API_Net官方程式碼之資料處理
阿新 • • 發佈:2021-02-16
技術標籤:API_Net
一、資料準備
總結:
RandomDataset :用於驗證 (val)
BatchDataset:用於訓練 (train)
BalancedBatchSampler:決定如何取樣樣本,不是簡單的在Dataloader中設定一個batch_size了
1)匯入的包類:
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data. sampler import BatchSampler#此處的BatchSampler相當於在Dataloader中設定的batch_size
2)讀取圖片函式
學習點:考慮讀取圖片失敗情況,使用try…except結構,並將讀取失敗的圖路徑儲存下來,並返回一個全為白色的同尺度新圖。
def default_loader(path):
'''
開啟圖片並轉化為RGB,開啟失敗,則記錄下來,並返回一個新圖
'''
try:
img = Image.open(path).convert('RGB')
except:
with open('read_error.txt', 'a') as fid:
fid.write(path+'\n')
return Image.new('RGB', (224,224), 'white')
return img
3)RandomDataset類
此類用於選取指定index的樣本,返回的是一張圖片以及對應的標籤。
批量獲取是在Dataloader中設定的。
class RandomDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader) :#此處dataloader用不上
self.transform = transform
self.dataloader = dataloader
with open('val.txt', 'r') as fid:#將圖片路徑以及標籤讀取出來
self.imglist = fid.readlines()
def __getitem__(self, index):
image_name, label = self.imglist[index].strip().split() #獲取對應的路徑以及標籤
image_path = image_name
img = self.dataloader(image_path)
img = self.transform(img)
label = int(label) #特別注意,要將label設定為int型別
label = torch.LongTensor([label])
return [img, label] #注意官網推薦使用字典{'image':img,'label':label}
def __len__(self):
return len(self.imglist)
4)BatchDataset類
此類與上述類類似。
class BatchDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader):
self.transform = transform
self.dataloader = dataloader
with open('train.txt', 'r') as fid:
self.imglist = fid.readlines()
self.labels = []
for line in self.imglist:
image_path, label = line.strip().split()
self.labels.append(int(label))
self.labels = np.array(self.labels)
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, index):
image_name, label = self.imglist[index].strip().split()
image_path = image_name
img = self.dataloader(image_path) #載入資料
img = self.transform(img)
label = int(label)
label = torch.LongTensor([label])
return [img, label]
def __len__(self):
return len(self.imglist)
5)BalancedBatchSampler類
此程式碼沒有初始化父類,可能是用不到父類的變數。
- 獲取所有樣本的
class BalancedBatchSampler(BatchSampler):
def __init__(self, dataset, n_classes, n_samples):
'''
獲取每類樣本對應的索引,用字典儲存,並將每類的索引打亂,因為此取樣器返回的就是索引列表,用於在Dataset中獲取樣本的,相當於其中的index
'''
self.labels = dataset.labels #所有樣本的labels
self.labels_set = list(set(self.labels.numpy())) #0~199,如果是1~200會報錯的
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] #返回每類中的樣本對應的index,字典
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l]) #將每類對應的索引打亂
self.used_label_indices_count = {label: 0 for label in self.labels_set} ##每類樣本使用過的數量
self.count = 0 #用過的圖片數量,用於統計看夠不夠下一個batch用
self.n_classes = n_classes
self.n_samples = n_samples
self.dataset = dataset
self.batch_size = self.n_samples * self.n_classes#此處儲存一個batch樣本的數量
def __iter__(self):
self.count = 0 #只用關心樣本使用一次的事情,也就是一個epoch後就歸零了
while self.count + self.batch_size < len(self.dataset): #也就是使用過的圖片數量再加上一個batch仍然小於總數,那麼可以繼續提供一個batch的圖片
classes = np.random.choice(self.labels_set, self.n_classes, replace=False) #選類別,不放回的抽,也就是抽出來的不能有重複
indices = []
for class_ in classes: # 1 , 3 ,4
indices.extend(self.label_to_indices[class_][ #label_to_indices是一個字典,{class:[index]}
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples]) #順序獲取每類中的n個樣本
self.used_label_indices_count[class_] += self.n_samples #每類樣本使用過的數量增加此次使用的樣本數量
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): #使用過的加上下次的樣本溝用不夠,大於表示不夠下次使用了
np.random.shuffle(self.label_to_indices[class_])#不夠下次使用,那就將其重新打亂,並將每類的使用數量歸零
self.used_label_indices_count[class_] = 0
yield indices #每類都獲取到了後,就送出去,送出去的是樣本的索引
self.count += self.n_classes * self.n_samples #增加本批樣本數量
def __len__(self):
return len(self.dataset) // self.batch_size
6)具體應用:
train_dataset = BatchDataset(transform=transforms.Compose([
transforms.Resize([512, 512]),
transforms.RandomCrop([448, 448]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)]))
train_sampler = BalancedBatchSampler(train_dataset, args.n_classes, args.n_samples) #用於設定每批樣本的來源,其返回的是樣本的索引indices
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_sampler=train_sampler,#不再設定batch_size,使用batch_sampler
num_workers=args.workers, pin_memory=True) #num_works是執行緒,pin_memory不太懂,沒看懂
二、總結
1.讀取圖片,考慮讀取失敗的情況,並且要考慮失敗後進行記錄,使用try…except…的結構;
2.樣本的獲取分為取樣以及獲取例項兩步,正常情況下,通過設定batch_size即可不用設定取樣器,只需要設定Dataset資料集即可;
3.設定樣本取樣器,繼承自torch.utils.data.sampler.BatchSampler,之後實現三個函式,分別是__init__()、iter()、以及__len__()等,其中的iter函式中使用yield生成一個生成器,不斷送出取樣的索引列表;
4. 設定資料集Dataset,繼承自torch.utils.data.Dataset類,之後實現三個函式,分別是__init__()、getitem()、len(),getitem()函式是返回例項與標籤的字典。