1. 程式人生 > 程式設計 >淺談matplotlib 繪製梯度下降求解過程

淺談matplotlib 繪製梯度下降求解過程

機器學習過程中經常需要視覺化,有助於加強對模型和引數的理解。

下面對梯度下降過程進行動圖演示,可以修改不同的學習率,觀看效果。

import numpy as np
import matplotlib.pyplot as plt
from IPython import display

X = 2*np.random.rand(100,1)
y = 4+3*X+np.random.randn(100,1) # randn正態分佈
X_b = np.c_[np.ones((100,1)),X] # c_行數相等,左右拼接

eta = 0.1 # 學習率
n_iter = 1000 # 迭代次數
m = 100 # 樣本點個數
theta = np.random.randn(2,1) # 引數初始值

plt.figure(figsize=(8,6))
mngr = plt.get_current_fig_manager() # 獲取當前figure manager
mngr.window.wm_geometry("+520+520") # 調整視窗在螢幕上彈出的位置,注意寫在開啟互動模式之前
# 上面固定視窗,方便screentogif定位錄製,只會這種弱弱的方法
plt.ion()# 開啟互動模式
plt.rcParams["font.sans-serif"] = "SimHei"# 消除中文亂碼

for iter in range(n_iter):
  plt.cla() # 清除原影象

  gradients = 2/m*X_b.T.dot(X_b.dot(theta)-y)
  theta = theta - eta*gradients
  X_new = np.array([[0],[2]])
  X_new_b = np.c_[np.ones((2,X_new]
  y_pred = X_new_b.dot(theta)

  plt.axis([0,2,15])
  plt.plot(X,y,"b.")
  plt.plot(X_new,y_pred,"r-")
  plt.title("學習率:{:.2f}".format(eta))
  plt.pause(0.3) # 暫停一會
  display.clear_output(wait=True)# 重新整理影象


plt.ioff()# 關閉互動模式  
plt.show()

淺談matplotlib 繪製梯度下降求解過程

學習率:0.1,較合適

淺談matplotlib 繪製梯度下降求解過程

學習率:0.02,收斂變慢了

淺談matplotlib 繪製梯度下降求解過程

學習率:0.45,在最佳引數附近震盪

淺談matplotlib 繪製梯度下降求解過程

學習率:0.5,不收斂

到此這篇關於淺談matplotlib 繪製梯度下降求解過程的文章就介紹到這了,更多相關matplotlib 梯度下降內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!