1. 程式人生 > >Pytorch中的squeeze()和unsqueeze()函式

Pytorch中的squeeze()和unsqueeze()函式

在numpy庫中,經常會出現“秩為1的一維陣列”(come from 吳恩達的深度學習,目前還沒有搞清楚numpy中如此設計的意圖)。比如:

        

注意這裡的a的shape是[3] ,既不是 [1,3] 也不是 [3,1]。這就說明它既不是行向量也不是列向量,只是一個數組。

但是我們可以用squeeze()和unsqueeze()對其進行操作,比如:

       

在對這兩個函式講解之前,我們先統一一下定義:

In  : x.shape
Out: torch.Size( [ a , b , c ] )

這裡我們說a是第一個維度(表示第一個維度的數量是a),b是第二個維度(表示第一個維度的數量是b),c是第三個維度(表示第三個維度的數量是c)

簡而言之,unsqueeze(arg)是增添第arg個維度為1,以插入的形式填充。比如:

       

       

相反,squeeze(arg)是刪除第arg個維度(如果當前維度不為1,則不會進行刪除),比如: