1. 程式人生 > >x = x.view(x.size(0), -1) 的理解

x = x.view(x.size(0), -1) 的理解

之前對於pytorch的網路程式設計學習都是大致理解每一層的概念,有些語法語句沒有從原理上弄清楚,就比如標題的x = x.view(x.size(0), -1)  。

這句話一般出現在model類的forward函式中,具體位置一般都是在呼叫分類器之前。分類器是一個簡單的nn.Linear()結構,輸入輸出都是維度為一的值,x = x.view(x.size(0), -1)  這句話的出現就是為了將前面多維度的tensor展平成一維。下面是個簡單的例子,我將會根據例子來對該語句進行解析。

class NET(nn.Module):
    def __init__(self,batch_size):
        super(NET,self).__init__()
        self.conv = nn.Conv2d(outchannels=3,in_channels=64,kernel_size=3,stride=1)
        self.fc = nn.Linear(64*batch_size,10)

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)  
        out = self.fc(x)

上面是個簡單的網路結構,包含一個卷積層和一個分類層。forward()函式中,input首先經過卷積層,此時的輸出x是包含batchsize維度為4的tensor,即(batchsize,channels,x,y),x.size(0)指batchsize的值。x = x.view(x.size(0), -1)簡化x = x.view(batchsize, -1)。

view()函式的功能根reshape類似,用來轉換size大小。x = x.view(batchsize, -1)中batchsize指轉換後有幾行,而-1指在不告訴函式有多少列的情況下,根據原tensor資料和batchsize自動分配列數。