李航-機器學習-感知機(perceptron)-原始形式
阿新 • • 發佈:2018-12-15
機器學習-感知機(perceptron)
感知機模型
感知機是一種線性的、二類分類模型,可以將空間劃分為正類和負類,是一種判別模型,輸入為具體的例項,輸出為例項的類別(+1,-1)。有原始形式和對偶形式兩種。感知機是神經網路和支援向量機的基礎。
感知機預測是利用學習到的模型對輸入例項進行類別的劃分。
由輸入空間到輸出空間有如下函式:
f(x) = sign(w*x+b)
感知機學習策略
假設資料集是線性可分(可以使用一個超平面完全將正負資料劃分開來),我們要做的就是找到一個線性函式將訓練集中的正負例項點完全劃分開來。 線性函式中的未知引數是w、b,所以我們的任務就是找到滿足上述條件的引數值,若要找到相應的引數值,就需要通過損失函式來尋找w、b。 我們可以選用誤分類點的總數作為我們的損失函式,但是這樣的損失函式中的引數w、b不是連續變化的,在後面不利於求導優化。所以我們選擇的損失函式是誤分類點到初始超平面的距離總和。公式如下:
||w||為w的L2範數:w的平方開根號 所有誤分類點到超平面的距離為: 在李航的書中不考慮1/||w||,由此得到損失函式,我還沒弄清怎麼回事。 綜上,感知機sign(w*x+b)的損失函式為: M是誤分類點的總數,這個函式就是感知機的經驗風險函式。
感知機演算法實現
輸入:訓練資料集T={(x1,y1),(x2,y2),(x3,y3)…},xi屬於實數集,yi={1,-1},學習率n 輸出:w,b;感知機函式f(x) = sign(wx+b) (1)、選取初值w0,b0 (2)、在訓練集中選取資料(xi,yi) (3)、將訓練集中選取的資料帶入yi(wxi+b),如果該式小於零,那麼就對w,b進行更新:w<-w+nxi
程式碼實現
import numpy as np //用於各種數值運算,例如矩陣,矩陣內積等 import random as random //用於打亂資料順序 import matplotlib as mpl import matplotlib.pyplib as plt //定義資料集 datas =[[(1,2),-1],[(2,1),-1],[(2,2),-1],[(1,4),1],[(3,3),1], [(5,4),1],[(3, 3), 1], [(4, 3), 1], [(1, 1), -1],[(2, 3), -1], [(4, 2), 1]] //將資料集打亂 random.shuffle(datas) //輸入影象標題 fig = plt.figure('Input data') //從datas中依次取出元素 xArr = np.array([x[0] for x in datas]) //取出一個11*2的陣列 yArr = np.array([x[1] for x in datas])//取出每一個標籤的分類值 //分別初始化資料集中正類和負類的空陣列 xPlotx,xPloty,xPlotx_,xPloty_ = [],[],[],[] // 利用迴圈分別儲存正類和負類資料 for i in range(len(datas)): y = yArr[i] //讀取標籤值 if y>0: xPlotx.append(xArr[i][0])//x為陣列中第一列的值 xPloty.append(xArr[i][1])//y為陣列中第二列的值 else: xPlotx_.append(xArr[i][0]) xPloty_.append(xArr[i][2]) //影象標題 plt.title('percetron 輸入資料‘) //影象顯示網格線 plt.grid(True) //繪製資料集中的正負類點 pPlot1,pPlot2 = plt.plot(xPlotx,xPloty,'b+',xPlotx_,xPloty_,'r+') //給影象貼標籤,指定標籤位置 plt.legend(handles=[pPlot1,pPlot2],labels=['pos','neg'],loc='upper center') plt.show()
執行程式可得
//對w、b進行初始化
//給w賦初值[1,1]
w = np.array([1,1])
b = 3
//學習率
n = 1
//依次檢驗訓練集中的資料
while True:
num = 0
for i in range(len(datas)):
num+=1 //每取出一個數據,記錄加一
x = xArr[i]
y = yArr[i]
z = y*(np.dot[w,x]+b) //帶入公式y(w*x+b),np.dot用於計算矩陣的內積
if z<=0:
//對w、b進行更新
w = w + n*x*y
b = b+n*b
break
//當取出的資料個數大於或者等於陣列長度
if num>=len(datas):
break
fig = plt.figure('Output figure')
x0 = np.linespce(0,5,100)//生成從0到5的100個等間隔數
w0 =w[0]
w1 = w[1]
x1 = -(w0/w1)*x0-b/w1
plt.title("Perception 輸出平面")
plt.xlabel('x0')
plt.ylabel('x1')
plt.annotate('Output Hyperplane',xy=(0.5,4.5),xytext=(1.7,3.5))//第一個引數為註釋文字內容,xy為被註釋的座標點,xytext為註釋文字座標位置
plt.plot(x0,x1,'k', lw=1)
pPlot3, pPlot4= plt.plot(xPlotx,xPloty,'b+',xPlotx_,xPloty_,'rx')//繪點
plt.legend(handles = [pPlot3,pPlot4],labels=['Positive Sample','Negative Sample'],loc='upper right')
plt.show()