對數機率迴歸-機器學習
阿新 • • 發佈:2019-01-04
資料集百度網盤,就是西瓜書3.0a的資料。
首先,載入資料,load_data(file)函式。
def load_data(file): s =[] with open(file) as f: for line in f.readlines(): line = line.replace('\n','') #追行讀取 type(line) = str s.append(line.split(' ')) #空格分開 return s將這個資料視覺化下,大致是這個樣子。
其中,紅,藍分別表示好瓜和壞瓜。可以看出,用一條直線,還是不好劃分的。
這部分程式碼如下:
我們接下來用對數機率迴歸模型,具體的公式可以看周志華的《機器學習》第三章的(3.27)這個式子,其他的地方也有。這個是沒約束的優化問題,直接用梯度下降法,求導有問題的,可以機器學習求導..file = '../data/3_0a.txt' #檔案地址 s = load_data(file) print(type(s)) x =[] #存好瓜的含糖率 y= [] #好瓜的密度 x1 =[] #壞瓜含糖率 y1 = [] #壞瓜密度 for i in range(1,8): #讀取好瓜 for j in range(len(s[i])): if j == 2: x.append(float(s[i][j])) if j ==3: y.append(float(s[i][j])) for i in range(8,len(s)): #壞瓜 for j in range(len(s[i])): if j == 2: x1.append(float(s[i][j])) if j ==3: y1.append(float(s[i][j])) import pylab as pl pl.plot(x,y,'o') pl.plot(x1,y1,'ro') pl.show()
3.27公式中,yi是樣本的結果,好瓜是1,壞瓜是0.xi是樣本的屬性,我們這裡有兩個屬性。下面就是從前面讀取的資料把xi,yi讀出來。然後把這個值帶入梯度下降法中的導數項。w,b的初始值隨便設定個就行。迭代算吧。程式碼如下:
import numpy as np import pylab as plt import my_load_data as mld #就是前面的那個函式,這段可以刪除,直接把上面的load_data函式放到這裡也行。 file = '../data/3_0a.txt' s = mld.load_data(file) x = np.mat(np.zeros((17,3))) #why (())? #初始化矩陣。用ndarray無法進行矩陣乘法這類運算,所以要用mat。 y = np.mat(np.zeros((17,1))) for i in range(1,18,1): #yi ,xi讀取 x[i-1] = np.mat([float(s[i][2]),float(s[i][3]),1]) if s[i][1]=='是': y[i-1] = np.mat([1]) else: y[i-1] =np.mat([0]) start = np.mat([[0.1],[10],[8]]) #w,b的初始化。這裡有三個數,[w1;w2;b] i = 0 xishu =0.01 while i<2*10**5: #一萬次差不多就可以了 s = 0 for j in range(17): #3.27前面有個i=1到i=m的求和,就是這裡。 startT =np.transpose(start) xT =np.transpose(x[j]) bx = startT*xT bx_1 = np.array(bx)[0][0] c = -y[j]*x[j]+(np.exp(bx_1)/(1+np.exp(bx_1)))*x[j] #導數,寫的太難看,請忽略 s =s+c s_1 = np.transpose(s) #導數 new = start - xishu*s_1 #梯度下降公式,這裡大家應該很熟悉 start =new i=i+1 if i%10000 ==0: print('no%s'%i,'start is %s'%start) print(start)
迭代結果:
no10000 start is [[ 2.98758124]
[ 11.91671654]
[ -4.21286642]]
no20000 start is [[ 3.13439493]
[ 12.43023225]
[ -4.39732669]]
no30000 start is [[ 3.15464273]
[ 12.50721714]
[ -4.42401375]]
no40000 start is [[ 3.15776018]
[ 12.5190382 ]
[ -4.42811559]]
no50000 start is [[ 3.15824169]
[ 12.52086253]
[ -4.42874883]]
no60000 start is [[ 3.15831607]
[ 12.52114431]
[ -4.42884664]]
no70000 start is [[ 3.15832756]
[ 12.52118784]
[ -4.42886175]]
no80000 start is [[ 3.15832934]
[ 12.52119456]
[ -4.42886408]]
no90000 start is [[ 3.15832961]
[ 12.5211956 ]
[ -4.42886444]]
no100000 start is [[ 3.15832965]
[ 12.52119576]
[ -4.4288645 ]]
no110000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no120000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no130000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no140000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no150000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no160000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no170000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no180000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no190000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no200000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
[[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]