pytorch中 max()、view()、 squeeze()、 unsqueeze()
阿新 • • 發佈:2018-12-13
查了好多部落格都似懂非懂,後來寫了幾個小例子,瞬間一目瞭然。
目錄
一、torch.max()
import torch a=torch.randn(3) print("a:\n",a) print('max(a):',torch.max(a)) b=torch.randn(3,4) print("b:\n",b) print('max(b,0):',torch.max(b,0)) print('max(b,1):',torch.max(b,1))
輸出:
a: tensor([ 0.9558, 1.1242, 1.9503]) max(a): tensor(1.9503) b: tensor([[ 0.2765, 0.0726, -0.7753, 1.5334], [ 0.0201, -0.0005, 0.2616, -1.1912], [-0.6225, 0.6477, 0.8259, 0.3526]]) max(b,0): (tensor([ 0.2765, 0.6477, 0.8259, 1.5334]), tensor([ 0, 2, 2, 0])) max(b,1): (tensor([ 1.5334, 0.2616, 0.8259]), tensor([ 3, 2, 2]))
max(a),用於一維資料,求出最大值。
max(a,0),計算出資料中一列的最大值,並輸出最大值所在的行號。
max(a,0),計算出資料中一行的最大值,並輸出最大值所在的列號。
print('max(b,1):',torch.max(b,1)[1])
輸出:只輸出行最大值所在的列號
max(b,1): tensor([ 3, 2, 2])
torch.max(b,1)[0], 只返回最大值的每個數
二、view()
a.view(i,j)表示將原矩陣轉化為i行j列的形式
i為-1表示不限制行數,輸出1列
a=torch.randn(3,4)
print(a)
輸出:
tensor([[-0.8146, -0.6592, 1.5100, 0.7615],
[ 1.3021, 1.8362, -0.3590, 0.3028],
[ 0.0848, 0.7700, 1.0572, 0.6383]])
b=a.view(-1,1)
print(b)
輸出:
tensor([[-0.8146],
[-0.6592],
[ 1.5100],
[ 0.7615],
[ 1.3021],
[ 1.8362],
[-0.3590],
[ 0.3028],
[ 0.0848],
[ 0.7700],
[ 1.0572],
[ 0.6383]])
i為1,j為-1表示不限制列數,輸出1行
b=a.view(1,-1)
print(b)
輸出:
tensor([[-0.8146, -0.6592, 1.5100, 0.7615, 1.3021, 1.8362, -0.3590,
0.3028, 0.0848, 0.7700, 1.0572, 0.6383]])
i為-1,j為2表示不限制行數,輸出2列
b=a.view(-1,2)
print(b)
輸出:
tensor([[-0.8146, -0.6592],
[ 1.5100, 0.7615],
[ 1.3021, 1.8362],
[-0.3590, 0.3028],
[ 0.0848, 0.7700],
[ 1.0572, 0.6383]])
i為-1,j為3表示不限制行數,輸出3列
i為4,j為3表示輸出4行3列
b=a.view(-1,3)
print(b)
b=a.view(4,3)
print(b)
輸出:
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
三、
1.torch.squeeze()
壓縮矩陣,我理解為降維
a.squeeze(i) 壓縮第i維,如果這一維維數是1,則這一維可有可無,便可以壓縮
import torch
a=torch.randn(1,3,4)
print(a)
b=a.squeeze(0)
print(b)
c=a.squeeze(1)
print(c
輸出:
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
一頁三行4列的矩陣
第0維為1,則可以通過squeeze(0)刪掉,轉化為三行4列的矩陣
tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]])
第1維不為1,則不可以壓縮
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
2.torch.unsqueeze()
unsqueeze(i) 表示將第i維設定為1
對壓縮為3行4列後的矩陣b進行操作,將第0維設定為1
c=b.unsqueeze(0)
print(c)
輸出一個一頁三行四列的矩陣
tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
[ 1.2210, -0.1084, -0.1166, -0.2379],
[-1.0012, -0.4363, 1.0057, -1.5180]]])
將第一維設定為1
c=b.unsqueeze(1)
print(c)
輸出一個3頁,一行,4列的矩陣
tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
[[-2.3976, 0.9857, -0.3462, -0.3648]],
[[ 1.1012, -0.4659, -0.0858, 1.6631]]])
另外,squeeze、unsqueeze操作不改變原矩陣
ok!!!