1. 程式人生 > >【Pytorch】CIFAR-10分類任務

【Pytorch】CIFAR-10分類任務

CIFAR-10資料集共有60000張32*32彩色圖片,分為10類,每類有6000張圖片。其中50000張用於訓練,構成5個訓練batch,每一批次10000張圖片,其餘10000張圖片用於測試。

CIFAR-10資料集下載地址:點選下載

資料讀取,這裡選擇下載python版本的資料集,解壓後得到如下檔案:

其中data_batch_1~data_batch_5為訓練集的5個批次,test_batch為測試集。

這些檔案是python的序列化模型,這裡使用python3,可以使用pickle模組讀取這些資料:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

    每一個batch檔案包括一個字典,字典的元素是:
data:一個尺寸為10000*3072,資料格式為uint8的numpy array,每
一行資料儲存了一張32*32彩色圖片的資料,前1024位是影象的紅色
通道資料,接著是綠色通道和藍色通道。

label:一個包含10000個0-9數字的列表,對應data裡每張圖片的標籤。

       

    此外,資料集中還有一個batches.meta檔案,它儲存了一個python字典,
該字典對標籤的10個數字0-9所代表的意義做了解釋,比如0代表airplane,
1代表automobile。

這次使用Pytorch框架來進行實驗,總體流程是,建立網路(這次小demo用Lenet),自定義資料集讀取框架,雖然pytorch已經有關於cifar10的Dataset例項,但還是自己實現了一遍,接著用DataLoader分批讀取資料集,定義損失函式和優化器,進行批次訓練。

import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#預設引數
CLASS_NUM = 10
BATCH_SIZE = 128
EPOCH = 30

#Lenet網路程式碼
class Lenet(nn.Module):
    def __init__(self):
        super(Lenet,self).__init__()
        #定義網路層
        #入通道數,出通道數,卷積尺寸
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    
    #將二維資料展開成一維資料以輸入到全連線層
    def num_flat_features(self,x):
        #size為[batch_size,num_channels,height,width]
        #除去batch_size,num_channels*height*width就是展開後維度
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features = num_features*s
        return num_features
    
    def forward(self,x):
        #定義前向傳播
        #輸入 和 視窗尺寸
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

#從原始檔讀取資料
#返回 train_data[50000,3072]和labels[50000]
#    test_data[10000,3072]和labels[10000]
def get_data(train=False):
    data = None
    labels = None
    if train == True:
        for i in range(1,6):    
            batch = unpickle('data/cifar-10-batches-py/data_batch_'+str(i))
            if i == 1:
                data = batch[b'data']
            else:
                data = np.concatenate([data,batch[b'data']])

            if i == 1:
                labels = batch[b'labels']
            else:
                labels = np.concatenate([labels,batch[b'labels']])
    else:
        batch = unpickle('data/cifar-10-batches-py/test_batch')
        data = batch[b'data']
        labels = batch[b'labels']
    return data,labels

#影象預處理函式,Compose會將多個transform操作包在一起
#對於彩色影象,色彩通道不存在平穩特性
transform = transforms.Compose([
    # ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 
    # 從0到255的值對映到0到1的範圍內,並轉化成Tensor格式。
    transforms.ToTensor(), 
    #Normalize函式將影象資料歸一化到[-1,1]
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]
)

#將標籤轉換為torch.LongTensor
def target_transform(label):
    label = np.array(label)
    target = torch.from_numpy(label).long()
    return target

'''
自定義資料集讀取框架來載入cifar10資料集
需要繼承data.Dataset
'''
class Cifar10_Dataset(Data.Dataset):
    def __init__(self,train=True,transform=None,target_transform=None):
        #初始化檔案路徑
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        #載入訓練資料集
        if self.train:
            self.train_data,self.train_labels = get_data(train)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            # 將影象資料格式轉換為[height,width,channels]方便預處理
            self.train_data = self.train_data.transpose((0, 2, 3, 1)) 
        #載入測試資料集
        else:
            self.test_data,self.test_labels = get_data()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))
        pass
    def __getitem__(self, index):
        #從資料集中讀取一個數據並對資料進行
        #預處理返回一個數據對,如(data,label)
        if self.train:
            img, label = self.train_data[index], self.train_labels[index]
        else:
            img, label = self.test_data[index], self.test_labels[index]
        
        img = Image.fromarray(img)
        #影象預處理
        if self.transform is not None:
            img = self.transform(img)
        #標籤預處理
        if self.target_transform is not None:
            target = self.target_transform(label)
 
        return img, target
    def __len__(self):
        #返回資料集的size
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

if __name__ == '__main__':
    #讀取訓練集和測試集
    train_data = Cifar10_Dataset(True,transform,target_transform)
    print('size of train_data:{}'.format(train_data.__len__()))
    test_data = Cifar10_Dataset(False,transform,target_transform)
    print('size of test_data:{}'.format(test_data.__len__()))
    train_loader = Data.DataLoader(dataset=train_data, batch_size = BATCH_SIZE, shuffle=True)

    net = Lenet()
    optimizer = optim.Adam(net.parameters(), lr = 0.001, betas=(0.9, 0.99))
    #在使用CrossEntropyLoss時target直接使用類別索引,不適用one-hot
    loss_fn = nn.CrossEntropyLoss()

    loss_list = []
    for epoch in range(1,EPOCH+1):
        #訓練部分
        for step,(x,y) in enumerate(train_loader):
            b_x = Variable(x)
            b_y = Variable(y)
            output = net(b_x)
            loss = loss_fn(output,b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #記錄loss
            if step%50 == 0:
                loss_list.append(loss)
        #每完成一個epoch進行一次測試觀察效果
        pre_correct = 0.0
        test_loader = Data.DataLoader(dataset=test_data, batch_size = 100, shuffle=True)
        for (x,y) in (test_loader):
            b_x = Variable(x)
            b_y = Variable(y)
            output = net(b_x)
            pre = torch.max(output,1)[1]
            pre_correct = pre_correct+float(torch.sum(pre==b_y))
        
        print('EPOCH:{epoch},ACC:{acc}%'.format(epoch=epoch,acc=(pre_correct/float(10000))*100))
 
    #儲存網路模型
    torch.save(net,'lenet_cifar_10.model')
    #繪製loss變化曲線
    plt.plot(loss_list)
    plt.show()

第一個pytorch demo跑通了,但是訓練模型效果很不好,應該是Lenet作用於Cifar10有些過於力不從心了,剛開始接觸深度學習的影象領域還不怎麼懂,下次換一個更強大的網路。