【pytorch學習筆記3】pytorch實現手寫數字識別
阿新 • • 發佈:2021-01-19
技術標籤:pytorch學習筆記深度學習神經網路
前言
我們來用手寫數字這個入門案例,拿它來熟悉一下pytorch
正文
講解連結如下:
https://www.bilibili.com/video/BV1fA411e7ad?p=13
本文程式碼的網路結構與視訊講解結構不同的是,本文將網路換成了LeNet網路
實現程式碼:
import numpy as np
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 定義超引數
Epochs = 3
learning_rate = 0.01
batch_size_train = 64
batch_size_test = 1000
# 1、準備資料集(其中0.1307和0.3081是MNIST資料集的全域性平均值和標準偏差)
train_loader = DataLoader(MNIST(root='./data/', train=True, download=True,
transform= torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size_train, shuffle=True)
test_loader = DataLoader(MNIST( root='./data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size_test, shuffle=True)
# 檢視一下test_loader的size與target
example = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(example) #讀取
print(example_data.shape)
# print(example_targets)
# 2、構建網路,此處用手寫數字的LeNet網路結構
class MnistNet(nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84, 10))
def forward(self, input):
x = self.conv1(input)
x = self.conv2(x)
x = x.view(x.size()[0], -1) #對引數實現扁平化(便於後面全連線層輸入)
x = self.fc1(x)
out = self.fc2(x)
return F.log_softmax(out)
# 初始化網路與優化器
model = MnistNet()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 已存在模型則載入模型
if os.path.exists("./model/model.pkl"):
model.load_state_dict(torch.load("./model/model.pkl"))
optimizer.load_state_dict(torch.load("./model/optimizer.pkl"))
# 3、訓練
def train(epoch):
for idx, (input, target) in enumerate(train_loader):
optimizer.zero_grad() # 手動將梯度設定為零,因為PyTorch在預設情況下會累積梯度
output = model(input)
loss = F.nll_loss(output, target) # 得到交叉熵損失
loss.backward()
optimizer.step() #梯度更新
if idx % 50 == 0:
print('Train Epoch: {} idx:{}\tLoss: {:.6f}'.format(
epoch, idx, loss.item()))
# 模型的儲存
if idx % 100 == 0:
torch.save(model.state_dict(), './model/model.pkl')
torch.save(optimizer.state_dict(), './model/optimizer.pkl')
# 4、模型評估(測試)
def test():
loss_list = []
acc_list = []
for idx, (input, target) in enumerate(test_loader):
with torch.no_grad(): #測試無需進行梯度下降操作
output = model(input)
cur_loss = F.nll_loss(output, target)
loss_list.append(cur_loss)
predict = output.max(dim=-1)[-1]
cur_acc = predict.eq(target).float().mean()
acc_list.append(cur_acc)
print("平均準確率:%.4f" % np.mean(acc_list), "平均損失率:%.4f" % np.mean(loss_list))
if __name__ == '__main__':
for i in range(3):
train(i)
test() # 評估
用訓練的模型進行評估(測試)