1. 程式人生 > 其它 >1、PyTorch基本操作

1、PyTorch基本操作

一、簡介

  簡單介紹PyTorch框架,基本使用和安裝方法。Torch是什麼?一個火炬!其實跟Tensorflow中Tensor是一個意思,就是說,有一批資料,無論是影象資料還是文字資料或數值資料,都需要把資料轉換成矩陣,接下來在建模操作過程中,都需要對當前資料即矩陣,做各種各樣變換,做各種各樣計算,一系列流程做完之後得到我們想要的結果。PyTorch可以說是做這樣一件事,把所有矩陣計算的東西傳入GPU中,因為GPU中做矩陣運算比較快,在GPU中幫我們實現了所有的計算功能,整體的計算,從前向傳播到反向傳播,有可能會涉及到非常複雜的計算,這些計算統統由框架幫我們實現,我們需要去做的,就是設計整個任務的流程,整個網路架構就可以了。深度學習框架,說白了,就是一個計算的工具,幫我們實現由前到後整體的計算流程。
一個框架該怎麼學?
  學框架不要去看基本的操作, 直接看一個實際的例子,一步步怎麼走的即可,遇到一些基本的點,可能遇到某些函式,儘量去查,查的過程其實也是學習的過程。
二、例項

1、匯入torch包

1 #匯入PyTorch包
2 import torch
3 print(torch.__version__)

輸出結果:

 2、建立一個空的張量

 1 #基本使用方法
 2 #建立一個全零的5行3列的矩陣,格式是一個tensor
 3 #tensor即張量,理解為矩陣即可,一維是向量,二維是矩陣,不管多少維,統一叫
 4 #做tensor,是深度學習中最基本的計算單元,
 5 #也可以說是框架的底層。之前用過的其他結構,如numpy、pandas之類的,會得到
 6 #ndarry或DataFrame類似的結構,
 7 #看起來也是矩陣,但是不能在我們這裡做,要用PyTorch框架,①把所有的資料轉換
8 #成tensor的格式,tensor是底層所支援的 9 #格式,所有的輸入,所有的計算,都是對tensor所執行的。 10 x=torch.empty(5,3) 11 print("x:\n",x)

輸出結果:

 3、建立一個隨機的53列的矩陣

1 #建立一個隨機的5行3列的矩陣
2 x1=torch.rand(5,3)
3 print("x1:\n",x1)

輸出結果:

 4、初始化一個全零的矩陣

1 #初始化一個全零的矩陣
2 x2=torch.zeros(5,3,dtype=torch.long)
3 print("x2:\n",x2)

輸出結果:

 5、直接傳入資料

1 #直接傳入資料
2 x3=torch.tensor([5.5,3])
3 print("x3:\n",x3)

輸出結果:

6、生成全為1的矩陣並隨機初始化

1 x4=x2.new_ones(5,3,dtype=torch.double)  #生成全為1的矩陣
2 x5=torch.randn_like(x4,dtype=torch.float)
3 print("x4:\n",x4)
4 print("x5:\n",x5)
5 #建議執行完每次操作之後,列印維度看一看
6 print(x5.size())

輸出結果:

 7、基本計算方法:加法操作

1 #基本計算方法
2 y=torch.rand(5,3)
3 print("y+x5=",y+x5)
4 print(torch.add(y,x5)) #一樣的操作

輸出結果:

 8、索引

1 #索引
2 print("x5[:,1]:\n",x5[:,1])

輸出結果:

 9、view操作可以改變矩陣維度

1 #view操作可以改變矩陣維度
2 x6=torch.randn(4,4)
3 print("x6:\n",x6)
4 y1=x6.view(16) #將x6拉成一行向量
5 print("y1:\n",y1)
6 z=x6.view(-1,8) #-1代表自動做計算,第二個維度有8個元素,第一個維度自動計算
7 print("z:\n",z)
8 print("x6.size:\n",x6.size(),"\n","y1.size:\n",y1.size(),"\n","z.size:\n",z.size())

輸出結果:

 10、與numpy的協同操作

1 #與numpy的協同操作
2 a=torch.ones(5)
3 b=a.numpy()
4 print("b:\n",b)
5 print("type(b):\n",type(b))

輸出結果:

11、numpytensor

1 #numpy轉tensor
2 a1=np.ones(5)
3 b1=torch.from_numpy(a1)
4 print("b1:\n",b1)

輸出結果: