1. 程式人生 > >莫煩Pytorch之快速構建網路程式碼

莫煩Pytorch之快速構建網路程式碼

import torch
import torch.nn.functional as F

class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden=torch.nn.Linear(n_feature,n_hidden)
        self.predict=torch.nn.Linear(n_hidden,n_output)
        
    def forward(self,x):
        x=F.relu(self.hidden(x))
        x=self.predict(x)
        return x
net1=Net(1,10,1)

net2=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),      #這裡的ReLU是大寫,是一個類
        torch.nn.Linear(10,1))
print(net1)
print(net2)

#輸出結果 Net( (hidden): Linear(in_features=1, out_features=10, bias=True) (predict): Linear(in_features=10, out_features=1, bias=True) ) Sequential( (0): Linear(in_features=1, out_features=10, bias=True) (1): ReLU() (2): Linear(in_features=10, out_features=1, bias=True) )