pytorch torchvision study
阿新 • • 發佈:2018-12-13
torchvision 模組
torchvision是獨立於pytorch的關於影象操作的工具庫,主要包含了如下4個子模組或包:
- datasets
- utils
- transforms
- models
1、datasets
torchvision.datasets包含如下資料集,可以下載和載入
- MNIST
- COCO(用於影象標註和目標檢測)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
- SVHN
- PhotoTour
from torchvision import datasets
train_dataset = datasets.MNIST(root='./data', train=True,
transform=transforms.ToTensor(),
download=True)
此操作便可下載MNIST的訓練資料集,
資料集有 API: - __getitem__ - __len__ 他們都是 torch.utils.data.Dataset的子類。因此, 他們可以使用torch.utils.data.DataLoader裡的多執行緒 (python multithreading) 。
例如:
torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
2、utils
utils主要提供了兩個方法:
- make_grid
- save_image
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False) 將輸入的minbatch_size圖片轉換成一張大的網格圖片 torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False) 將輸入的圖片儲存,如果輸入的是minbatch_size圖片,先make_grid轉換成大的網格圖再儲存
3、transforms
了方便進行資料的操作,pytorch團隊提供了一個torchvision.transforms包,我們可以用transforms進行以下操作:
PIL.Image/numpy.ndarray與Tensor的相互轉化;
歸一化;
對PIL.Image進行裁剪、縮放等操作。
通常,在使用torchvision.transforms,我們通常使用transforms.Compose將transforms組合在一起。
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
std = [ 0.229, 0.224, 0.225 ]),
])
- transforms.ToTensor() :把shape=(H x W x C)的畫素值範圍為[0, 255]的PIL.Image或者numpy.ndarray轉換成shape=(C x H x W)的畫素值範圍為[0.0, 1.0]的torch.FloatTensor。
- transforms.Normalize(mean,std) : 此轉換類作用於torch.tensor,給定均值(R, G, B)和標準差(R, G, B),用公式channel = (channel - mean) / std進行規範化。
4、models
torchvision.models包含下列常用網路結構,並且提供了預訓練模型,可以通過簡單呼叫來讀取網路結構和預訓練模型
- AlexNet: AlexNet variant from the “One weird trick” paper.
- VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
import torchvision
model = torchvision.models.resnet50(pretrained=True)
這樣就匯入了resnet50的預訓練模型了,
如果只需要網路結構,不需要用預訓練模型的引數來初始化
model = torchvision.models.resnet50(pretrained=False)
如果要匯入densenet模型也是同樣的道理,比如匯入densenet169,且不需要是預訓練的模型,
model = torchvision.models.densenet169(pretrained=False)
由於預訓練引數預設是假,所以等價於
model = torchvision.models.densenet169()