Python機器學習演算法實踐——梯度上升演算法
一:理論部分
給定一個樣本集,每個樣本點有兩個維度值(X1,X2)和一個類別值,類別只有兩類,我們以0和1代表。資料如下所示:
樣本 |
X1 |
X2 |
類別 |
1 |
-1.4 |
4.7 |
1 |
2 |
-2.5 |
6.9 |
0 |
... |
... |
... |
... |
機器學習的任務是找一個函式,給定一個數據兩個維度的值,該函式能夠預測其屬於類別1的概率。
假設這個函式的模樣如下:
h(x) =sigmoid(z)
z = w0 +w1*X1+w2*X2
問題轉化成了,根據現有的樣本資料,找出最佳的引數w
為進一步簡化問題,我們假設樣本集只有上表中的兩個。
假設現在手上已經有一個wt,也就是有了一個函式h(x),那麼我們可以把樣本1和樣本2的資料代進去,看看這個函式的預測效果如何,假設樣本1的預測值是p1 = 0.8,樣本2的預測值是:p2 = 0.4。
函式在樣本1上犯的錯誤為e1=(1-0.8)= 0.2,在樣本2上犯的錯誤為e2=(0-0.4)= -0.4,總的錯誤E為-0.20(e1+e2)。如下表所示:
樣本 |
X1 |
X2 |
類別 |
預測值 |
error |
1 |
-1.4 |
4.7 |
1 |
0.8 |
0.2 |
2 |
-2.5 |
6.9 |
0 |
0.4 |
-0.4 |
... |
... |
... |
... |
... |
... |
現在我們要改進wt的值,使得函式在樣本1和2上犯的總錯誤E減小。
將wt的改進拆開來,就是分別改進它的三個分量的值,我們以w1為例。
對於樣本1:
X1*e1=-1.4*0.2= -0.28
-0.28告訴我們什麼呢?它告訴我們,樣本1的X1和e1是異號的,減小w1的值,能夠減小函式在樣本1上犯的錯誤。為什麼呢?
w1減小,則X1*w1增大(因為樣本1的X1是負的),進而 z = w0 +w1*X1+w2*X2增大,又由於sigmoid函式是單調遞增的,則h(x)會增大。當前的h(x)是0.8,增大的話就是在向1靠近,也就是減小了在樣本1上犯的錯。
對於樣本2:
X1*e2=-2.5*-0.4= 1
1告訴我們,樣本2的X1和e2是同號的,增大w1的值,能夠減小函式在樣本2上犯的錯誤。為什麼呢?
w1增大,則X1*w1減小,進而 z = w0 +w1*X1+w2*X2減小,又由於sigmoid函式是單調遞增的,則h(x)會減小。當前的h(x)是0.4,減小的話就是在向0靠近,也就是減小了在樣本2上犯的錯。
現在的問題就是這樣的,樣本1說,要減小w1的值,這樣函式對我的判斷就更準確了,樣本2說,要增大w1的值,這樣函式對我的判斷就更準確了。
顯然,樣本1和樣本2都只從自己的角度出發,對改進w1提出了各自不同意見,我們要綜合它們的意見,以決定是增大w1還是減小w1,如下:
-0.28+1 = 0.72
最後的結果0.72是正的,說明,增大w1對函式的總體表現更有利。就是說,增大w1後,雖然在樣本1上犯的錯誤會稍稍增大,但在樣本2上犯的錯誤會大大減小,一個是稍稍增大,一個是大大減小,為了函式總體表現,肯定是增大w1的值啦。
那麼具體增加多大呢?我們可以用一個專門的引數alpha來控制。
二 Python程式碼(核心部分)
from numpy import *
def gradientAscent(dataMat,labelMat):
dataMat=mat(dataMat)
m,n=shape(dataMat)
labelMat=mat(labelMat).T
#假設weight =1
weights=ones((n,1))
alpha=0.00001 #學習率
num=500000 #迴圈次數
for k in range(num):
#整個資料集全部運算:num*m
#計算 z 值
z=dataMat*weights
y=sigmoid(z)
error=labelMat-y
#更新weights
weights=weights+alpha*dataMat.T * error
return weights
三.畫出影象
def showPlot(weights):
import matplotlib.pyplot as plt
dataMat,labelMat=loadDataSet()
dataArr=array(dataMat)
n=shape(dataArr)[0]
#正樣本
xcord1=[]
ycord1=[]
xcord2=[]
ycord2=[]
#迴圈資料,存到正負樣本中
for i in range(n):
if int(labelMat[i])==1:
xcord1.append(dataArr[i,1])
ycord1.append(dataArr[i,2])
else:
xcord2.append(dataArr[i,1])
ycord2.append(dataArr[i,2])
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(xcord1,ycord1,c='red')
ax.scatter(xcord2,ycord2,c='green')
x=arange(-5,5,0.1)
y=arange(-5,5,0.1)
y=(-weights[0]-weights[1]*x)/weights[2]
y=y.T
ax.plot(x,y)
plt.show()
四.其餘程式碼
def loadDataSet():
dataMat=[]
labelMat=[]
fr=open('dataset/testSet.txt')
for line in fr.readlines():
array=line.strip().split()
#截距 特徵
dataMat.append([1.0,float(array[0]),float(array[1])])
labelMat.append(int(array[2]))
return dataMat,labelMat
#方式一的函式圖
dataMat,labelMat=loadDataSet()
weights1=gradientAscent(dataMat,labelMat)
showPlot(weights1)
五.資料集
-0.017612 14.053064 0
-1.395634 4.662541 1
-0.752157 6.538620 0
-1.322371 7.152853 0
0.423363 11.054677 0
0.406704 7.067335 1
0.667394 12.741452 0
-2.460150 6.866805 1
0.569411 9.548755 0
-0.026632 10.427743 0
0.850433 6.920334 1
1.347183 13.175500 0
1.176813 3.167020 1
-1.781871 9.097953 0
-0.566606 5.749003 1
0.931635 1.589505 1
-0.024205 6.151823 1
-0.036453 2.690988 1
-0.196949 0.444165 1
1.014459 5.754399 1
1.985298 3.230619 1
-1.693453 -0.557540 1
-0.576525 11.778922 0
-0.346811 -1.678730 1
-2.124484 2.672471 1
1.217916 9.597015 0
-0.733928 9.098687 0
-3.642001 -1.618087 1
0.315985 3.523953 1
1.416614 9.619232 0
-0.386323 3.989286 1
0.556921 8.294984 1
1.224863 11.587360 0
-1.347803 -2.406051 1
1.196604 4.951851 1
0.275221 9.543647 0
0.470575 9.332488 0
-1.889567 9.542662 0
-1.527893 12.150579 0
-1.185247 11.309318 0
-0.445678 3.297303 1
1.042222 6.105155 1
-0.618787 10.320986 0
1.152083 0.548467 1
0.828534 2.676045 1
-1.237728 10.549033 0
-0.683565 -2.166125 1
0.229456 5.921938 1
-0.959885 11.555336 0
0.492911 10.993324 0
0.184992 8.721488 0
-0.355715 10.325976 0
-0.397822 8.058397 0
0.824839 13.730343 0
1.507278 5.027866 1
0.099671 6.835839 1
-0.344008 10.717485 0
1.785928 7.718645 1
-0.918801 11.560217 0
-0.364009 4.747300 1
-0.841722 4.119083 1
0.490426 1.960539 1
-0.007194 9.075792 0
0.356107 12.447863 0
0.342578 12.281162 0
-0.810823 -1.466018 1
2.530777 6.476801 1
1.296683 11.607559 0
0.475487 12.040035 0
-0.783277 11.009725 0
0.074798 11.023650 0
-1.337472 0.468339 1
-0.102781 13.763651 0
-0.147324 2.874846 1
0.518389 9.887035 0
1.015399 7.571882 0
-1.658086 -0.027255 1
1.319944 2.171228 1
2.056216 5.019981 1
-0.851633 4.375691 1
-1.510047 6.061992 0
-1.076637 -3.181888 1
1.821096 10.283990 0
3.010150 8.401766 1
-1.099458 1.688274 1
-0.834872 -1.733869 1
-0.846637 3.849075 1
1.400102 12.628781 0
1.752842 5.468166 1
0.078557 0.059736 1
0.089392 -0.715300 1
1.825662 12.693808 0
0.197445 9.744638 0
0.126117 0.922311 1
-0.679797 1.220530 1
0.677983 2.556666 1
0.761349 10.693862 0
-2.168791 0.143632 1
1.388610 9.341997 0
0.317029 14.739025 0