x = x.view(x.size(0), -1) 的理解
阿新 • • 發佈:2019-01-08
之前對於pytorch的網路程式設計學習都是大致理解每一層的概念,有些語法語句沒有從原理上弄清楚,就比如標題的x = x.view(x.size(0), -1) 。
這句話一般出現在model類的forward函式中,具體位置一般都是在呼叫分類器之前。分類器是一個簡單的nn.Linear()結構,輸入輸出都是維度為一的值,x = x.view(x.size(0), -1) 這句話的出現就是為了將前面多維度的tensor展平成一維。下面是個簡單的例子,我將會根據例子來對該語句進行解析。
class NET(nn.Module): def __init__(self,batch_size): super(NET,self).__init__() self.conv = nn.Conv2d(outchannels=3,in_channels=64,kernel_size=3,stride=1) self.fc = nn.Linear(64*batch_size,10) def forward(self,x): x = self.conv(x) x = x.view(x.size(0), -1) out = self.fc(x)
上面是個簡單的網路結構,包含一個卷積層和一個分類層。forward()函式中,input首先經過卷積層,此時的輸出x是包含batchsize維度為4的tensor,即(batchsize,channels,x,y),x.size(0)指batchsize的值。x = x.view(x.size(0), -1)簡化x = x.view(batchsize, -1)。
view()函式的功能根reshape類似,用來轉換size大小。x = x.view(batchsize, -1)中batchsize指轉換後有幾行,而-1指在不告訴函式有多少列的情況下,根據原tensor資料和batchsize自動分配列數。