1. 程式人生 > >pytorch torchvision study

pytorch torchvision study

torchvision 模組

torchvision是獨立於pytorch的關於影象操作的工具庫,主要包含了如下4個子模組或包:

  • datasets
  • utils
  • transforms
  • models

1、datasets

torchvision.datasets包含如下資料集,可以下載和載入

  • MNIST
  • COCO(用於影象標註和目標檢測)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10
  • SVHN 
  • PhotoTour
from torchvision import datasets
​​​​​​​train_dataset = datasets.MNIST(root='./data', train=True,
                               transform=transforms.ToTensor(),
                               download=True)

此操作便可下載MNIST的訓練資料集,

資料集有 API: - __getitem__ - __len__ 他們都是 torch.utils.data.Dataset的子類。因此, 他們可以使用torch.utils.data.DataLoader裡的多執行緒 (python multithreading) 。

例如:

torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)

 

2、utils

utils主要提供了兩個方法:

  • make_grid  
  • save_image
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False) 
將輸入的minbatch_size圖片轉換成一張大的網格圖片

torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False) 
將輸入的圖片儲存,如果輸入的是minbatch_size圖片,先make_grid轉換成大的網格圖再儲存

3、transforms

了方便進行資料的操作,pytorch團隊提供了一個torchvision.transforms包,我們可以用transforms進行以下操作:

PIL.Image/numpy.ndarray與Tensor的相互轉化;

歸一化;

對PIL.Image進行裁剪、縮放等操作。

通常,在使用torchvision.transforms,我們通常使用transforms.Compose將transforms組合在一起。

transform = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                          std = [ 0.229, 0.224, 0.225 ]),
])
  • transforms.ToTensor() :把shape=(H x W x C)的畫素值範圍為[0, 255]的PIL.Image或者numpy.ndarray轉換成shape=(C x H x W)的畫素值範圍為[0.0, 1.0]的torch.FloatTensor。
  • transforms.Normalize(mean,std) : 此轉換類作用於torch.tensor,給定均值(R, G, B)和標準差(R, G, B),用公式channel = (channel - mean) / std進行規範化。

4、models

torchvision.models包含下列常用網路結構,並且提供了預訓練模型,可以通過簡單呼叫來讀取網路結構和預訓練模型

  • AlexNet: AlexNet variant from the “One weird trick” paper.
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
import torchvision
model = torchvision.models.resnet50(pretrained=True)

這樣就匯入了resnet50的預訓練模型了,

如果只需要網路結構,不需要用預訓練模型的引數來初始化

model = torchvision.models.resnet50(pretrained=False)

如果要匯入densenet模型也是同樣的道理,比如匯入densenet169,且不需要是預訓練的模型,

model = torchvision.models.densenet169(pretrained=False)

由於預訓練引數預設是假,所以等價於

model = torchvision.models.densenet169()