MNIST 資料集、資料載入
MNIST 資料集
機器學習的入門就是MNIST。
MNIST 資料集來自美國國家標準與技術研究所,是NIST(National Institute of Standards and Technology)的縮小版,訓練集(training set)由來自 250 個不同人手寫的數字構成,其中 50% 是高中學生,50% 來自人口普查局(the Census Bureau)的工作人員,測試集(test set)也是同樣比例的手寫數字資料。
獲取MNIST
MNIST 資料集可在http://yann.lecun.com/exdb/mnist/獲取,圖片是以位元組的形式進行儲存,它包含了四個部分:
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本) Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標籤) Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本) Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤)
此資料集中,訓練樣本:共60000個,其中55000個用於訓練,另外5000個用於驗證。測試樣本:共10000個,驗證資料比例相同。
from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())
資料載入
from torch.utils.data import DataLoader from torchvision.utils import make_grid dataloader = DataLoader(dataset=mnist_train, batch_size=2, shuffle=True, num_workers=2) for (images, labels) in dataloader: print(labels) image = make_grid(images).permute(1, 2, 0).numpy() plt.imshow(image) plt.show() exit()
其中引數含義:
- dataset:提前定義的dataset的例項
- batch_size:傳入資料的batch的大小,常用128,256等等
- shuffle:bool型別,表示是否在每次獲取資料的時候提前打亂資料
-
num_workers
:載入資料的執行緒數
transforms
由於 DataLoader 這個載入器只能載入 tensors, numpy arrays, numbers, dicts or lists
但是 found <class 'PIL.Image.Image'>,所以就很尷尬,我們需要將圖片轉換一下
transforms 用於圖形變換,在使用時我們還可以使用 transforms.Compose
將一系列的transforms操作連結起來。
-
torchvision.transforms.Compose([ ts,ts,ts... ])
ts為transforms操作
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
大多數情況下我們不會只transforms 一下,所以可以用如下方案
from torchvision import transforms
transforms.Compose(
[ #文件 https://pytorch.org/vision/stable/transforms.html
transforms.ToPILImage(), # 轉成PIL圖片
# transforms.Resize(size), # 縮放
transforms.ToTensor(), # 變張量
transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]
)
介紹一個概念:
transforms 處理過後,會把通道移到最前邊。比如 MNIST h*w*c
為:28281
tensor處理完,通道數會提前,並且做了軸交換,變為了 c*h*w
為:12828
至於為什麼要這麼設計?聽傳言是做矩陣加減乘除以及卷積等運算是需要呼叫cuda和cudnn的函式的,而這些介面都設成成 chw 格式了