1. 程式人生 > 其它 >基於PyTorch的CIFAR-10分類

基於PyTorch的CIFAR-10分類

作者:如縷清風

本文為博主原創,未經允許,請勿轉載:https://www.cnblogs.com/warren2123/articles/11823690.html


一、前言

本文基於Facebook的PyTorch框架,通過對VGGNet模型實現,對CIFAR-10資料集進行分類。

CIFAR-10資料集包含60000張 32x32的彩色圖片,共分為10種類別,每種類別6000張。其中訓練集包含50000張圖片,測試機包含10000張圖片。CIFAR-10的樣本圖如下所示。

二、基於PyTorch構建VGGNet模型

PyTorch與TensorFlow最大的不同是運用動態圖計算,並採用自動autograph的方法,大大方便了模型的構建。本文模型構建分為四個部分:資料讀取及預處理、構建VGGNet模型、定義模型超引數以及評估方法、引數優化。

1、資料讀取及預處理

本文采用GPU對PyTorch進行速度提升,如果不存在GPU,則會自動選擇CPU進行運算。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

資料集的讀取利用PyTorch的torchvision庫,可以選擇將download引數改為True進行下載,由於本文已經下載好,所以定位為False。資料集提前採用正則化的方式進行預處理,分為訓練集和測試集,並採用生成器的方式載入資料,便於更好的處理大批量資料。classes為CIFAR-10資料集的10個標籤類別。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10('./data', train=True, download=False, transform=transform)
testset = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

2、構建VGGNet模型

VGGNet是通過對AlexNet的改進,進一步加深了卷積神經網路的深度,採用堆疊3 x 3的卷積層和2 x 2的降取樣層,實現11到19層的網路深度。VGG的結構圖如下所示。

VGGNet模型總的來說,分為VGG16和VGG19兩類,區別在於模型的層數不同,以下'M'引數代表池化層,資料代表各層濾波器的數量。

cfg = {
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

本文中定義VGGNet模型的為全連線層,卷積層中都運用批量歸一化的方法,提升模型的訓練速度與收斂效率,並且可以一定的代替dropout的作用,有利於模型的泛化效果。

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)
    
    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out
    
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

3、定義模型超引數以及評估方法

模型的學習率、訓練次數、批次大小通過超引數的方式設定,優化函式採用Adam,損失函式採用交叉熵進行計算。

LR = 0.001
EPOCHES = 20
BATCHSIZE = 100

net4 = VGG('VGG16')
mlps = [net4.to(device)]
optimizer = torch.optim.Adam([{"params": mlp.parameters()} for mlp in mlps], lr=LR)
loss_function = nn.CrossEntropyLoss()

4、引數優化

以下是通過定義的訓練次數進行模型的引數優化過程,每一次訓練輸出模型的測試正確率。

for ep in range(EPOCHES):
    for img, label in trainloader:
        img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        for mlp in mlps:
            mlp.train()
            out = mlp(img)
            loss = loss_function(out, label)
            loss.backward()
        optimizer.step()
    
    pre = []
    vote_correct = 0
    mlps_correct = [0 for i in range(len(mlps))]
    for img, label in testloader:
        img, label = img.to(device), label.to(device)
        for i, mlp in enumerate(mlps):
            mlp.eval()
            out = mlp(img)
            _, prediction = torch.max(out, 1)
            pre_num = prediction.cpu().numpy()
            mlps_correct[i] += (pre_num == label.cpu().numpy()).sum()
            pre.append(pre_num)
        arr = np.array(pre)
        pre.clear()
        result = [Counter(arr[:, i]).most_common(1)[0][0] for i in range(BATCHSIZE)]
        vote_correct += (result == label.cpu().numpy()).sum()
    
    for idx, correct in enumerate(mlps_correct):
        print("Epoch:" + str(ep) + "VGG的正確率為:" + str(correct/len(testloader)))

訓練輸出如下所示:

Epoch:0 VGG的正確率為:57.67
Epoch:1 VGG的正確率為:67.13
Epoch:2 VGG的正確率為:74.84
Epoch:3 VGG的正確率為:79.59
Epoch:4 VGG的正確率為:79.93
Epoch:5 VGG的正確率為:82.61
Epoch:6 VGG的正確率為:82.96
Epoch:7 VGG的正確率為:84.31
Epoch:8 VGG的正確率為:82.43
Epoch:9 VGG的正確率為:85.12
Epoch:10 VGG的正確率為:84.33
Epoch:11 VGG的正確率為:83.66
Epoch:12 VGG的正確率為:82.02
Epoch:13 VGG的正確率為:85.44
Epoch:14 VGG的正確率為:84.08
Epoch:15 VGG的正確率為:85.67
Epoch:16 VGG的正確率為:84.87
Epoch:17 VGG的正確率為:85.21
Epoch:18 VGG的正確率為:84.62
Epoch:19 VGG的正確率為:85.88
Epoch:20 VGG的正確率為:83.46
Epoch:21 VGG的正確率為:86.63
Epoch:22 VGG的正確率為:85.75
Epoch:23 VGG的正確率為:86.29
Epoch:24 VGG的正確率為:83.33
Epoch:25 VGG的正確率為:86.48
Epoch:26 VGG的正確率為:85.6
Epoch:27 VGG的正確率為:86.66
Epoch:28 VGG的正確率為:85.45
Epoch:29 VGG的正確率為:85.65
Epoch:30 VGG的正確率為:86.36
Epoch:31 VGG的正確率為:86.27
Epoch:32 VGG的正確率為:85.09
Epoch:33 VGG的正確率為:85.6
Epoch:34 VGG的正確率為:86.82
Epoch:35 VGG的正確率為:85.76
Epoch:36 VGG的正確率為:86.59
Epoch:37 VGG的正確率為:85.56
Epoch:38 VGG的正確率為:85.71
Epoch:39 VGG的正確率為:86.07
Epoch:40 VGG的正確率為:84.87
Epoch:41 VGG的正確率為:85.91
Epoch:42 VGG的正確率為:86.8
Epoch:43 VGG的正確率為:87.43
Epoch:44 VGG的正確率為:85.99
Epoch:45 VGG的正確率為:86.32
Epoch:46 VGG的正確率為:86.72
Epoch:47 VGG的正確率為:86.39
Epoch:48 VGG的正確率為:86.08
Epoch:49 VGG的正確率為:86.97

三、總結

本文基於PyTorch構建的VGG模型,在CIFAR-10中分類效果達到86.97%,最高達到87.43%的分類準確率,當然後續可以進一步調整超引數優化模型,也可以運用多模型架構。通過細分各類別的準確率,可以看出模型在dog類別準確率較低,在truck類別準確率較高。

Accuracy of   airplane : 90 %
Accuracy of automobile : 90 %
Accuracy of       bird : 90 %
Accuracy of        cat : 82 %
Accuracy of       deer : 88 %
Accuracy of        dog : 71 %
Accuracy of       frog : 93 %
Accuracy of      horse : 85 %
Accuracy of       ship : 81 %
Accuracy of      truck : 95 %

基於PyTorch的構建,能夠從中體會到Python之禪的哲學,簡潔、方便等。相信這將是深度學習的一大助力,當然這也是因人而異。