莫煩Pytorch之快速構建網路程式碼
阿新 • • 發佈:2018-12-19
import torch import torch.nn.functional as F class Net(torch.nn.Module): def __init__(self,n_feature,n_hidden,n_output): super(Net,self).__init__() self.hidden=torch.nn.Linear(n_feature,n_hidden) self.predict=torch.nn.Linear(n_hidden,n_output) def forward(self,x): x=F.relu(self.hidden(x)) x=self.predict(x) return x net1=Net(1,10,1) net2=torch.nn.Sequential( torch.nn.Linear(1,10), torch.nn.ReLU(), #這裡的ReLU是大寫,是一個類 torch.nn.Linear(10,1)) print(net1) print(net2)
#輸出結果 Net( (hidden): Linear(in_features=1, out_features=10, bias=True) (predict): Linear(in_features=10, out_features=1, bias=True) ) Sequential( (0): Linear(in_features=1, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=1, bias=True) )