自己寫簡單的感知機演算法
阿新 • • 發佈:2019-02-20
自己動手寫感知機
1 什麼是感知機?
感知機(perceptron)是一種二分類的線性分類模型,可以將所有輸入的例項劃分為True或是False兩類。感知機模型的輸入向量是例項的特徵向量,其任務就是在N維空間中尋找一個平面,讓這個平面可以正好將例項劃分為正負兩類,這個平面我們稱其為超平面。感知機是神經網路與支援向量機的基礎。
2 感知機模型
輸入空間:
輸出空間:
輸入空間到輸出的函式:
其中:
幾何解釋:
函式中的可以理解為超平面的法向量,為超平面的截距,超平面將整個空間分為兩部分,就是分類出來的正、負兩類。
3 感知機學習演算法
感知機學習演算法的具體策略就是使用誤分類點到超平面的距離和為損失函式,並使用梯度下降法最小化損失函式,即可求解出超平面。
原始形式
首先初始化所有引數為0,然後根據梯度下降法,用每個誤分類點來更新引數,封裝成類:
class PreceptronClassifier: def __init__(self,learning_rate)
對偶形式
前面提到過,每次更新引數,使用 這個式子是模型對求的偏導,表示學習率。
所以相對於每一個例項,他們的是不會變的,所以最終的的模型引數可以表示為:
表示第個點被誤分類用來更新模型的次數,這就是對偶形式,下面放程式碼:class PreceptronClassifier: def __init__(self,learning_rate): self.b = 0 self.w = [] self.rate = learning_rate def func(self, x): # 定義模型 res = 0.0 for i in range(len(x)): res += self.w[i]*x[i] res += self.b return res def update(self,error_index, trainY): # 更新引數 self.a[error_index] += 1 for i in range(len(self.w)): self.w[i] += self.rate*self.matrix[error_index][i] self.b += self.rate*trainY[error_index] def error_label(self,temp_res:list,trainY): # 標記誤分類點 for i in range(len(self.w)): if temp_res[i] == trainY[i]: temp_res[i] = 0 else: temp_res[i] = 1 return temp_res def fit(self, trainX,trainY): # 初始化引數 self.w = [0]*len(trainX[0]) self.a = [0]*len(trainX) self.b = 0 self.matrix = [] for i in range(len(trainX)): self.matrix.append([xi*trainY[i] for xi in trainX[i]]) temp_res = self.prediction(trainX) temp_res = self.error_label(temp_res, trainY) train_iter = 0 while(1 in temp_res): train_iter += 1 error_index = temp_res.index(1) self.update(error_index,trainY) temp_res = self.prediction(trainX) temp_res = self.error_label(temp_res, trainY) print('第'+str(train_iter)+'次迭代','w:',self.w,'b:',self.b) def prediction(self, testX): # 進行預測 res = [] for x in testX: res.append(1 if self.func(x) > 0 else -1 ) return res if __name__ == '__main__': # 準備訓練資料和測試資料 trainX = [[3,3],[4,3],[1,2]] trainY = [1,1,-1] testX = [[1,2],[3,4]] # 建立物件,指定學習率 pc = PreceptronClassifier(0.5) pc.fit(trainX,trainY) # 訓練 print(pc.prediction(testX)) # 測試
4 總結
感知機模型簡單且易於實現,是入門級演算法,同時又是神經網路和支援向量機的基礎。比如說神經網路的全連線層就和感知機很相似,所以這個演算法還是值得理解一下的。
初學乍練,請多多指正!