pytorch產生loss的計算圖
import torch.nn as nn import torch.nn.functional as F
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv1=nn.Conv2d(1,6,5) self.conv2=nn.Conv2d(6,16,5) self.fc1=nn.Linear(16*5*5,120) self.fc2=nn.Linear(120,84) self.fc3=nn.Linear(84,10) def forward(self,x): x=F.max_pool2d(F.relu(self.conv1(x)),(2,2)) x=F.max_pool2d(F.relu(self.conv2(x)),2) x=x.view(x.size()[0],-1) print(x) x=F.relu(self.fc1(x)) x=F.relu(self.fc2(x)) x=self.fc3(x) return x net=Net() #params=list(net.parameters()) #for name,parameters in net.named_parameters(): # print(name,':',parameters.size()) #print(len(params)) #print(net) input=Variable(t.randn(1,1,32,32)) output=net(input) #out.size() target=Variable(t.arange(0,10)) criterion=nn.MSELoss() loss=criterion(output,target) loss.grad_fn