1. 程式人生 > >Pytorch tutorials 實戰教程(1)——訓練自己的資料集(程式碼詳解)

Pytorch tutorials 實戰教程(1)——訓練自己的資料集(程式碼詳解)

最開始入坑的時候使用的是caffe,前一段時間換了使用主流框架的keras(Tensorflow as backward),但是keras確實封裝得太好了,一個高階的API對於我這種程式設計渣渣來說反而上手有些不習慣,在寫了一段時間的程式碼以後開始使用pytorch(反正老闆要求了兩個框架都要熟練那就都學啦),對於原始碼部分確實友好了很多,儘管需要自己定義前向過程但是也很簡單啦~

**

一、訓練torchvision自帶資料集:

**

搭建網路、訓練torchvision裡面自帶的資料集都是easy stuff,這個tutorials儘量記錄我在實際程式碼中遇到的稍微要費點兒精力的事情。

首先是如何訓練資料集,如果訓練torchvision裡自帶的資料集非常簡單,只需要使用torchvision.datasets直接進行讀取,再例項化torch.utils.data.DataLoader(規定好batch_size以及是否進行shuffle),在訓練時使用enumerate列舉函式匯入資料,也可以用以下程式碼檢視是否匯入資料成功顯示圖片:

for i, data in enumerate(dataLoader, 0):  
    print(data[i][0])  
    # PIL  
    img = transforms.ToPILImage()(data[i][0])  
img.show() break

完整的程式碼如下:

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from logger import Logger

# 定義超引數
batch_size = 128
learning_rate = 1e-2 num_epoches = 20 def to_np(x): return x.cpu().data.numpy() # download datasets train_dataset = datasets.CIFAR10( root='./cifar_data', train=True, transform=transforms.ToTensor(), download=True) test_dataset = datasets.CIFAR10( root='./cifar_data', train=False, transform=transforms.ToTensor()) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) #define model class slice_ssc(nn.Module): def __init__(self,in_channel,n_class): super(slice_ssc,self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channel,32,3,1,1), nn.ReLU(True), nn.MaxPool2d(2)) self.conv2 = nn.Sequential( nn.Conv2d(32,64,3,1,1), nn.ReLU(True), nn.MaxPool2d(2)) self.fc = nn.Sequential( nn.Linear(64*8*8,128), nn.Linear(128,64), nn.Linear(64,n_class)) def forward(self,x): conv1_out = self.conv1(x) conv2_out = self.conv2(conv1_out) conv2_out = conv2_out.view(conv2_out.size(0),-1) out = self.fc(conv2_out) return out model = slice_ssc(1,10) print model use_gpu = torch.cuda.is_available() # 判斷是否有GPU加速 if use_gpu: model = model.cuda() # 定義loss和optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=learning_rate) logger = Logger('./logs') #training for epoch in range(num_epoches): print 'epoch {}'.format(epoch+1) train_loss=0.0 train_acc=0.0 #==========training============ for i,data in enumerate(train_loader,1): img,label=data img=img.view(img.size(0)*3,1,32,32) label = torch.cat((label,label,label),0) #print img.size() #print label.size() if use_gpu: img = img.cuda() label = label.cuda() img = Variable(img) label = Variable(label) #forward out = model(img) loss = criterion(out,label) train_loss += loss.data[0] #*label.size(0) _, pred = torch.max(out,1) train_correct = (pred == label).sum() accuracy = (pred == label).float().mean() train_acc += train_correct.data[0] #backward optimizer.zero_grad() loss.backward() optimizer.step() #=============log=============== step = epoch*len(train_loader)+i info = {'loss':loss.data[0],'accuracy':accuracy.data[0]} for tag, value in info.items(): logger.scalar_summary(tag, value, step) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') logger.histo_summary(tag, to_np(value), step) logger.histo_summary(tag + '/grad', to_np(value.grad), step) info = {'images': to_np(img.view(-1, 32, 32)[:10])} for tag, images in info.items(): logger.image_summary(tag, images, step) if i % 300 == 0: print '[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format( epoch + 1, num_epoches, train_loss / (batch_size * i), train_acc / (batch_size * i)) print 'Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format( epoch + 1, train_loss / (len(train_dataset)), train_acc / (len( train_dataset))) #============testing============= model.eval() eval_loss = 0.0 eval_acc = 0.0 for data in test_loader: img,label = data img=img.view(img.size(0)*3,1,32,32) label = torch.cat((label,label,label),0) if use_gpu: img = Variable(img,volatile=True).cuda() label = Variable(label,volatile=True).cuda() else: img = Variable(img, volatile=True) label = Variable(label, volatile=True) out = model(img) loss = criterion(out, label) eval_loss += loss.data[0] * label.size(0) _, pred = torch.max(out, 1) num_correct = (pred == label).sum() eval_acc += num_correct.data[0] print 'Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len( test_dataset)), eval_acc / (len(test_dataset))) # 儲存模型 torch.save(model.state_dict(), './cnn.pth')

其中儲存log日誌的logger.py程式碼為:

import tensorflow as tf
import numpy as np
import scipy.misc
try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x


class Logger(object):

    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.FileWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag,
                                                     simple_value=value)])
        self.writer.add_summary(summary, step)

    def image_summary(self, tag, images, step):
        """Log a list of images."""

        img_summaries = []
        for i, img in enumerate(images):
            # Write the image to a string
            try:
                s = StringIO()
            except:
                s = BytesIO()
            scipy.misc.toimage(img).save(s, format="png")

            # Create an Image object
            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
                                       height=img.shape[0],
                                       width=img.shape[1])
            # Create a Summary value
            img_summaries.append(
                tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

        # Create and write Summary
        summary = tf.Summary(value=img_summaries)
        self.writer.add_summary(summary, step)

    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""

        # Create a histogram using numpy
        counts, bin_edges = np.histogram(values, bins=bins)

        # Fill the fields of the histogram proto
        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values**2))

        # Drop the start of the first bin
        bin_edges = bin_edges[1:]

        # Add bin edges and counts
        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        # Create and write Summary
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
        self.writer.add_summary(summary, step)
        self.writer.flush()

**

二、訓練自己的資料集:

**

1、Dataset class:

**
torch.utils.data.Dataset : 是一個表達dataset的抽象類,需要繼承Dataset類,並進行override,最重要的複寫類中的幾個函式如下:

(1) __init__ : 讀各種格式的資料集、路徑等,控制傳入引數
(2) __getitem__ : 使dataset[i]能夠獲得第i個樣本資料,即匯入具體資料
(3) __len__ : len(dataset) returns the size of the dataset

完整程式碼例項如下:

def default_loader(path):
    return Image.open(path).convert('RGB')

############# Dataset ############
class myImageFloder(data.Dataset):
    def __init__(self,root,image_path,label_path,transform = None,target_transform = None,loader = default_loader):
        f_img = open(image_path)
        f_label = open(label_path)

        #c = 0
        imgs = []
        img_names = []
        label_names = []

        for line in f_img.readlines():
            cls = line.split()
            img_name = cls.pop(1)
            img_names.append(img_name)

            #read image
            if os.path.isfile(os.path.join(root,img_name)):
                imgs.append((img_name,tuple([float(v) for v in cls])))

        for line in f_label.readlines():
            cls = line.split()
            label_name = cls.pop(1)
            label_names.append(label_name)     

        self.root = root
        self.imgs = imgs
        self.img_names = img_names
        self.lable_names = label_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self,index):
        img_name,label_name = self.imgs[index]
        img = self.loader(os.path.join(self.root,img_name))
        if self.transform is not None:
            img = self.transform(img)
        return img,torch.Tensor(label)

    def __len__(self):
        return len(self.imgs)

**

2.Transform:

**
需要用一些轉化函式對輸入的影象對做轉換變化,常用函式如下:

rescale:scale the image
randomcrop:crop from image randomly,for data augmentation
ToTensor:convert the numpy image to torch image

例如如下完整程式碼定義:

########### Transform ############
mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

**

3.例項化DataLoader:

**
這一步是為了將上面得到的資料做處理:Batch the data、Shuffle the data、load the data in parallel using multiprocessing workers.並且對trainloader、testloader單獨進行例項化。
完整程式碼例項如下:

########## Dataloader ############
trainloader = torch.utils.data.DataLoader(
    myFloder.myImageFloder(root = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images',
                           image_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images_train.txt',
                           label_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/image_class_labels_train.txt',
                           transform = mytransform),
    batch_size = 24,shuffle = True,num_workers = 2)
print("TrainLoader success...")

testloader = torch.utils.data.DataLoader(
    myFloder.myImageFloder(root = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images',
                          image_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images_test.txt',
                          label_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/image_class_labels_test.txt',
                          transform = mytransform),
    batch_size = 24,shuffle = False,num_workers = 2)

print("TestLoader success...")

相關推薦

FCN訓練自己資料person-segmentation、SIFT-FLOW、SBD和VOC實驗總結

最近花了將近一週的時間,基於提供的原始碼,通過參考網上的部落格,跑通了FCN在三個資料集上的訓練以及測試。在這裡寫下總結,即是記錄,又希望能夠對其他剛剛接觸FCN的人有所幫助。 FCN的原始碼地址:https://github.com/shelhamer/fcn.berkeleyvision.o

FastRCNN 訓練自己資料——修改讀寫介面

這裡樓主講解了如何修改Fast RCNN訓練自己的資料集,首先請確保你已經安裝好了Fast RCNN的環境,具體的編配編制操作請參考我的上一篇文章。首先可以看到fast rcnn的工程目錄下有個Lib目錄這裡下面存在3個目錄分別是:datasetsfast_rcnnroi_d

FastRCNN 訓練自己資料——編譯配置

FastRCNN是Ross Girshick在RCNN的基礎上增加了Multi task training整個的訓練過程和測試過程比RCNN快了許多。別的一些細節不展開,過幾天會上傳Fast RCNN的論文筆記。FastRCNN mAP效能上略有上升。Fast RCNN中,提取OP的過程和訓練過程仍

使用deeplabv3+訓練自己資料遷移學習

# 概述 在前邊一篇文章,我們講了如何復現論文程式碼,使用pascal voc 2012資料集進行訓練和驗證,具體內容可以參考[《deeplab v3+在pascal_voc 2012資料集上進行訓練》](https://www.vcjmhg.top/train-deeplabv3-puls-with-pa

Pytorch tutorials 實戰教程1——訓練自己資料程式碼

最開始入坑的時候使用的是caffe,前一段時間換了使用主流框架的keras(Tensorflow as backward),但是keras確實封裝得太好了,一個高階的API對於我這種程式設計渣渣來說反而上手有些不習慣,在寫了一段時間的程式碼以後開始使用py

caffe練習例項1——訓練mnist資料

1.簡介 這是一個非常簡單的例項,主要是為了這個簡單的例項瞭解caffe的工作流程。 2.操作流程 1.獲取資料 在caffe-master/data/mnist資料夾中只有一

SSD: Single Shot MultiBox Detector 訓練KITTI資料1

前言 之前介紹了SSD的基本用法和檢測單張圖片的方法,那麼本篇部落格將詳細記錄如何使用SSD檢測框架訓練KITTI資料集。SSD專案中自帶了用於訓練PASCAL VOC資料集的指令碼,基本不用做修改就可以輕鬆完成訓練;但是想要訓練其他資料集比如KITTI,則

TensorFlow訓練MNIST資料3 —— 卷積神經網路

  前面兩篇隨筆實現的單層神經網路 和多層神經網路, 在MNIST測試集上的正確率分別約為90%和96%。在換用多層神經網路後,正確率已有很大的提升。這次將採用卷積神經網路繼續進行測試。 1、模型基本結構   如下圖所示,本次採用的模型共有8層(包含dropout層)。其中卷積層和池化層各有兩層。   在

YOLOv2目標檢測_單目標_訓練自己資料全過程自用

1.    製作符合要求的VOC資料集 目標:製作如下格式的資料夾 格式: --VOC2017(大寫字母+數字) --Annotations(存放儲存標註資訊的xml) --ImageSets --Main(存放儲存圖片名的train.txttest.txt) --Layo

搜索引擎系列八:solr-部署solr兩種部署模式介紹、獨立服務器模式、SolrCloud分布式群模式

nod 為什麽 用途 serve creat 復制 stand 數據 變量名 一、solr兩種部署模式介紹 Standalone Server 獨立服務器模式:適用於數據規模不大的場景 SolrCloud 分布式集群模式:適用於數據規模大,高可靠、高可用、高並發的場景 二

擷取拼接成新的字串System.arraycopy()如何一分鐘快速掌握示例程式碼

//該示例程式碼直接執行即可,喜歡我的文章請關注我,你們是我動力的源泉,謝謝 public static void main(String[] args) { //宣告一個字串型別的變數,在實際開發中變數為獲取的引數 String signDate="AAAAAAAAAAAAAAAAAAAA

機器學習 深度學習資料彙總含文件,資料程式碼

分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!        

使用pytorch版faster-rcnn訓練自己資料

使用pytorch版faster-rcnn訓練自己資料集 引言 faster-rcnn pytorch程式碼下載 訓練自己資料集 接下來工作 參考文獻 引言 最近在復現目標檢測程式碼(師兄強烈推薦F

TensorFlow函式之tf.nn.conv2d()程式碼

tf.nn.conv2d是TensorFlow裡面實現卷積的函式,是搭建卷積神經網路比較核心的一個方法。 函式格式: tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu =  Noen, name = Non

led指示燈電路圖大全八款led指示燈電路設計原理圖

led指示燈電路圖大全(八款led指示燈電路設計原理圖詳解) led指示燈電路圖(一) 圖1所示電路中只有兩個元件,R選用1/6--1/8W碳膜電阻或金屬膜電阻,阻值在1--300K之間。 Ne為氖泡,也選取用普通日光燈啟輝器中的氖泡,若想用體積小且在60V左右即能啟

《OpenCV3程式設計入門》——5.2.3 addWeighted()函式線性混合程式碼

addWeighted()函式用來計算兩個陣列(影象陣列)的加權和。 格式如下: void addWeighted(InputArray src1, double alpha, InputArray src2, double beta, double gamma, OutputArray

OpenCV中copyTo()函式及Mask程式碼

copyTo函式有兩種重構方式: 第一種:A.copyTo(B),表示將A矩陣複製到B中 第二種:A.copyTo(B, mask),表示得到一個附加掩膜mask的矩陣B。 第一種方法就不多贅述,這裡主要詳細敘述第二種使用方法。  對於第二種mask引數的格

語音識別——基於深度學習的中文語音識別系統實現程式碼

文章目錄 利用thchs30為例建立一個語音識別系統 1. 特徵提取 2. 模型搭建 搭建cnn+dnn+ctc的聲學模型 3. 訓練準備 下載資料

php垃圾回收機制PHP新的垃圾回收機制:Zend GC

概述     在5.2及更早版本的PHP中,沒有專門的垃圾回收器GC(Garbage Collection),引擎在判斷一個變數空間是否能夠被釋放的時候是依據這個變數的zval的refcount的值,如果refcount為0,那麼變數的空間可以被釋放,否則就不釋放,這是一種

Fast RCNN 訓練自己資料 (1編譯配置)

FastRCNN 訓練自己資料集 (1編譯配置) FastRCNN是Ross Girshick在RCNN的基礎上增加了Multi task training整個的訓練過程和測試過程比RCNN快了許多。別的一些細節不展開,過幾天會上傳Fast RCNN的論文筆記。FastRCNN mAP效能上略有上升。Fa