1. 程式人生 > 其它 >pytorch與numpy中維度順序轉換的比較

pytorch與numpy中維度順序轉換的比較

技術標籤:pytorchpytorchnumpypermutetranpose

pytorch

tensor.permute(dims)

將tensor的維度換位。

引數:引數是一系列的整數,代表原來張量的維度。比如三維就有0,1,2這些dimension。

例:

import torch

a = torch.rand(8,256,256,3)   #---> n,h,w,c
print(a.shape)

b = a.permute(0,3,1,2)  # ---> n,c,h,w
print(b.shape)

輸出:

torch.Size([8, 256, 256, 3])
torch.Size([8, 3, 256, 256])


numpy

numy.transpose(a,axis=None)

引數 a: 輸入陣列
axis: int型別的列表,這個引數是可選的。預設情況下,反轉的輸入陣列的維度,當給定這個引數時,按照這個引數所定的值進行陣列變換。
返回值 p: ndarray 返回轉置過後的原陣列的檢視。

import numpy as  np

x = np.random.randn(8,256,256,3)  # ---> n,h,w,c
print(x.shape)
y=x.transpose((0,3,1,2))   #  ----> n,c,h,w
print(y.shape)

輸出:

(8, 256, 256, 3)
(8, 3, 256, 256)

transpose 的原理其實是根據維度(shape)索引決定的

例:

import numpy as  np

x = np.arange(4).reshape((2,2)) #生成一個2x2的陣列
print(x)

[[0 1]
 [2 3]]

生成了一個維度為二維陣列,其中有兩個索引值(矩陣的行與列)。

transpose()函式的作用就是調換陣列的行列值的索引值,類似於求矩陣的轉置:

x = np.arange(4).reshape((2,2))

x = np.transpose(x)

print(x)

[[0 2]

[1 3]]

我們可以直觀的看到,陣列的行列索引值對換,1的位置從x(0,1)跑到了x(1,0)。

那麼三維陣列呢?

我們繼續生成一個三維的陣列:

x = np.arange(12).reshape((2,2,3)) //生成一個2x2x3的陣列

print(x)

[[[ 0 1 2]

 [ 3 4 5]]

[[ 6 7 8]

 [ 9 10 11]]]

從高中數學知道三維由x軸、y軸以及z軸組成。

假設三維陣列當中的索引值為x,y,z

transpose()函式的作用就是調換x,y,z的位置,也就是陣列的索引值。

所以正常的陣列索引值為(0,1,2),等於(x,y,z)

我們來看例項程式碼:

x = np.arange(12).reshape((2,2,3))

print(x)

[[[ 0 1 2]

[ 3 4 5]]

[[ 6 7 8]

[ 9 10 11]]]


x = np.transpose(x,(1,0,2)) //transpose()函式的第二個引數就是改變索引值的地方

print(x)

[[[ 0 1 2]

[ 6 7 8]]

[[ 3 4 5]

[ 9 10 11]]]

通過transpose()函式改變了x的索引值為(1,0,2),對應(y,x,z)

索引改變後原本y的值和x的值對換了。

有上面程式碼的數字7為例,原本的7的位置索引為(1,0,1),通過transpose(x,(1,0,2))索引改變為(0,1,1)

無論四維、五維……都可以用這個原理分析。