1. 程式人生 > 程式設計 >pytorch隨機取樣操作SubsetRandomSampler()

pytorch隨機取樣操作SubsetRandomSampler()

這篇文章記錄一個取樣器都隨機地從原始的資料集中抽樣資料。抽樣資料採用permutation。 生成任意一個下標重排,從而利用下標來提取dataset中的資料的方法

需要的庫

import torch

使用方法

這裡以MNIST舉例

train_dataset = dsets.MNIST(root='./data',#檔案存放路徑
              train=True,#提取訓練集
              transform=transforms.ToTensor(),#將影象轉化為Tensor
              download=True)

sample_size = len(train_dataset)
sampler1 = torch.utils.data.sampler.SubsetRandomSampler(
  np.random.choice(range(len(train_dataset)),sample_size))

程式碼詳解

np.random.choice()

#numpy.random.choice(a,size=None,replace=True,p=None)
#從a(只要是ndarray都可以,但必須是一維的)中隨機抽取數字,並組成指定大小(size)的陣列
#replace:True表示可以取相同數字,False表示不可以取相同數字
#陣列p:與陣列a相對應,表示取陣列a中每個元素的概率,預設為選取每個元素的概率相同。

那麼這裡就相當於抽取了一個全排列

torch.utils.data.sampler.SubsetRandomSampler

# 會根據後面給的列表從資料集中按照下標取元素
# class torch.utils.data.SubsetRandomSampler(indices):無放回地按照給定的索引列表取樣樣本元素。

所以就可以了。

補充知識:Pytorch學習之torch----隨機抽樣、序列化、並行化

1. torch.manual_seed(seed)

說明:設定生成隨機數的種子,返回一個torch._C.Generator物件。使用隨機數種子之後,生成的隨機數是相同的。

引數:

seed(int or long) -- 種子

>>> import torch
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> a = torch.rand(2,3)
>>> a
tensor([[0.7576,0.2793,0.4031],[0.7347,0.0293,0.7999]])
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> b = torch.rand(2,3)
>>> b
tensor([[0.7576,0.7999]])
>>> a == b
tensor([[1,1,1],[1,1]],dtype=torch.uint8)

2. torch.initial_seed()

說明:返回生成隨機數的原始種子值

>>> torch.manual_seed(4)
<torch._C.Generator object at 0x0000019684586350>
>>> torch.initial_seed()
4

3. torch.get_rng_state()

說明:返回隨機生成器狀態(ByteTensor)

>>> torch.initial_seed()
4
>>> torch.get_rng_state()
tensor([4,...,0],dtype=torch.uint8)

4. torch.set_rng_state()

說明:設定隨機生成器狀態

引數:

new_state(ByteTensor) -- 期望的狀態

5. torch.default_generator

說明:預設的隨機生成器。等於<torch._C.Generator object>

6. torch.bernoulli(input,out=None)

說明:從伯努利分佈中抽取二元隨機數(0或1)。輸入張量包含用於抽取二元值的概率。因此,輸入中的所有值都必須在[0,1]區間內。輸出張量的第i個元素值,將會以輸入張量的第i個概率值等於1。返回值將會是與輸入相同大小的張量,每個值為0或者1.

引數:

input(Tensor) -- 輸入為伯努利分佈的概率值

out(Tensor,可選) -- 輸出張量

>>> a = torch.Tensor(3,3).uniform_(0,1)
>>> a
tensor([[0.5596,0.5591,0.0915],[0.2100,0.0072,0.0390],[0.9929,0.9131,0.6186]])
>>> torch.bernoulli(a)
tensor([[0.,1.,0.],[0.,0.,[1.,1.]])

7. torch.multinomial(input,num_samples,replacement=False,out=None)

說明:返回一個張量,每行包含從input相應行中定義的多項分佈中抽取的num_samples個樣本。要求輸入input每行的值不需要總和為1,但是必須非負且總和不能為0。當抽取樣本時,依次從左到右排列(第一個樣本對應第一列)。如果輸入input是一個向量,輸出out也是一個相同長度num_samples的向量。如果輸入input是m行的矩陣,輸出out是形如m x n的矩陣。並且如果引數replacement為True,則樣本抽取可以重複。否則,一個樣本在每行不能被重複。

引數:

input(Tensor) -- 包含概率的張量

num_samples(int) -- 抽取的樣本數

replacement(bool) -- 布林值,決定是否能重複抽取

out(Tensor) -- 結果張量

>>> weights = torch.Tensor([0,10,3,0])
>>> weights
tensor([ 0.,10.,3.,0.])
>>> torch.multinomial(weights,4,replacement=True)
tensor([1,1])

8. torch.normal(means,std,out=None)

說明:返回一個張量,包含從給定引數means,std的離散正態分佈中抽取隨機數。均值means是一個張量,包含每個輸出元素相關的正態分佈的均值。std是一個張量。包含每個輸出元素相關的正態分佈的標準差。均值和標準差的形狀不須匹配,但每個張量的元素個數必須想聽。

引數:

means(Tensor) -- 均值

std(Tensor) -- 標準差

out(Tensor) -- 輸出張量

>>> n_data = torch.ones(5,2)
>>> n_data
tensor([[1.,1.],1.]])
>>> x0 = torch.normal(2 * n_data,1)
>>> x0
tensor([[1.6544,0.9805],[2.1114,2.7113],[1.0646,1.9675],[2.7652,3.2138],[1.1204,2.0293]])

9. torch.save(obj,f,pickle_module=<module 'pickle' from '/home/lzjs/...)

說明:儲存一個物件到一個硬碟檔案上。

引數:

obj -- 儲存物件

f -- 類檔案物件或一個儲存檔名的字串

pickle_module -- 用於pickling源資料和物件的模組

pickle_protocol -- 指定pickle protocal可以覆蓋預設引數

10. torch.load(f,map_location=None,pickle_module=<module 'pickle' from '/home/lzjs/...)

說明:從磁碟檔案中讀取一個通過torch.save()儲存的物件。torch.load()可通過引數map_location動態地進行記憶體重對映,使其能從不動裝置中讀取檔案。一般呼叫時,需兩個引數:storage和location tag。返回不同地址中的storage,或者返回None。如果這個引數是字典的話,意味著從檔案的地址標記到當前系統的地址標記的對映。

引數:

f -- l類檔案物件或一個儲存檔名的字串

map_location -- 一個函式或字典規定如何remap儲存位置

pickle_module -- 用於unpickling元資料和物件的模組

torch.load('tensors.pt')
# 載入所有的張量到CPU
torch.load('tensor.pt',map_location=lambda storage,loc:storage)
# 載入張量到GPU
torch.load('tensors.pt',map_location={'cuda:1':'cuda:0'})

11. torch.get_num_threads()

說明:獲得用於並行化CPU操作的OpenMP執行緒數

12. torch.set_num_threads()

說明:設定用於並行化CPU操作的OpenMP執行緒數

以上這篇pytorch隨機取樣操作SubsetRandomSampler()就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。