1. 程式人生 > 程式設計 >PyTorch中permute的用法詳解

PyTorch中permute的用法詳解

permute(dims)

將tensor的維度換位。

引數:引數是一系列的整數,代表原來張量的維度。比如三維就有0,1,2這些dimension。

例:

import torch
import numpy as np
a=np.array([[[1,2,3],[4,5,6]]])
unpermuted=torch.tensor(a)
print(unpermuted.size()) # ——> torch.Size([1,3])
permuted=unpermuted.permute(2,1)
print(permuted.size()) # ——> torch.Size([3,1,2])

再比如圖片img的size比如是(28,28,3)就可以利用img.permute(2,1)得到一個size為(3,28,28)的tensor。

利用這個函式permute(1,3,2)可以把Tensor([[[1,6]]]) 轉換成

tensor([[[1.,4.],[2.,5.],[3.,6.]]])

如果使用view(1,3,2),可以得到

tensor([[[1.,2.],[5.,6.]]])

以上這篇PyTorch中permute的用法詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。