1. 程式人生 > >torch畫散點圖

torch畫散點圖

很多 pyplot 維度 bubuko squeeze 結果 技術 ria num

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)   #torch.linspace本身是一維向量,unsqueeze是增加維度,把一維化為二維
y=x.pow(2)+0.2*torch.rand(x.size())   
#0.2*torch.rand(x.size())相當於給散點圖加噪聲
x,y=Variable(x),Variable(y)
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

輸出結果為:

技術分享圖片

代碼中,Variable是變量的意思。包含很多屬性,常用的是.data,還可以計算梯度。

torch畫散點圖