1. 程式人生 > >pytorch產生loss的計算圖

pytorch產生loss的計算圖

import torch.nn as nn import torch.nn.functional as F

class Net(nn.Module):     def __init__(self):         super(Net,self).__init__()         self.conv1=nn.Conv2d(1,6,5)         self.conv2=nn.Conv2d(6,16,5)         self.fc1=nn.Linear(16*5*5,120)         self.fc2=nn.Linear(120,84)         self.fc3=nn.Linear(84,10)     def forward(self,x):         x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))         x=F.max_pool2d(F.relu(self.conv2(x)),2)         x=x.view(x.size()[0],-1)         print(x)         x=F.relu(self.fc1(x))         x=F.relu(self.fc2(x))         x=self.fc3(x)         return x net=Net() #params=list(net.parameters()) #for name,parameters in net.named_parameters(): #    print(name,':',parameters.size()) #print(len(params)) #print(net) input=Variable(t.randn(1,1,32,32)) output=net(input) #out.size() target=Variable(t.arange(0,10)) criterion=nn.MSELoss() loss=criterion(output,target) loss.grad_fn