1. 程式人生 > >BP神經網路,BP推導過程,反向傳播演算法,誤差反向傳播,梯度下降,權值閾值更新推導,隱含層權重更新公式

BP神經網路,BP推導過程,反向傳播演算法,誤差反向傳播,梯度下降,權值閾值更新推導,隱含層權重更新公式

  1. %% BP的主函式  
  2. % 清空  
  3. clear all;  
  4. clc;  
  5. % 匯入資料  
  6. load data;  
  7. %從1到2000間隨機排序  
  8. k=rand(1,2000);  
  9. [m,n]=sort(k);  
  10. %輸入輸出資料  
  11. input=data(:,2:25);  
  12. output1 =data(:,1);  
  13. %把輸出從1維變成4維  
  14. for i=1:2000  
  15.     switch output1(i)  
  16.         case 1  
  17.             output(i,:)=[1 0 0 0];  
  18.         case 2  
  19.             output(i,:)=[0 1 0 0];  
  20.         case 3  
  21.             output(i,:)=[0 0 1 0];  
  22.         case 4  
  23.             output(i,:)=[0 0 0 1];  
  24.     end  
  25. end  
  26. %隨機提取1500個樣本為訓練樣本,500個樣本為預測樣本  
  27. trainCharacter=input(n(1:1600),:);  
  28. trainOutput=output(n(1:1600),:);  
  29. testCharacter=input(n(1601:2000),:);  
  30. testOutput=output(n(1601:2000),:);  
  31. % 對訓練的特徵進行歸一化  
  32. [trainInput,inputps]=mapminmax(trainCharacter‘);  
  33. %% 引數的初始化  
  34. % 引數的初始化  
  35. inputNum = 24;%輸入層的節點數  
  36. hiddenNum = 50;%隱含層的節點數  
  37. outputNum = 4;%輸出層的節點數  
  38. % 權重和偏置的初始化  
  39. w1 = rands(inputNum,hiddenNum);  
  40. b1 = rands(hiddenNum,1);  
  41. w2 = rands(hiddenNum,outputNum);  
  42. b2 = rands(outputNum,1);  
  43. % 學習率  
  44. yita = 0.1;  
  45. %% 網路的訓練  
  46. for r = 1:30  
  47.     E(r) = 0;% 統計誤差  
  48.     for m = 1:1600  
  49.         % 資訊的正向流動  
  50.         x = trainInput(:,m);  
  51.         % 隱含層的輸出  
  52.         for j = 1:hiddenNum  
  53.             hidden(j,:) = w1(:,j)‘*x+b1(j,:);  
  54.             hiddenOutput(j,:) = g(hidden(j,:));  
  55.         end  
  56.         % 輸出層的輸出  
  57.         outputOutput = w2‘*hiddenOutput+b2;  
  58.         % 計算誤差  
  59.         e = trainOutput(m,:)‘-outputOutput;  
  60.         E(r) = E(r) + sum(abs(e));  
  61.         % 修改權重和偏置  
  62.         % 隱含層到輸出層的權重和偏置調整  
  63.         dw2 = hiddenOutput*e‘;  
  64.         db2 = e;  
  65.         % 輸入層到隱含層的權重和偏置調整  
  66.         for j = 1:hiddenNum  
  67.             partOne(j) = hiddenOutput(j)*(1-hiddenOutput(j));  
  68.             partTwo(j) = w2(j,:)*e;  
  69.         end  
  70.         for i = 1:inputNum  
  71.             for j = 1:hiddenNum  
  72.                 dw1(i,j) = partOne(j)*x(i,:)*partTwo(j);  
  73.                 db1(j,:) = partOne(j)*partTwo(j);  
  74.             end  
  75.         end  
  76.         w1 = w1 + yita*dw1;  
  77.         w2 = w2 + yita*dw2;  
  78.         b1 = b1 + yita*db1;  
  79.         b2 = b2 + yita*db2;    
  80.     end  
  81. end  
  82. %% 語音特徵訊號分類  
  83. testInput=mapminmax(‘apply‘,testCharacter‘,inputps);  
  84. for m = 1:400  
  85.     for j = 1:hiddenNum  
  86.         hiddenTest(j,:) = w1(:,j)‘*testInput(:,m)+b1(j,:);  
  87.         hiddenTestOutput(j,:) = g(hiddenTest(j,:));  
  88.     end  
  89.     outputOfTest(:,m) = w2‘*hiddenTestOutput+b2;  
  90. end  
  91. %% 結果分析  
  92. %根據網路輸出找出資料屬於哪類  
  93. for m=1:400  
  94.     output_fore(m)=find(outputOfTest(:,m)==max(outputOfTest(:,m)));  
  95. end  
  96. %BP網路預測誤差  
  97. error=output_fore-output1(n(1601:2000))‘;  
  98. k=zeros(1,4);    
  99. %找出判斷錯誤的分類屬於哪一類  
  100. for i=1:400  
  101.     if error(i)~=0  
  102.         [b,c]=max(testOutput(i,:));  
  103.         switch c  
  104.             case 1   
  105.                 k(1)=k(1)+1;  
  106.             case 2   
  107.                 k(2)=k(2)+1;  
  108.             case 3   
  109.                 k(3)=k(3)+1;  
  110.             case 4   
  111.                 k(4)=k(4)+1;  
  112.         end  
  113.     end  
  114. end  
  115. %找出每類的個體和  
  116. kk=zeros(1,4);  
  117. for i=1:400  
  118.     [b,c]=max(testOutput(i,:));  
  119.     switch c  
  120.         case 1  
  121.             kk(1)=kk(1)+1;  
  122.         case 2  
  123.             kk(2)=kk(2)+1;  
  124.         case 3  
  125.             kk(3)=kk(3)+1;  
  126.         case 4  
  127.             kk(4)=kk(4)+1;  
  128.     end  
  129. end  
  130. %正確率  
  131. rightridio=(kk-k)./kk