Rosenblatt感知機-神經網路與機器學習筆記1
阿新 • • 發佈:2019-02-02
一、Rosenblatt感知機小結
感知機模型
輸入向量:
權重向量:
當輸入向量線性可分時,我們需要尋找到一個合適的w,正確地分開兩類資料
當
當
選取符合函式sign(x)為硬限幅函式,我們的模型可以描述為
所以訓練過程可描述如下:
x(n):輸入向量
w(n): 權重向量
y(n):模型產生的響應
d(n):期望的響應
eta:學習速率
初始化w=[0 ; 0; ...0];
for i = 1 :N//批次迴圈N次
for j = 1 : n//遍歷一次樣本,n為樣本個數
y(j) = sign(w * x(j));
w = w + eta * (d(j) - y(j)) * x(j)//調整權重
end
計算均方差;
end
二、程式碼實現
1.產生隨機數
產生兩類半圓形資料,引數r,dr,d如下圖所示
matlab程式碼如下
function [x1,y1,x2,y2] = GenRandomData(r, d,dr, n1, n2)
if nargin<4
r=1;
d=1;
n1=100;n2=100;
end
dr1=rand([n1,1])*dr;theta1=rand([n1,1])*pi;
dr2=rand([n2,1])*dr;theta2=-rand([n2,1])*pi;
x1=(r-dr/2+dr1).*cos(theta1);
y1=(r-dr/2+dr1).*sin(theta1);
x2=r+(r-dr/2+dr2).*cos(theta2);
y2=-d+(r-dr/2+dr2).*sin(theta2);
end
2.Rosenblatt感知機
function [w, err] = Rosenblatt(x,d,w0,etaLim,epochN)
%x: m x n, m: feature numbers, n: sample numbers
isPlotLine = 1;%是否畫出擬閤中曲線的變化
if isPlotLine
figure;
plot(x(2, d == 1),x(3, d == 1),'LineStyle','none','Marker','x','Color','r');
hold on;
plot(x(2, d == -1),x(3, d == -1),'LineStyle','none','Marker','o','Color','b');
end
[m,n] = size(x);
MSEarr = [];w = w0;epochi = 0;
%訓練速率eta採用線性退火
eta=linspace(etaLim(1),etaLim(2),epochN);
while epochi < epochN
epochi=epochi + 1;
for i=1:n
y= sign(w'*x(:,i));
w=w + eta(epochi)*(d(i)-y)*x(:,i);
end
yarr = sign(x'*w);
MSE = sum((d-yarr).^2);
MSE = sqrt(MSE/n);
MSEarr = [MSEarr MSE];
if isPlotLine
xx = -10:0.1:10;
yy=-w(2,1) * xx/w(3,1)-w(1,1)/w(3,1);
plot(xx,yy);
end
end
yarr = sign(x'*w);
errN=sum(yarr~=d);
err=errN/n;
figure;plot(MSEarr);title('MSE');
end
3.程式碼測試
close all;
%產生隨機數
n1=1000;n2=1000;n=n1+n2;
[x1,y1,x2,y2]=GenRandomData(8,1,3,n1,n2);
samps=[ones(n1,1),x1, y1;ones(n2,1),x2,y2];
d=[ones(n1,1);-ones(n2,1)];
%打亂順序
randI=randperm(n);
samps(randI,:)=samps(1:n,:);
d(randI)=d(1:n);
%利用感知機進行訓練
w0 = [0;0;0];etaLim = [1e-3 1e-5];x = samps';epochN = 50;
[w, err] = Rosenblatt(x,d,w0,etaLim,epochN)
%畫圖
figure;
plot(x1,y1,'Marker','x','Color','r','LineStyle','none');
hold on;
plot(x2,y2,'Marker','o','Color','b','LineStyle','none');
xx=[-10:0.1:10];
yy=-w(2,1)*xx/w(3,1)-w(1,1)/w(3,1);
plot(xx,yy,'k');
當兩個半圓距離為4時,均方差開始就很小。
當兩個半圓距離為0時,均方差震盪了一段時間開始收斂,最後結果有較小的錯誤率err =5.000000000000000e-04
當兩個半圓距離為-4時,均方差一直震盪,最後錯誤率err=0.195500000000000