1. 程式人生 > 其它 >Pytorch CNN網路MNIST數字識別 [超詳細記錄] 學習筆記(三)

Pytorch CNN網路MNIST數字識別 [超詳細記錄] 學習筆記(三)

目錄

1. 準備資料集

1.1 MNIST資料集獲取:

  • torchvision.datasets介面直接下載,該介面可以直接構建資料集,推薦

  • 其他途徑下載後,編寫程式進行讀取,然後由Datasets構建自己的資料集

​ ​ 本文使用第一種方法獲取資料集,並使用Dataloader進行按批裝載。如果使用程式下載失敗,請將其他途徑下載的MNIST資料集 [檔案][解壓檔案] 放置在 <data/MNIST/raw/> 位置下,本文的程式及檔案結構圖如下:

​ ​ 其中,model資料夾用來儲存每個epoch訓練的模型引數,根資料夾下包含model.py用於訓練模型,test.py為測試集測試,show.py為展示部分

1.2 程式部分

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

# 1. 準備資料集
## 1.1 使用torchvision自動下載MNIST資料集
train_data = datasets.MNIST(root='data\\',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)

## 1.2 構建資料集裝載器
train_loader = DataLoader(dataset=train_data,
                          batch_size=100,
                          shuffle=True,
                          drop_last=False,
                          num_workers=4)

if __name__ == "__main__":
    print("===============資料統計===============")
    print("訓練集樣本:",train_data.__len__(), train_data.data.shape)

​ ​ 【程式碼解析】

  • root為存放MNIST的路徑,trian=True代表下載的為訓練集和訓練集標籤,False則代表測試集和標籤

  • transforms.ToTensor()表示將shape為(H, W, C)的 numpy 陣列或 img 轉為shape為(C, H, W)的tensor,並將數值歸一化為[0,1]

  • download為True則代表自動下載,若該資料夾下已經下載,則直接跳過下載步驟

  • shuffle=True,表示對分好的batch進行洗牌操作,drop_last=True表示對最後不足batch大小的剩餘樣本捨去,False表示保留

  • num_works表示每次讀取的程序數,和核心數有關

​ ​ Dataset和Dataloader詳細說明,請移步:[Pytorch Dataset和Dataloader 學習筆記(二)]

2. 設計網路結構

2.1 網路設計

​ ​ 網路結構如上圖所示,輸入影象—>卷積1—>池化1—>卷積2—>池化2—>全連線1—>全連線2—>softmax,每次卷積通道數都增加一倍,最後送入全連線層實現分類

2.2 程式部分

# 2. Design model using class
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_layer1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.max_pooling1 = nn.MaxPool2d(2)
        self.conv_layer2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.max_pooling2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(1568, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.max_pooling1(F.relu(self.conv_layer1(x)))
        x = self.max_pooling2(F.relu(self.conv_layer2(x)))
        x = x.view(-1, 32*7*7)
        x = F.relu(self.fc1(x))
        y_hat = self.fc2(x)     # CrossEntropyLoss會自動啟用最後一層的輸出以及softmax處理
        return y_hat

net = Net()

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

​ ​ 【程式碼解析】

  • fc1的1568維度是因為最後一次池化後的shape為32*7*7=1568

  • 在最後一層,並沒有進行relu啟用以及接入softmax,是因為,在CrossEntropyLoss中會自動啟用最後一層的輸出以及softmax處理

​ ​ CrossEntropyLoss圖參考:《PyTorch深度學習實踐》完結合集
​ ​ 詳細網路結構搭建說明,請移步:Pytorch線性規劃模型 學習筆記(一)

3. 迭代訓練

# 3. Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

# 4. Training
if __name__ == "__main__":
    print("Training...")
    for epoch in range(20):
        strat = time.time()
        total_correct = 0
        for x, y in train_loader:
            y_hat = net(x)
            y_pre = torch.argmax(y_hat, dim=1)
            total_correct += sum(torch.eq(y_pre, y))    # 統計當前epoch下的正確個數

            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        acc = (float(total_correct) / train_data.__len__())*100
        save_path = "model/net" + str(epoch+1) + ".pth"
        torch.save(obj=net.state_dict(), f=save_path)
        print("epoch:", str(epoch + 1) + "/20",
              " \n time:", "%.1f" % (time.time() - strat) + "s"    
              " train_loss:", loss.item(),
              " acc:%.3f%%" % acc,)

    print("we are done!")

​ ​ 【程式碼解析】

  • total_correct變數用於統計每個epoch下正確預測值的個數,每進行epoch進行一次清零
  • torch.argmax(y_hat, dim=1)用於選取y_hat下每一行的最大值(每個樣本的最高得分),並返回與y相同維度的tensor
  • torch.eq(y_pre, y)用於比較兩個矩陣元素是否相同,相同則返回True,不同則返回False,用於判斷預測值與真實值是否相同
  • torch.save儲存了每個epoch的網路權重引數

4. 測試集預測部分

# 測試模型,測試集為test_data

import torch
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from model import Net

test_data = datasets.MNIST(root='data\\',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=True)
test_loader = DataLoader(dataset=test_data,
                          batch_size=100,
                          shuffle=True,
                          drop_last=False,
                          num_workers=4)

if __name__ == "__main__":
    print("---------------預測分析---------------")
    print("測試集樣本:", test_data.__len__(), test_data.data.shape)
    model = Net()
    model.load_state_dict(torch.load("model/net20.pth"))
    model.eval()

    total_correct = 0
    for x, y in test_loader:
        y_hat = model(x)
        y_pre = torch.argmax(y_hat, dim=1)
        total_correct += sum(torch.eq(y_pre, y))

    acc = (float(total_correct) / test_data.__len__())*100
    print("total_test_samples:", test_data.__len__(),
          " test_acc:", "%.3f%%" % acc)

​ ​ 經過20個epoch的訓練,在測試集上達到了98.590%的準確率,部分batch真實值與預測值展示如下:

5. 全部程式碼

連結:連結:https://pan.baidu.com/s/1GGhG1Slw2Tlsgl13yzHUIw
提取碼:82l4

轉載請說明出處