1. 程式人生 > 其它 >Pytorch之MNIST資料集的訓練和測試

Pytorch之MNIST資料集的訓練和測試

技術標籤:深度學習神經網路pythonpytorch

訓練和測試的完整程式碼:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import argparse
import os


# 訓練
def train(args, model, device, train_loader, optimizer)
: model.train() num_correct = 0 for batch_index, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) # forward outputs = model(images) loss = F.cross_entropy(outputs, labels) # backward optimizer.zero_grad
() # 梯度清空 loss.backward() # 梯度回傳,更新引數 optimizer.step() _, predicted = torch.max(outputs, dim=1) # 每一個batch預測對的個數 batch_correct = (predicted == labels).sum().item() # 每一個batch的準確率 batch_accuracy = batch_correct / args.batch_size # 每一個epoch預測對的總個數 num_correct +
= (predicted == labels).sum().item() # print sth. print(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}, Batch accuracy:{batch_accuracy*100}%') # 每一個epoch的準確率 epoch_accuracy = num_correct / len(train_loader.dataset) # print epoch_accuracy print(f'Epoch Accuracy:{epoch_accuracy}') # 儲存模型 if epoch % args.checkpoint_interval == 0: torch.save(model.state_dict(), f"checkpoints/VGG16_MNIST_%d.pth" % epoch) # 驗證 def test(args, model, device, test_loader): model.eval() total_loss = 0 num_correct = 0 if args.pretrained_weights.endswith(".pth"): model.load_state_dict(torch.load(args.pretrained_weights)) # 不計算梯度,節省計算資源 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device) output = model(images) # 總的loss total_loss += F.cross_entropy(output, labels).item() # item()用於取出tensor裡邊的值 # torch.max():返回的是兩個值,第一個值是具體的value,第二個值是value所在的index _, predicted = torch.max(output, dim=1) # 預測對的總個數 num_correct += (predicted == labels).sum().item() # 平均loss test_loss = total_loss / len(test_loader.dataset) # 平均準確率 accuracy = num_correct / len(test_loader.dataset) # print sth. print(f'Average loss:{test_loss}\nTest Accuracy:{accuracy*100}%') if __name__ == '__main__': parser = argparse.ArgumentParser(description = 'Pytorch-MNIST_classification') parser.add_argument('--epochs', type=int, default=20, help='number of epochs') parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch' ) parser.add_argument('--num_classes', type=int, default=10, help='number of classes') parser.add_argument('--lr', type=float, default=0.001, help='learning rate') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') parser.add_argument('--pretrained_weights', type=str, default='checkpoints/', help='pretrained weights') parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension") parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights") parser.add_argument("--train", default=False, help="train or test") args = parser.parse_args() print(args) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # os.makedirs() 方法用於遞迴建立目錄 os.makedirs("output", exist_ok=True) os.makedirs("checkpoints", exist_ok=True) # transform data_transform = transforms.Compose([transforms.ToTensor(), transforms.RandomResizedCrop(args.img_size)]) # 下載訓練資料 train_data = datasets.MNIST(root = 'data', train = True, transform = data_transform, target_transform = None, download = True) # 下載測試資料 test_data = datasets.MNIST(root = 'data', train = False, transform = data_transform, target_transform = None, download = True) # 載入訓練資料 train_loader = DataLoader(dataset = train_data, batch_size = args.batch_size, shuffle = True) # 載入測試資料 test_loader = DataLoader(dataset = test_data, batch_size = args.batch_size) # 建立模型 model = models.vgg16(pretrained = True) # 修改vgg16的輸出維度 model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True) # MNIST資料集是灰度圖,channel數為1 model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) print(model) model = model.to(device) # 優化器(也可以選擇其他優化器) optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum) # optimizer = torch.optim.Adam() if args.train == True: for epoch in range(1, args.epochs+1): # 是否載入預訓練好的權重 if args.pretrained_weights.endswith(".pth"): model.load_state_dict(torch.load(args.pretrained_weights)) train(args, model, device, train_loader, optimizer) else: # 是否載入預訓練好的權重 if args.pretrained_weights.endswith(".pth"): model.load_state_dict(torch.load(args.pretrained_weights)) test(args, model, device, test_loader)

測試結果:
在這裡插入圖片描述我只訓練了不到10輪,效果不是太好,還有提升空間。

說明:
MNIST資料集可以通過trochvision中的datasets.MNIST下載,也可以自己下載(注意存放路徑);我模型使用的是torchvision中的models中預訓練好的vgg16網路,也可以自己搭建網路。