1. 程式人生 > >pytorch中 max()、view()、 squeeze()、 unsqueeze()

pytorch中 max()、view()、 squeeze()、 unsqueeze()

查了好多部落格都似懂非懂,後來寫了幾個小例子,瞬間一目瞭然。

目錄

一、torch.max()

二、torch.view()

三、

1.torch.unsqueeze()

2.squeeze()


一、torch.max()

import torch 

a=torch.randn(3)
print("a:\n",a)
print('max(a):',torch.max(a))

b=torch.randn(3,4)
print("b:\n",b)
print('max(b,0):',torch.max(b,0))
print('max(b,1):',torch.max(b,1))

   輸出:

a:
 tensor([ 0.9558,  1.1242,  1.9503])
max(a): tensor(1.9503)
b:
 tensor([[ 0.2765,  0.0726, -0.7753,  1.5334],
        [ 0.0201, -0.0005,  0.2616, -1.1912],
        [-0.6225,  0.6477,  0.8259,  0.3526]])
max(b,0): (tensor([ 0.2765,  0.6477,  0.8259,  1.5334]), tensor([ 0,  2,  2,  0]))
max(b,1): (tensor([ 1.5334,  0.2616,  0.8259]), tensor([ 3,  2,  2]))

  max(a),用於一維資料,求出最大值。

  max(a,0),計算出資料中一列的最大值,並輸出最大值所在的行號。

  max(a,0),計算出資料中一行的最大值,並輸出最大值所在的列號。

print('max(b,1):',torch.max(b,1)[1])

  輸出:只輸出行最大值所在的列號

max(b,1): tensor([ 3,  2,  2])

  torch.max(b,1)[0], 只返回最大值的每個數

二、view()

a.view(i,j)表示將原矩陣轉化為i行j列的形式   

  i為-1表示不限制行數,輸出1列

a=torch.randn(3,4)
print(a)

輸出:
tensor([[-0.8146, -0.6592,  1.5100,  0.7615],
        [ 1.3021,  1.8362, -0.3590,  0.3028],
        [ 0.0848,  0.7700,  1.0572,  0.6383]])

b=a.view(-1,1)
print(b)

輸出:
tensor([[-0.8146],
        [-0.6592],
        [ 1.5100],
        [ 0.7615],
        [ 1.3021],
        [ 1.8362],
        [-0.3590],
        [ 0.3028],
        [ 0.0848],
        [ 0.7700],
        [ 1.0572],
        [ 0.6383]])

i為1,j為-1表示不限制列數,輸出1行

b=a.view(1,-1)
print(b)

輸出:
tensor([[-0.8146, -0.6592,  1.5100,  0.7615,  1.3021,  1.8362, -0.3590,
          0.3028,  0.0848,  0.7700,  1.0572,  0.6383]])

 i為-1,j為2表示不限制行數,輸出2列

b=a.view(-1,2)
print(b)

輸出:
tensor([[-0.8146, -0.6592],
        [ 1.5100,  0.7615],
        [ 1.3021,  1.8362],
        [-0.3590,  0.3028],
        [ 0.0848,  0.7700],
        [ 1.0572,  0.6383]])

 i為-1,j為3表示不限制行數,輸出3列 

 i為4,j為3表示輸出4行3列 

b=a.view(-1,3)
print(b)
b=a.view(4,3)
print(b)

輸出:
tensor([[-0.8146, -0.6592,  1.5100],
        [ 0.7615,  1.3021,  1.8362],
        [-0.3590,  0.3028,  0.0848],
        [ 0.7700,  1.0572,  0.6383]])
tensor([[-0.8146, -0.6592,  1.5100],
        [ 0.7615,  1.3021,  1.8362],
        [-0.3590,  0.3028,  0.0848],
        [ 0.7700,  1.0572,  0.6383]])

三、

1.torch.squeeze()

壓縮矩陣,我理解為降維

a.squeeze(i)   壓縮第i維,如果這一維維數是1,則這一維可有可無,便可以壓縮

import torch 

a=torch.randn(1,3,4)
print(a)
b=a.squeeze(0)
print(b)
c=a.squeeze(1)
print(c

輸出:

tensor([[[ 0.4627,  1.6447,  0.1320,  2.0946],
         [-0.0080,  0.1794,  1.1898, -1.2525],
         [ 0.8281, -0.8166,  1.8846,  0.9008]]])
一頁三行4列的矩陣

第0維為1,則可以通過squeeze(0)刪掉,轉化為三行4列的矩陣
tensor([[ 0.4627,  1.6447,  0.1320,  2.0946],
        [-0.0080,  0.1794,  1.1898, -1.2525],
        [ 0.8281, -0.8166,  1.8846,  0.9008]])
第1維不為1,則不可以壓縮
tensor([[[ 0.4627,  1.6447,  0.1320,  2.0946],
         [-0.0080,  0.1794,  1.1898, -1.2525],
         [ 0.8281, -0.8166,  1.8846,  0.9008]]])

 

2.torch.unsqueeze()

 unsqueeze(i) 表示將第i維設定為1

對壓縮為3行4列後的矩陣b進行操作,將第0維設定為1

c=b.unsqueeze(0)
print(c)

輸出一個一頁三行四列的矩陣
tensor([[[ 0.0661, -0.2386, -0.6610,  1.5774],
         [ 1.2210, -0.1084, -0.1166, -0.2379],
         [-1.0012, -0.4363,  1.0057, -1.5180]]])

將第一維設定為1 

c=b.unsqueeze(1)
print(c)

輸出一個3頁,一行,4列的矩陣
tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],

        [[-2.3976,  0.9857, -0.3462, -0.3648]],

        [[ 1.1012, -0.4659, -0.0858,  1.6631]]])

另外,squeeze、unsqueeze操作不改變原矩陣

ok!!!