python-pytorch—squeeze()和unsqueeze()函式
技術標籤:pytorch學習筆記python
1.增加維度——unsqueeze()函式
a = torch.randn((3,4))
print("a.shape:",a.shape)
b1 = a.unsqueeze(0)
print("b1.shape:",b1.shape)
b2 = a.unsqueeze(-3)
print("b2.shape:",b2.shape)
print("****************************")
c1 = a.unsqueeze(1)
print( "c1.shape:",c1.shape)
c2 = a.unsqueeze(-2)
print("c2.shape:",c2.shape)
print("****************************")
d1 = a.unsqueeze(2)
print("d1.shape:",d1.shape)
d2 = a.unsqueeze(-1)
print("d2.shape:",d2.shape)
print("****************************" )
輸出:>>
a.shape: torch.Size([3, 4])
b1.shape: torch.Size([1, 3, 4])
b2.shape: torch.Size([1, 3, 4])
c1.shape: torch.Size([3, 1, 4])
c2.shape: torch.Size([3, 1, 4])
d1.shape: torch.Size([3, 4, 1])
d2.shape: torch.Size([3, 4, 1])
再注意,1個錯誤例子:
a = torch.randn((3,4))
e = a.unsqueeze(3)
print ("e.shape:",e.shape)
報錯:因為超出索引範圍[-3,2]
2.減少維度——squeeze()函式
2.1.
a = torch.randn((1,3,4))
print("a.shape:",a.shape)
b1 = a.squeeze(0)
print("b1.shape:",b1.shape)
b2 = a.squeeze(-3)
print("b2.shape:",b2.shape)
print("****************************")
輸出:
a.shape: torch.Size([1, 3, 4])
b1.shape: torch.Size([3, 4])
b2.shape: torch.Size([3, 4])
2.2.
a = torch.randn((1,3,4))
print("a.shape:",a.shape)
c1 = a.squeeze(1)
print("c1.shape:",c1.shape)
c2 = a.squeeze(-2)
print("c2.shape:",c2.shape)
輸出:
a.shape: torch.Size([1, 3, 4])
c1.shape: torch.Size([1, 3, 4])
c2.shape: torch.Size([1, 3, 4])
分析:
- 維度並未發生變化,因為,只有當壓縮的對應維度是1時才可以,
在本案例中,要壓縮第1維(也是第-2維),它的數是3,不能壓縮成功。
而壓縮第0維時,它的數是1,可以成功。
3.ceil()、floor()函式——小數變整數
print("np.math.ceil(1.8):",np.math.ceil(1.8))
print("np.math.floor(1.8)",np.math.floor(1.8))
np.math.ceil(1.8): 2
np.math.floor(1.8) 1
4.slice()——切片函式函式
arr: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
1.切去陣列中前n的資料----slice(n)
slice_a = slice(5)
print(arr[slice_a])
輸出: [1, 2, 3, 4, 5]
2.slice(start, stop)
- 取出陣列起始索引為start,結束索引為(stop-1)的資料
arr = (0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)
slice_ = slice(0,5)
print(arr[slice_])
輸出:(0, 1, 2, 3, 4)
3.slice(start, stop, step)
- 設定起止位置,取出間距為step的資料
slice_b = slice(1,19,8)
print(arr[slice_b])
輸出: [2, 10, 18]
5.pytorch中torch.max和F.softmax函式
input = [[-0.3712, 1.3154, -1.4527],
[ 0.4230, -0.0256, 1.2595]]
1.按列SoftMax,列和為1——F.softmax(input,dim=0)
b = F.softmax(input,dim=0)
*********************************
b = [[0.3113, 0.7927, 0.0623],
[0.6887, 0.2073, 0.9377]]
2.按行SoftMax,行和為1——F.softmax(input,dim=1)
c = F.softmax(input,dim=1)
*********************************
c = [[0.1484, 0.8013, 0.0503],
[0.2534, 0.1618, 0.5849]]
3.按列取max——torch.max(input,dim=0)
d = torch.max(input,dim=0)
******************************
values=[0.4230, 1.3154, 1.2595]
indices=[1, 0, 1]
4.按行取max——torch.max(input,dim=1)
e = torch.max(input,dim=1)
*****************************
values=[1.3154, 1.2595]
indices=[1, 2]