1. 程式人生 > 其它 >PyTorch的Variable已經不需要用了!!!

PyTorch的Variable已經不需要用了!!!

轉載自:https://blog.csdn.net/rambo_csdn_123/article/details/119056123

 

Pytorch的torch.autograd.Variable
今天在看《莫凡Python》的PyTorch教程的時候發現他的程式碼還在使用Variable,並且我記得過去讀一些GitHub上面的程式碼的時候也發現了Variable這個東西,根據教程中所說的,想要計算tensor的梯度等等,就必須將tensor放入Variable中並指定required_grad的值為True,通過這個Variable的容器才能進行梯度求導等等操作,程式碼如下:

import torch
from torch.autograd import Variable

tensor = torch.FloatTensor([[1, 2], [3, 4]])
variable = Variable(tensor, requires_grad=True)
v_out = torch.mean(variable * variable)
v_out.backward()
print(variable.grad)

  在我查閱PyTorch的官方文件之後,發現Variable已經被放棄使用了,因為tensor自己已經支援自動求導的功能了,只要把requires_grad屬性設定成True就可以了,所以下次見到Variable可以大膽地更改程式碼

 

 例如之前的程式碼可以改成

import torch

tensor = torch.FloatTensor([[1, 2], [3, 4]])
tensor.requires_grad = True  # 這個尤其重要!!!
t_out = torch.mean(tensor * tensor)
t_out.backward()
print(tensor.grad)