三層神經網路實現分類器
阿新 • • 發佈:2019-02-15
一、簡介
神經網路模型是一種模仿生物大腦的神經元連線,提出的一種計算模型。目前已在人工智慧領域取得了廣泛的應用。
下圖為一個神經元的抽象模型,一個神經元接收來自其他神經元的訊號,對訊號進行加和後作一個激勵,然後輸出。
二、實現
原始資料共有30個,分為3類,每類10個:samples1=[ 1.58 2.32 -5.8; 0.67 1.58 -4.78; 1.04 1.01 -3.63; -1.49 2.18 -3.39; -0.41 1.21 -4.73; 1.39 3.16 2.87; 1.20 1.40 -1.89; -0.92 1.44 -3.22; 0.45 1.33 -4.38; -0.76 0.84 -1.96; ]; samples2=[ 0.21 0.03 -2.21; 0.37 0.28 -1.8; 0.18 1.22 0.16; -0.24 0.93 -1.01; -1.18 0.39 -0.39; 0.74 0.96 -1.16; -0.38 1.94 -0.48; 0.02 0.72 -0.17; 0.44 1.31 -0.14; 0.46 1.49 0.68; ]; samples3=[ -1.54 1.17 0.64; 5.41 3.45 -1.33; 1.55 0.99 2.69; 1.86 3.19 1.51; 1.68 1.79 -0.87; 3.51 -0.22 -1.39; 1.40 -0.44 -0.92; 0.44 0.83 1.97; 0.25 0.68 -0.99; 0.66 -0.45 0.08; ]; samples=[samples1;samples2;samples3]; labels=zeros(size(samples)); labels(1:10,1)=1; labels(11:20,2)=1; labels(21:30,3)=1; save -v7 'data.mat' samples labels;
除了原始資料,還要構造one-hot vector指示向量。因為一共有3類,所以網路的輸出層的節點數為3,對於每個資料,用一個三維向量來指示資料具體屬於哪一類,one-hot vector為對應位上為1,其餘位為0,這就是網路訓練的目標值。這樣通過對輸入資料的訓練,我們的網路能對將輸入資料輸出為一個one-hot vector即完成了對資料的分類。
為了實現對上述資料的分類,採用包含一個隱含層的三層前饋神經網路,結構如下圖。並且採用梯度下降演算法對網路引數進行優化。
matlab/octave 程式碼:
clear; load data.mat samples=[samples ones(size(samples,1),1)]; nn=size(samples,1); %number of samples ni=size(samples,2); %number of input layer nodes nh=8; %number of hidden layer nodes nj=3; %number of output layer nodes eta=0.5; %learning rate theta=0.0001; %criterion threshold Jw=[0]; %errors of traning wih=rand(ni,nh)-0.5; whj=rand(nh,nj)-0.5; for r=0:1000 for k=1:nn rk=randi(nn); xi=samples(rk,:); %kth sample tj=labels(rk,:); %expected output neth=xi*wih; %hidden layer net yh=tanh(neth); %hidden layer out and activation function is tanh netj=yh*whj; %output lay net zj=sigmoid(netj); %output layer out and activation function is sigmoid delta_j=(sigmoid(netj).*(1-sigmoid(netj))).*(tj-zj); delta_h=(1-tanh(neth).^2).*(delta_j*whj'); delta_whj=eta*yh'*delta_j; delta_wih=eta*xi'*delta_h; whj=whj+delta_whj; wih=wih+delta_wih; end result=sigmoid(tanh(samples*wih)*whj); Jw=[Jw sum(sum((labels-result).^2))]; delta_Jw=abs(Jw(end)-Jw(end-1)); if delta_Jw