1. 程式人生 > 其它 >python-pytorch—squeeze()和unsqueeze()函式

python-pytorch—squeeze()和unsqueeze()函式

技術標籤:pytorch學習筆記python

1.增加維度——unsqueeze()函式

a = torch.randn((3,4))
print("a.shape:",a.shape)

b1 = a.unsqueeze(0)
print("b1.shape:",b1.shape)
b2 = a.unsqueeze(-3)
print("b2.shape:",b2.shape)
print("****************************")

c1 = a.unsqueeze(1)
print(
"c1.shape:",c1.shape) c2 = a.unsqueeze(-2) print("c2.shape:",c2.shape) print("****************************") d1 = a.unsqueeze(2) print("d1.shape:",d1.shape) d2 = a.unsqueeze(-1) print("d2.shape:",d2.shape) print("****************************"
)

輸出:>>

a.shape: torch.Size([3, 4])


b1.shape: torch.Size([1, 3, 4])
b2.shape: torch.Size([1, 3, 4])


c1.shape: torch.Size([3, 1, 4])
c2.shape: torch.Size([3, 1, 4])


d1.shape: torch.Size([3, 4, 1])
d2.shape: torch.Size([3, 4, 1])


再注意,1個錯誤例子:

a = torch.randn((3,4))
e = a.unsqueeze(3)
print
("e.shape:",e.shape)

報錯:因為超出索引範圍[-3,2]
在這裡插入圖片描述

2.減少維度——squeeze()函式

2.1.

a = torch.randn((1,3,4))
print("a.shape:",a.shape)

b1 = a.squeeze(0)
print("b1.shape:",b1.shape)
b2 = a.squeeze(-3)
print("b2.shape:",b2.shape)
print("****************************")

輸出:


a.shape: torch.Size([1, 3, 4])


b1.shape: torch.Size([3, 4])
b2.shape: torch.Size([3, 4])

2.2.

a = torch.randn((1,3,4))
print("a.shape:",a.shape)

c1 = a.squeeze(1)
print("c1.shape:",c1.shape)
c2 = a.squeeze(-2)
print("c2.shape:",c2.shape)

輸出:


a.shape: torch.Size([1, 3, 4])


c1.shape: torch.Size([1, 3, 4])
c2.shape: torch.Size([1, 3, 4])

分析:

  • 維度並未發生變化,因為,只有當壓縮的對應維度是1時才可以,
    在本案例中,要壓縮第1維(也是第-2維),它的數是3,不能壓縮成功
    而壓縮第0維時,它的數是1,可以成功

3.ceil()、floor()函式——小數整數

print("np.math.ceil(1.8):",np.math.ceil(1.8))
print("np.math.floor(1.8)",np.math.floor(1.8))

np.math.ceil(1.8): 2
np.math.floor(1.8) 1

4.slice()——切片函式函式

arr: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

1.切去陣列中前n的資料----slice(n)

slice_a = slice(5)
print(arr[slice_a])

輸出: [1, 2, 3, 4, 5]

2.slice(start, stop)

  • 取出陣列起始索引為start,結束索引為(stop-1)的資料
arr = (0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)
slice_ = slice(0,5)
print(arr[slice_])

輸出:(0, 1, 2, 3, 4)

3.slice(start, stop, step)

  • 設定起止位置,取出間距為step的資料
slice_b = slice(1,19,8)
print(arr[slice_b])

輸出: [2, 10, 18]

5.pytorch中torch.max和F.softmax函式

input = [[-0.3712,  1.3154, -1.4527],
         [ 0.4230, -0.0256,  1.2595]]

1.按列SoftMax,列和為1——F.softmax(input,dim=0)

b = F.softmax(input,dim=0)
*********************************
b = [[0.3113, 0.7927, 0.0623],
     [0.6887, 0.2073, 0.9377]]

2.按行SoftMax,行和為1——F.softmax(input,dim=1)

c = F.softmax(input,dim=1)
*********************************
c = [[0.1484, 0.8013, 0.0503],
     [0.2534, 0.1618, 0.5849]]

3.按列取max——torch.max(input,dim=0)

d = torch.max(input,dim=0)
******************************
values=[0.4230, 1.3154, 1.2595]
indices=[1, 0, 1]

4.按行取max——torch.max(input,dim=1)

e = torch.max(input,dim=1)
*****************************
values=[1.3154, 1.2595]
indices=[1, 2]