BP神經網路的python實現
引用:
啟用函式:
變數定義:
正向計算:
反向調節權值,閾值:
預測樣本:
函式呼叫:
訓練用例以及測試用例:
訓練結果:
原始碼如下:
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 16 23:07:14 2022
@author: 12234
"""
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
def sigmoid(x):
"""
隱含層和輸出層對應的函式法則
"""
return 1/(1+np.exp(-x))
def BP(DateTry, DateTest, maxiter=11):
# --pandas是基於numpy設計的,效率略低
# 為提高處理效率,轉換為陣列
DateTry, DateTest = np.array(DateTry), np.array(DateTest)
# --隱層輸入
# -1: 代表的是隱層的閾值
hidden_in = np.array([0.0, 0, -1])
w_hidden = np.random.rand(3, 4) # 隱層權值閾值(-1x其中一個值:閾值)
# 輸出層輸入
# -1:代表輸出層閾值
out_in = np.array([0.0, 0, 0, 0, -1])
w_out = np.random.rand(5) # 輸出層權值閾值(-1x其中一個值:閾值)
delta_w_out = np.zeros([5]) # 存放輸出層權值閾值的逆向計算誤差
delta_w_hidden = np.zeros([3, 4]) # 存放因此能權值閾值的逆向計算誤差
yita = 1.75 # η: 學習速率
Err = np.zeros([maxiter]) # 記錄總體樣本每迭代一次的錯誤率
# 1.樣本總體訓練的次數
for it in range(maxiter):
# 衡量每一個樣本的誤差
err = np.zeros([len(DateTry)])
# 2.訓練集訓練一遍
for j in range(len(DateTry)):
hidden_in[:2] = DateTry[j, :2] # 儲存當前物件前兩個屬性值
real = DateTry[j, 2]
# 3.當前物件進行訓練
for i in range(4):
out_in[i] = sigmoid(sum(hidden_in*w_hidden[:, i])) # 計算輸出層輸入
res = sigmoid(sum(out_in * w_out)) # 獲得訓練結果
err[j] = abs(real - res)
# --先調節輸出層的權值與閾值
delta_w_out = yita*res*(1-res)*(real-res)*out_in # 權值調整
delta_w_out[4] = -yita*res*(1-res)*(real-res) # 閾值調整
w_out = w_out + delta_w_out
# --隱層權值和閾值的調節
for i in range(4):
# 權值調整
delta_w_hidden[:, i] = yita * out_in[i] * (1 - out_in[i]) * w_out[i] * res * (1 - res) * (real - res) * hidden_in
# 閾值調整
delta_w_hidden[2, i] = -yita * out_in[i] * (1 - out_in[i]) * w_out[i] * res * (1 - res) * (real - res)
w_hidden = w_hidden + delta_w_hidden
Err[it] = err.mean()
plt.plot(Err)
plt.show()
# 儲存預測誤差
err_te = np.zeros([11])
# 預測樣本11個
for j in range(11):
hidden_in[:2] = DateTest[j, :2] # 儲存資料
real = DateTest[j, 2] # 真實結果
# net_in和w_mid的相乘過程
for i in range(4):
# 輸入層到隱層的傳輸過程
out_in[i] = sigmoid(sum(hidden_in*w_hidden[:, i]))
res = sigmoid(sum(out_in*w_out)) # 網路預測結果輸出
err_te[j] = abs(real-res) # 預測誤差
print('res:', res, ' real:', real)
plt.plot(err_te)
plt.show()
if "__main__" == __name__:
# 1.讀取樣本
DateTry = pd.read_csv("NetTry.txt")
DateTest = pd.read_csv("NetTest.txt")
BP(DateTry, DateTest, maxiter=11)