1. 程式人生 > 其它 >API_Net官方程式碼之資料處理

API_Net官方程式碼之資料處理

技術標籤:API_Net

在這裡插入圖片描述

一、資料準備

總結:
RandomDataset :用於驗證 (val)
BatchDataset:用於訓練 (train)
BalancedBatchSampler:決定如何取樣樣本,不是簡單的在Dataloader中設定一個batch_size了

1)匯入的包類:

import torch
from PIL import  Image
import numpy as np
from torchvision import transforms

from torch.utils.data import Dataset
from torch.utils.data.
sampler import BatchSampler#此處的BatchSampler相當於在Dataloader中設定的batch_size

2)讀取圖片函式

學習點:考慮讀取圖片失敗情況,使用try…except結構,並將讀取失敗的圖路徑儲存下來,並返回一個全為白色的同尺度新圖。

def default_loader(path):
    '''
    開啟圖片並轉化為RGB,開啟失敗,則記錄下來,並返回一個新圖
    '''
    try:
        img = Image.open(path).convert('RGB')
    except:
        with
open('read_error.txt', 'a') as fid: fid.write(path+'\n') return Image.new('RGB', (224,224), 'white') return img

3)RandomDataset類

此類用於選取指定index的樣本,返回的是一張圖片以及對應的標籤。
批量獲取是在Dataloader中設定的。

class RandomDataset(Dataset):
    def __init__(self, transform=None, dataloader=default_loader)
:#此處dataloader用不上 self.transform = transform self.dataloader = dataloader with open('val.txt', 'r') as fid:#將圖片路徑以及標籤讀取出來 self.imglist = fid.readlines() def __getitem__(self, index): image_name, label = self.imglist[index].strip().split() #獲取對應的路徑以及標籤 image_path = image_name img = self.dataloader(image_path) img = self.transform(img) label = int(label) #特別注意,要將label設定為int型別 label = torch.LongTensor([label]) return [img, label] #注意官網推薦使用字典{'image':img,'label':label} def __len__(self): return len(self.imglist)

4)BatchDataset類

此類與上述類類似。

class BatchDataset(Dataset):
    def __init__(self, transform=None, dataloader=default_loader):
        self.transform = transform
        self.dataloader = dataloader
        with open('train.txt', 'r') as fid:
            self.imglist = fid.readlines()
        self.labels = []
        for line in self.imglist:
            image_path, label = line.strip().split()
            self.labels.append(int(label))
        self.labels = np.array(self.labels)
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, index):
        image_name, label = self.imglist[index].strip().split()
        image_path = image_name
        img = self.dataloader(image_path) #載入資料
        img = self.transform(img)
        label = int(label)
        label = torch.LongTensor([label])
        return [img, label]

    def __len__(self):
        return len(self.imglist)

5)BalancedBatchSampler類
此程式碼沒有初始化父類,可能是用不到父類的變數。

  • 獲取所有樣本的
class BalancedBatchSampler(BatchSampler):
    def __init__(self, dataset, n_classes, n_samples):
    '''
	獲取每類樣本對應的索引,用字典儲存,並將每類的索引打亂,因為此取樣器返回的就是索引列表,用於在Dataset中獲取樣本的,相當於其中的index
	'''
        self.labels = dataset.labels #所有樣本的labels
        self.labels_set = list(set(self.labels.numpy())) #0~199,如果是1~200會報錯的
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] #返回每類中的樣本對應的index,字典
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l]) #將每類對應的索引打亂

        self.used_label_indices_count = {label: 0 for label in self.labels_set}  ##每類樣本使用過的數量
        self.count = 0 #用過的圖片數量,用於統計看夠不夠下一個batch用

        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset

        self.batch_size = self.n_samples * self.n_classes#此處儲存一個batch樣本的數量

    def __iter__(self):
        self.count = 0 #只用關心樣本使用一次的事情,也就是一個epoch後就歸零了
        while self.count + self.batch_size < len(self.dataset): #也就是使用過的圖片數量再加上一個batch仍然小於總數,那麼可以繼續提供一個batch的圖片
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False) #選類別,不放回的抽,也就是抽出來的不能有重複
            indices = []
            for class_ in classes: # 1 , 3 ,4
                indices.extend(self.label_to_indices[class_][ #label_to_indices是一個字典,{class:[index]}
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples]) #順序獲取每類中的n個樣本
                self.used_label_indices_count[class_] += self.n_samples #每類樣本使用過的數量增加此次使用的樣本數量
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): #使用過的加上下次的樣本溝用不夠,大於表示不夠下次使用了
                    np.random.shuffle(self.label_to_indices[class_])#不夠下次使用,那就將其重新打亂,並將每類的使用數量歸零
                    self.used_label_indices_count[class_] = 0
            yield indices #每類都獲取到了後,就送出去,送出去的是樣本的索引
            self.count += self.n_classes * self.n_samples #增加本批樣本數量

    def __len__(self):
        return len(self.dataset) // self.batch_size

6)具體應用:

train_dataset = BatchDataset(transform=transforms.Compose([
        transforms.Resize([512, 512]),
        transforms.RandomCrop([448, 448]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        )]))

    train_sampler = BalancedBatchSampler(train_dataset, args.n_classes, args.n_samples) #用於設定每批樣本的來源,其返回的是樣本的索引indices
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_sampler=train_sampler,#不再設定batch_size,使用batch_sampler
        num_workers=args.workers, pin_memory=True) #num_works是執行緒,pin_memory不太懂,沒看懂

二、總結

1.讀取圖片,考慮讀取失敗的情況,並且要考慮失敗後進行記錄,使用try…except…的結構;
2.樣本的獲取分為取樣以及獲取例項兩步,正常情況下,通過設定batch_size即可不用設定取樣器,只需要設定Dataset資料集即可;
3.設定樣本取樣器,繼承自torch.utils.data.sampler.BatchSampler,之後實現三個函式,分別是__init__()、iter()、以及__len__()等,其中的iter函式中使用yield生成一個生成器,不斷送出取樣的索引列表;
4. 設定資料集Dataset,繼承自torch.utils.data.Dataset類,之後實現三個函式,分別是__init__()、getitem()、len(),getitem()函式是返回例項與標籤的字典。