1. 程式人生 > 其它 >【pytorch學習筆記3】pytorch實現手寫數字識別

【pytorch學習筆記3】pytorch實現手寫數字識別

技術標籤:pytorch學習筆記深度學習神經網路

前言

我們來用手寫數字這個入門案例,拿它來熟悉一下pytorch

正文

講解連結如下:
https://www.bilibili.com/video/BV1fA411e7ad?p=13

本文程式碼的網路結構與視訊講解結構不同的是,本文將網路換成了LeNet網路

實現程式碼:

import numpy as np
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import
torch.nn as nn import torch.nn.functional as F import torch.optim as optim # 定義超引數 Epochs = 3 learning_rate = 0.01 batch_size_train = 64 batch_size_test = 1000 # 1、準備資料集(其中0.1307和0.3081是MNIST資料集的全域性平均值和標準偏差) train_loader = DataLoader(MNIST(root='./data/', train=True, download=True, transform=
torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size_train, shuffle=True) test_loader = DataLoader(MNIST(
root='./data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size_test, shuffle=True) # 檢視一下test_loader的size與target example = enumerate(test_loader) batch_idx, (example_data, example_targets) = next(example) #讀取 print(example_data.shape) # print(example_targets) # 2、構建網路,此處用手寫數字的LeNet網路結構 class MnistNet(nn.Module): def __init__(self): super(MnistNet, self).__init__() self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(), nn.MaxPool2d(2, 2)) self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2)) self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120), nn.BatchNorm1d(120), nn.ReLU()) self.fc2 = nn.Sequential( nn.Linear(120, 84), nn.BatchNorm1d(84), nn.ReLU(), nn.Linear(84, 10)) def forward(self, input): x = self.conv1(input) x = self.conv2(x) x = x.view(x.size()[0], -1) #對引數實現扁平化(便於後面全連線層輸入) x = self.fc1(x) out = self.fc2(x) return F.log_softmax(out) # 初始化網路與優化器 model = MnistNet() optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 已存在模型則載入模型 if os.path.exists("./model/model.pkl"): model.load_state_dict(torch.load("./model/model.pkl")) optimizer.load_state_dict(torch.load("./model/optimizer.pkl")) # 3、訓練 def train(epoch): for idx, (input, target) in enumerate(train_loader): optimizer.zero_grad() # 手動將梯度設定為零,因為PyTorch在預設情況下會累積梯度 output = model(input) loss = F.nll_loss(output, target) # 得到交叉熵損失 loss.backward() optimizer.step() #梯度更新 if idx % 50 == 0: print('Train Epoch: {} idx:{}\tLoss: {:.6f}'.format( epoch, idx, loss.item())) # 模型的儲存 if idx % 100 == 0: torch.save(model.state_dict(), './model/model.pkl') torch.save(optimizer.state_dict(), './model/optimizer.pkl') # 4、模型評估(測試) def test(): loss_list = [] acc_list = [] for idx, (input, target) in enumerate(test_loader): with torch.no_grad(): #測試無需進行梯度下降操作 output = model(input) cur_loss = F.nll_loss(output, target) loss_list.append(cur_loss) predict = output.max(dim=-1)[-1] cur_acc = predict.eq(target).float().mean() acc_list.append(cur_acc) print("平均準確率:%.4f" % np.mean(acc_list), "平均損失率:%.4f" % np.mean(loss_list)) if __name__ == '__main__': for i in range(3): train(i) test() # 評估

用訓練的模型進行評估(測試)
在這裡插入圖片描述