1. 程式人生 > 其它 >[自己看的]pytorch實現簡單的徑向基網路

[自己看的]pytorch實現簡單的徑向基網路

網路只有三個節點,如下圖

 

 

用來擬合一個函式,如下圖

 

 

資料如下

x為[[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11],[12],[13],[14],[15]]

y為[[1.103638323514327], [3.0], [1.103638323514327], [0.054946916666202536], [0.0003702294122600387], [3.3760552415777734e-07], [4.166383159489206e-11], [6.958568490730709e-16], [1.5728656990090393e-21], [4.811432671645914e-28], [1.9919031598742206e-35], [1.1160227928062509e-43], [8.462310265380405e-53], [8.683920934944901e-63], [1.2060180647230065e-73]]

這些資料用這個程式碼生成:

import math
a = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
b = []
def yunsuan(a,b):
    for i in range(len(a)):
        b.append(math.exp(-(a[i]-2)**2)*3)
    return b
print(yunsuan(a,b))

 

搭建及訓練網路的程式碼參考網址:https://www.pythonheidong.com/blog/article/608644/9530a7202e0ab840d104/

如下:

#1,匯入包
import torch, random
import torch.nn as nn import torch.optim as optim #2,設定隨機數種子 torch.manual_seed(42) #3,搭建神經網路類 class RBFN(nn.Module): """ 以高斯核作為徑向基函式 """ def __init__(self, centers, n_out=1): """ :param centers: shape=[center_num,data_dim] :param n_out: """ super(RBFN, self).
__init__() self.n_out = n_out #這個是輸出維度?? self.num_centers = 1#centers.size(0) # 隱層節點的個數 self.dim_centure = 1#centers.size(1) # 這個centers是隱藏層?,num是節點的話,dim是該節點處的維度嗎? self.centers = nn.Parameter(centers) #nn.parameter的作用是設定網路引數??,centers變成了模型的一部分 self.beta = torch.ones(1, self.num_centers) #生成一行有八個值為10的張量 #self.beta = nn.Parameter(self.beta) # 對線性層的輸入節點數目進行了修改 self.linear = nn.Linear(self.num_centers, self.n_out, bias=True) #不懂這裡為什麼輸入不是num self.initialize_weights()# 建立物件時自動執行,初始化權重 def kernel_fun(self, batches): n_input = batches.size(0) # number of inputs,0代表行數 A = self.centers.view(self.num_centers, -1).repeat(n_input, 1, 1) #view成num_centers行,然後再重複n_input次 B = batches.view(n_input, -1).unsqueeze(1).repeat(1, self.num_centers, 1) #A是中心,B是資料 #把batches展開成n_input行, #unsqueeze新增一個維度 C = torch.exp(-self.beta.mul((A - B).pow(2).sum(1, keepdim=False))) #pow是2次方, #sum(2, keepdim=False),False不儲存原來的維度,2是按照第二維度進行求和後用false去掉多餘維度 return C def forward(self, batches): radial_val = self.kernel_fun(batches) class_score = self.linear((radial_val)) #linear這一行,是隻定義輸入?? return class_score def initialize_weights(self, ): """ 網路權重初始化 :return: """ for m in self.modules(): #modules負責返回所有的modules if isinstance(m, nn.Conv2d): #isinstance是一個內建函式,用於判斷一個物件是否是一個已知的型別 m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.ConvTranspose2d): #轉置卷積 m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.Linear): #線性 m.weight.data.normal_(0, 0.02) m.bias.data.zero_() def print_network(self): num_params = 0 for param in self.parameters(): num_params += param.numel() #numel統計模型引數量,tensor裡的每個元素算一個 print(self) print('Total number of parameters: %d' % num_params) if __name__ =="__main__": #4,資料 data1 = torch.tensor([0.25], dtype=torch.float32) data = torch.tensor([[1],[2],[3],[4],[5], [6],[7],[8],[9],[10],[11],[12],[13],[14],[15]], dtype=torch.float32) label = torch.tensor([[1.103638323514327], [3.0], [1.103638323514327], [0.054946916666202536], [0.0003702294122600387], [3.3760552415777734e-07], [4.166383159489206e-11], [6.958568490730709e-16], [1.5728656990090393e-21], [4.811432671645914e-28], [1.9919031598742206e-35], [1.1160227928062509e-43], [8.462310265380405e-53], [8.683920934944901e-63], [1.2060180647230065e-73]], dtype=torch.float32) #5,設定網路 centers = data1 #data1就是中心引數 rbf = RBFN(centers,1) #設定中心引數 params = rbf.parameters() loss_fn = torch.nn.MSELoss() optimizer = torch.optim.SGD(params,lr=0.1,momentum=0.9) #6,訓練網路 for i in range(10000): optimizer.zero_grad() #1,梯度歸零 y = rbf.forward(data) #2,前向傳播 loss = loss_fn(y,label) #3,求損失函式 loss.backward() #4,反向傳播 optimizer.step() #5,對引數進行更新 print(i,"\t",loss.data) # 7,載入使用 y = rbf.forward(data) print("預測值") print(y.data) print("真實值") print(label.data) print("網路的引數") print(rbf.state_dict())

最後計算出的網路的引數如下:

網路的引數
OrderedDict([('centers', tensor([2.])), ('linear.weight', tensor([[3.]])), ('linear.bias', tensor([-1.9980e-08]))])

比較符合原函式。

 

注意:這個程式碼不是通用的,用它去擬合其他函式就會有些問題,比如網路的輸出全為Nan,或者與原函式的引數不太一樣。具體原因我也不是很清楚。

比如用來擬合這個函式就會有問題

 

 把beta變成引數再擬合其他的一些函式也會有問題,如下面兩個函式