PyTorch(一)——資料處理
PyTorch學習和使用(一)
要使學會用一個框架,只會執行其測試實驗是不行的,所以現在打算把caffe中的Siamese模型使用PyTorch實現,來鞏固自己對PyTroch的熟練使用。
資料預處理
首先是資料處理這一塊,PyTorch使用了torchvision來完成資料的處理,其只實現了一些資料集的處理,如果處理自己的工程則需要修改增加內容。
把原始資料處理為模型使用的資料需要3步:transforms.Compose()
torchvision.datasets
torch.utils.data.DataLoader()
分別可以理解為資料處理格式的定義、資料處理和資料載入。
Compose()transforms.Compose([transform.CenterCrop(10), transofrms.ToTensor()])
又比如先轉換為張量,然後正則化,程式碼為:`transforms.Compose([transofrms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 其具體的引數呼叫在原始碼中可以看到,在此不多說了。還要注意的是Compose的程式碼是:
def __call_(self, img):
for t in self.transforms:
img = t(img)
return img
這就是把輸入到Compose的操作按順序進行執行。先執行第一個,然後第二個……。如果需要處理自己的資料,可以把具體的操縱放在這個類中實現。
torchvision.datasets ;裡實現了不同針對資料集的處理方法,主要用來載入資料和處理資料。比如在mnist.py 和cifar.py 中用來處理mnist和cifar資料集。類的實現需要繼承父類data.Dataset,其主要方法有2個:
-
__init__(self, root, train=Ture, transform=None, traget_transform=None, download=False):
-
--getitem__(self, index):
該方法是把讀入的輸出傳給PyTorch(迭代器的方式)。**注意:**上面定義的transform.Compose在次數進行呼叫,通過index確定需要訪問的資料,然後對其格式進行轉換,最後返回處理後的資料。也就是說資料在定義時只是定義了一個類,其具體的資料傳出在需要使用時使用該方法完成。
至此,對資料進行載入,然後處理傳給PyTorch已經完成,如果需要對自己的資料進行處理,也是通過修改和增加此部分完成。接下來需要對訓練的資料進行處理,比如分批次的大小,十分隨機處理等等。
torch.utils.data.DataLoader() Data loder, Combines a dataset and and a sampler, and provides single, or multi-process iterators over the dataset. 就是把合成數據並且提供迭代訪問。輸入引數有:
-
dataset(Dataset)。輸入載入的資料,就是上面的torchvision.datasets.myData()的實現,所以需要繼承data.Dataset,滿足此介面。
-
**batch-size, shuffle, sampler, num_workers, collate_fn, pin_memory, drop_last.**這些引數比較好理解,看名字就知道其作用了。分別為:
- batch-size。樣本每個batch的大小,預設為1。
- shuffle。是否打亂資料,預設為False。
- sampler。定義一個方法來繪製樣本資料,如果定義該方法,則不能使用shuffle。
- num_workers。資料分為幾批處理(對於大資料)。
- collate_fn。整理資料,把每個batch資料整理為tensor。(一般使用預設呼叫default_collate(batch))。
- pin_memory。針對不同型別的batch進行處理。比如為Map或者Squence等型別,需要處理為tensor型別。
- drop_last。用於處理最後一個batch的資料。因為最後一個可能不能夠被整除,如果設定為True,則捨棄最後一個,為False則保留最後一個,但是最後一個可能很小。
迭代器(DataLoaderIter)的具體處理就是根據這些引數的設定,分別進行不同的處理。
補充2017/8/10:
torch.utils.data.DataLoader
類主要使用torch.utils.data.sampler
實現,sampler
是所有采樣器的基礎類,提供了迭代器的迭代(__iter__
)和長度(__len__
)介面實現,同時sampler
也是通過索引對資料進行洗牌(shuffle)等操作。因此,如果DataLoader
不適用於你的資料,需要重新設計資料的分批次,可以充分使用所提供的smapler
。