Pytorch錯誤Expected input batch_size (324) to match target batch_size (4) Log In
阿新 • • 發佈:2021-11-18
參考連結:
Pytorch Error: ValueError: Expected input batch_size (324) to match target batch_size (4) Log In
1.ERROR原因
使用pytorch訓練一個自定義的模型,參照網上的部落格直接照搬網路,但是在修改自定義資料集時,出現這個錯誤。很明顯是一個影象引數不匹配問題,自定義資料集的圖片大小規格不統一且與網路接受的大小不匹配。
ValueError: Expected input batch_size (324) to match target batch_size (4) Log In
2.解決思路
首先,在錯誤的網路結構處前後加入print來檢視網路結構。
# 構建CNN模型 class CNNNet(nn.Module): def __init__(self): super(CNNNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(64, 128, 5) self.fc1 = nn.Linear(128*53*53, 1024) self.fc2 = nn.Linear(1024, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) # print(x.shape) x = x.view(-1, 128*53*53) # print(x.shape) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
即我註釋的這個地方,可以得到輸入前的資料格式。
torch.Size([4, 128, 53, 53])
根據輸出的形狀來更改view裡的引數。
x = x.view(-1, 128 * 53 * 53)
後面的Linear層也需要對應修改,使其與資料輸入匹配:
self.fc1 = nn.Linear(128 * 53 * 53, 1024)