1. 程式人生 > >Rosenblatt感知機-神經網路與機器學習筆記1

Rosenblatt感知機-神經網路與機器學習筆記1

一、Rosenblatt感知機小結

感知機模型
這裡寫圖片描述
輸入向量:x(n)=[+1,x1(n),x2(n),...,xm(n)]T
權重向量:w(n)=[b,w1(n),w2(n),...wm(n)]T
當輸入向量線性可分時,我們需要尋找到一個合適的w,正確地分開兩類資料
wTx>0時,x屬於分類1
wTx0時,x屬於分類-1
這裡寫圖片描述
選取符合函式sign(x)為硬限幅函式,我們的模型可以描述為

y=sign(wTx)
所以訓練過程可描述如下:
x(n):輸入向量
w(n): 權重向量
y(n):模型產生的響應
d(n):期望的響應
eta:學習速率
初始化w=[0
; 0; ...0]; for i = 1 :N//批次迴圈N次 for j = 1 : n//遍歷一次樣本,n為樣本個數 y(j) = sign(w * x(j)); w = w + eta * (d(j) - y(j)) * x(j)//調整權重 end 計算均方差; end

二、程式碼實現

1.產生隨機數

產生兩類半圓形資料,引數r,dr,d如下圖所示

這裡寫圖片描述

matlab程式碼如下

function [x1,y1,x2,y2] = GenRandomData(r, d,dr, n1, n2)
if nargin<4 r=1; d=1; n1=100;n2=100; end dr1=rand([n1,1])*dr;theta1=rand([n1,1])*pi; dr2=rand([n2,1])*dr;theta2=-rand([n2,1])*pi; x1=(r-dr/2+dr1).*cos(theta1); y1=(r-dr/2+dr1).*sin(theta1); x2=r+(r-dr/2+dr2).*cos(theta2); y2=-d+(r-dr/2+dr2).*sin(theta2); end

2.Rosenblatt感知機

function [w, err]
= Rosenblatt(x,d,w0,etaLim,epochN)
%x: m x n, m: feature numbers, n: sample numbers isPlotLine = 1;%是否畫出擬閤中曲線的變化 if isPlotLine figure; plot(x(2, d == 1),x(3, d == 1),'LineStyle','none','Marker','x','Color','r'); hold on; plot(x(2, d == -1),x(3, d == -1),'LineStyle','none','Marker','o','Color','b'); end [m,n] = size(x); MSEarr = [];w = w0;epochi = 0; %訓練速率eta採用線性退火 eta=linspace(etaLim(1),etaLim(2),epochN); while epochi < epochN epochi=epochi + 1; for i=1:n y= sign(w'*x(:,i)); w=w + eta(epochi)*(d(i)-y)*x(:,i); end yarr = sign(x'*w); MSE = sum((d-yarr).^2); MSE = sqrt(MSE/n); MSEarr = [MSEarr MSE]; if isPlotLine xx = -10:0.1:10; yy=-w(2,1) * xx/w(3,1)-w(1,1)/w(3,1); plot(xx,yy); end end yarr = sign(x'*w); errN=sum(yarr~=d); err=errN/n; figure;plot(MSEarr);title('MSE'); end

3.程式碼測試

close all;
%產生隨機數
n1=1000;n2=1000;n=n1+n2;
[x1,y1,x2,y2]=GenRandomData(8,1,3,n1,n2);
samps=[ones(n1,1),x1, y1;ones(n2,1),x2,y2];
d=[ones(n1,1);-ones(n2,1)];
%打亂順序
randI=randperm(n);
samps(randI,:)=samps(1:n,:);
d(randI)=d(1:n);
%利用感知機進行訓練
w0 = [0;0;0];etaLim = [1e-3 1e-5];x = samps';epochN = 50;
[w, err] = Rosenblatt(x,d,w0,etaLim,epochN)
%畫圖
figure;
plot(x1,y1,'Marker','x','Color','r','LineStyle','none');
hold on;
plot(x2,y2,'Marker','o','Color','b','LineStyle','none');
xx=[-10:0.1:10];
yy=-w(2,1)*xx/w(3,1)-w(1,1)/w(3,1);
plot(xx,yy,'k');

當兩個半圓距離為4時,均方差開始就很小。
這裡寫圖片描述
這裡寫圖片描述

當兩個半圓距離為0時,均方差震盪了一段時間開始收斂,最後結果有較小的錯誤率err =5.000000000000000e-04
最後分類結果
迭代過程
誤差

當兩個半圓距離為-4時,均方差一直震盪,最後錯誤率err=0.195500000000000
這裡寫圖片描述
這裡寫圖片描述
這裡寫圖片描述