pytorch sampler對資料進行取樣的實現
阿新 • • 發佈:2020-01-09
PyTorch中還單獨提供了一個sampler模組,用來對資料進行取樣。常用的有隨機取樣器:RandomSampler,當dataloader的shuffle引數為True時,系統會自動呼叫這個取樣器,實現打亂資料。預設的是採用SequentialSampler,它會按順序一個一個進行取樣。這裡介紹另外一個很有用的取樣方法: WeightedRandomSampler,它會根據每個樣本的權重選取資料,在樣本比例不均衡的問題中,可用它來進行重取樣。
構建WeightedRandomSampler時需提供兩個引數:每個樣本的權重weights、共選取的樣本總數num_samples,以及一個可選引數replacement。權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。replacement用於指定是否可以重複選取某一個樣本,預設為True,即允許在一個epoch中重複取樣某一個數據。如果設為False,則當某一類的樣本被全部選取完,但其樣本數目仍未達到num_samples時,sampler將不會再從該類中選擇資料,此時可能導致weights引數失效。
下面舉例說明。
from dataSet import * dataset = DogCat('data/dogcat/',transform=transform) from torch.utils.data import DataLoader # 狗的圖片被取出的概率是貓的概率的兩倍 # 兩類圖片被取出的概率與weights的絕對大小無關,只和比值有關 weights = [2 if label == 1 else 1 for data,label in dataset] print(weights) from torch.utils.data.sampler import WeightedRandomSampler sampler = WeightedRandomSampler(weights,\ num_samples=9,\ replacement=True) dataloader = DataLoader(dataset,batch_size=3,sampler=sampler) for datas,labels in dataloader: print(labels.tolist())
輸出:
[2,2,1,2] [1,0] [1,0] [0,1]
github 地址:
https://github.com/WebLearning17/CommonTool
以上這篇pytorch sampler對資料進行取樣的實現就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。