PyTorch:VGG16網路模型
阿新 • • 發佈:2021-02-09
技術標籤:PyTorchPyTorchPythonVGG16
深度學習界的小學生一枚
從頭開始,先了解一下VGG16,哈哈哈
VGG16的結構圖隨便百度一下就好
這裡主要記錄一下程式碼流程
import torch.nn as nn from torchsummary import summary import torch class VGG16(nn.Module): def __init__(self): super(VGG16, self).__init__() self.maxpool1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.maxpool2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.maxpool3 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.maxpool4 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.maxpool5 = nn.Sequential( nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) self.dense = nn.Sequential( nn.Linear(512 * 5 * 5, 4096), nn.ReLU(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, 1000) ) def forward(self, x): pool1 = self.maxpool1(x) pool2 = self.maxpool2(pool1) pool3 = self.maxpool3(pool2) pool4 = self.maxpool4(pool3) pool5 = self.maxpool5(pool4) flat = pool5.view(pool5.size(0), -1) class_ = self.dense(flat) return class_ if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vgg_model = VGG16().to(device) #給模型一個輸入,檢視輸出size x = torch.rand(size=(8, 3, 160, 160)) vgg_model(x) print(vgg_model(x).size()) summary(vgg_model, (3, 160, 160)) # 列印網路結構
輸出
D:\Anaconda3\envs\pytorch\python.exe E:/LearningCode/Torch/Torch-Learning/vgg16-learning-01.py torch.Size([8, 1000]) ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 160, 160] 1,792 ReLU-2 [-1, 64, 160, 160] 0 Conv2d-3 [-1, 64, 160, 160] 36,928 ReLU-4 [-1, 64, 160, 160] 0 MaxPool2d-5 [-1, 64, 80, 80] 0 Conv2d-6 [-1, 128, 80, 80] 73,856 ReLU-7 [-1, 128, 80, 80] 0 Conv2d-8 [-1, 128, 80, 80] 147,584 ReLU-9 [-1, 128, 80, 80] 0 MaxPool2d-10 [-1, 128, 40, 40] 0 Conv2d-11 [-1, 256, 40, 40] 295,168 ReLU-12 [-1, 256, 40, 40] 0 Conv2d-13 [-1, 256, 40, 40] 590,080 ReLU-14 [-1, 256, 40, 40] 0 Conv2d-15 [-1, 256, 40, 40] 590,080 ReLU-16 [-1, 256, 40, 40] 0 MaxPool2d-17 [-1, 256, 20, 20] 0 Conv2d-18 [-1, 512, 20, 20] 1,180,160 ReLU-19 [-1, 512, 20, 20] 0 Conv2d-20 [-1, 512, 20, 20] 2,359,808 ReLU-21 [-1, 512, 20, 20] 0 Conv2d-22 [-1, 512, 20, 20] 2,359,808 ReLU-23 [-1, 512, 20, 20] 0 MaxPool2d-24 [-1, 512, 10, 10] 0 Conv2d-25 [-1, 512, 10, 10] 2,359,808 ReLU-26 [-1, 512, 10, 10] 0 Conv2d-27 [-1, 512, 10, 10] 2,359,808 ReLU-28 [-1, 512, 10, 10] 0 Conv2d-29 [-1, 512, 10, 10] 2,359,808 ReLU-30 [-1, 512, 10, 10] 0 MaxPool2d-31 [-1, 512, 5, 5] 0 Linear-32 [-1, 4096] 52,432,896 ReLU-33 [-1, 4096] 0 Linear-34 [-1, 4096] 16,781,312 ReLU-35 [-1, 4096] 0 Linear-36 [-1, 1000] 4,097,000 ================================================================ Total params: 88,025,896 Trainable params: 88,025,896 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.29 Forward/backward pass size (MB): 111.56 Params size (MB): 335.79 Estimated Total Size (MB): 447.64 ---------------------------------------------------------------- Process finished with exit code 0