基於PyTorch的CIFAR-10分類
作者:如縷清風
本文為博主原創,未經允許,請勿轉載:https://www.cnblogs.com/warren2123/articles/11823690.html
一、前言
本文基於Facebook的PyTorch框架,通過對VGGNet模型實現,對CIFAR-10資料集進行分類。
CIFAR-10資料集包含60000張 32x32的彩色圖片,共分為10種類別,每種類別6000張。其中訓練集包含50000張圖片,測試機包含10000張圖片。CIFAR-10的樣本圖如下所示。
二、基於PyTorch構建VGGNet模型
PyTorch與TensorFlow最大的不同是運用動態圖計算,並採用自動autograph的方法,大大方便了模型的構建。本文模型構建分為四個部分:資料讀取及預處理、構建VGGNet模型、定義模型超引數以及評估方法、引數優化。
1、資料讀取及預處理
本文采用GPU對PyTorch進行速度提升,如果不存在GPU,則會自動選擇CPU進行運算。
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
資料集的讀取利用PyTorch的torchvision庫,可以選擇將download引數改為True進行下載,由於本文已經下載好,所以定位為False。資料集提前採用正則化的方式進行預處理,分為訓練集和測試集,並採用生成器的方式載入資料,便於更好的處理大批量資料。classes為CIFAR-10資料集的10個標籤類別。
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10('./data', train=True, download=False, transform=transform) testset = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2) classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
2、構建VGGNet模型
VGGNet是通過對AlexNet的改進,進一步加深了卷積神經網路的深度,採用堆疊3 x 3的卷積層和2 x 2的降取樣層,實現11到19層的網路深度。VGG的結構圖如下所示。
VGGNet模型總的來說,分為VGG16和VGG19兩類,區別在於模型的層數不同,以下'M'引數代表池化層,資料代表各層濾波器的數量。
cfg = { 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] }
本文中定義VGGNet模型的為全連線層,卷積層中都運用批量歸一化的方法,提升模型的訓練速度與收斂效率,並且可以一定的代替dropout的作用,有利於模型的泛化效果。
class VGG(nn.Module): def __init__(self, vgg_name): super(VGG, self).__init__() self.features = self._make_layers(cfg[vgg_name]) self.classifier = nn.Linear(512, 10) def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out def _make_layers(self, cfg): layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers)
3、定義模型超引數以及評估方法
模型的學習率、訓練次數、批次大小通過超引數的方式設定,優化函式採用Adam,損失函式採用交叉熵進行計算。
LR = 0.001 EPOCHES = 20 BATCHSIZE = 100 net4 = VGG('VGG16') mlps = [net4.to(device)] optimizer = torch.optim.Adam([{"params": mlp.parameters()} for mlp in mlps], lr=LR) loss_function = nn.CrossEntropyLoss()
4、引數優化
以下是通過定義的訓練次數進行模型的引數優化過程,每一次訓練輸出模型的測試正確率。
for ep in range(EPOCHES): for img, label in trainloader: img, label = img.to(device), label.to(device) optimizer.zero_grad() for mlp in mlps: mlp.train() out = mlp(img) loss = loss_function(out, label) loss.backward() optimizer.step() pre = [] vote_correct = 0 mlps_correct = [0 for i in range(len(mlps))] for img, label in testloader: img, label = img.to(device), label.to(device) for i, mlp in enumerate(mlps): mlp.eval() out = mlp(img) _, prediction = torch.max(out, 1) pre_num = prediction.cpu().numpy() mlps_correct[i] += (pre_num == label.cpu().numpy()).sum() pre.append(pre_num) arr = np.array(pre) pre.clear() result = [Counter(arr[:, i]).most_common(1)[0][0] for i in range(BATCHSIZE)] vote_correct += (result == label.cpu().numpy()).sum() for idx, correct in enumerate(mlps_correct): print("Epoch:" + str(ep) + "VGG的正確率為:" + str(correct/len(testloader)))
訓練輸出如下所示:
Epoch:0 VGG的正確率為:57.67 Epoch:1 VGG的正確率為:67.13 Epoch:2 VGG的正確率為:74.84 Epoch:3 VGG的正確率為:79.59 Epoch:4 VGG的正確率為:79.93 Epoch:5 VGG的正確率為:82.61 Epoch:6 VGG的正確率為:82.96 Epoch:7 VGG的正確率為:84.31 Epoch:8 VGG的正確率為:82.43 Epoch:9 VGG的正確率為:85.12 Epoch:10 VGG的正確率為:84.33 Epoch:11 VGG的正確率為:83.66 Epoch:12 VGG的正確率為:82.02 Epoch:13 VGG的正確率為:85.44 Epoch:14 VGG的正確率為:84.08 Epoch:15 VGG的正確率為:85.67 Epoch:16 VGG的正確率為:84.87 Epoch:17 VGG的正確率為:85.21 Epoch:18 VGG的正確率為:84.62 Epoch:19 VGG的正確率為:85.88 Epoch:20 VGG的正確率為:83.46 Epoch:21 VGG的正確率為:86.63 Epoch:22 VGG的正確率為:85.75 Epoch:23 VGG的正確率為:86.29 Epoch:24 VGG的正確率為:83.33 Epoch:25 VGG的正確率為:86.48 Epoch:26 VGG的正確率為:85.6 Epoch:27 VGG的正確率為:86.66 Epoch:28 VGG的正確率為:85.45 Epoch:29 VGG的正確率為:85.65 Epoch:30 VGG的正確率為:86.36 Epoch:31 VGG的正確率為:86.27 Epoch:32 VGG的正確率為:85.09 Epoch:33 VGG的正確率為:85.6 Epoch:34 VGG的正確率為:86.82 Epoch:35 VGG的正確率為:85.76 Epoch:36 VGG的正確率為:86.59 Epoch:37 VGG的正確率為:85.56 Epoch:38 VGG的正確率為:85.71 Epoch:39 VGG的正確率為:86.07 Epoch:40 VGG的正確率為:84.87 Epoch:41 VGG的正確率為:85.91 Epoch:42 VGG的正確率為:86.8 Epoch:43 VGG的正確率為:87.43 Epoch:44 VGG的正確率為:85.99 Epoch:45 VGG的正確率為:86.32 Epoch:46 VGG的正確率為:86.72 Epoch:47 VGG的正確率為:86.39 Epoch:48 VGG的正確率為:86.08 Epoch:49 VGG的正確率為:86.97
三、總結
本文基於PyTorch構建的VGG模型,在CIFAR-10中分類效果達到86.97%,最高達到87.43%的分類準確率,當然後續可以進一步調整超引數優化模型,也可以運用多模型架構。通過細分各類別的準確率,可以看出模型在dog類別準確率較低,在truck類別準確率較高。
Accuracy of airplane : 90 % Accuracy of automobile : 90 % Accuracy of bird : 90 % Accuracy of cat : 82 % Accuracy of deer : 88 % Accuracy of dog : 71 % Accuracy of frog : 93 % Accuracy of horse : 85 % Accuracy of ship : 81 % Accuracy of truck : 95 %
基於PyTorch的構建,能夠從中體會到Python之禪的哲學,簡潔、方便等。相信這將是深度學習的一大助力,當然這也是因人而異。