[PyTorch 學習筆記] 2.2 圖片預處理 transforms 模組機制
阿新 • • 發佈:2020-08-27
# PyTorch 的資料增強
我們在安裝`PyTorch`時,還安裝了`torchvision`,這是一個計算機視覺工具包。有 3 個主要的模組:
- `torchvision.transforms`: 裡面包括常用的影象預處理方法
- `torchvision.datasets`: 裡面包括常用資料集如 mnist、CIFAR-10、Image-Net 等
- `torchvision.models`: 裡面包括常用的預訓練好的模型,如 AlexNet、VGG、ResNet、GoogleNet 等
深度學習模型是由資料驅動的,資料的數量和分佈對模型訓練的結果起到決定性作用。所以我們需要對資料進行預處理和資料增強。下面是用資料增強,從一張圖片經過各種變換生成 64 張圖片,增加了資料的多樣性,這可以提高模型的泛化能力。
常用的影象預處理方法有: - 資料中心化 - 資料標準化 - 縮放 - 裁剪 - 旋轉 - 翻轉 - 填充 - 噪聲新增 - 灰度變換 - 線性變換 - 仿射變換 - 亮度、飽和度以及對比度變換。 在[人民幣圖片二分類實驗]()中,我們對資料進行了一定的增強。 ``` # 設定訓練集的資料增強和轉化 train_transform = transforms.Compose([ transforms.Resize((32, 32)),# 縮放 transforms.RandomCrop(32, padding=4), #裁剪 transforms.ToTensor(), # 轉為張量,同時歸一化 transforms.Normalize(norm_mean, norm_std),# 標準化 ]) # 設定驗證集的資料增強和轉化,不需要 RandomCrop valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) ``` 當我們需要多個`transforms`操作時,需要作為一個`list`放在`transforms.Compose`中。需要注意的是`transforms.ToTensor()`是把圖片轉換為張量,同時進行歸一化操作,把每個通道 0~255 的值歸一化為 0~1。在驗證集的資料增強中,不再需要`transforms.RandomCrop()`操作。然後把這兩個`transform`操作作為引數傳給`Dataset`,在`Dataset`的`__getitem__()`方法中做影象增強。 ``` def __getitem__(self, index): # 通過 index 讀取樣本 path_img, label = self.data_info[index] # 注意這裡需要 convert('RGB') img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在這裡做transform,轉為tensor等等 # 返回是樣本和標籤 return img, label ``` 其中`self.transform(img)`會呼叫`Compose`的`__call__()`函式: ``` def __call__(self, img): for t in self.transforms: img = t(img) return img ``` 可以看到,這裡是遍歷`transforms`中的函式,按順序應用到 img 中。 # transforms.Normalize ``` torchvision.transforms.Normalize(mean, std, inplace=False) ``` 功能:逐 channel 地對影象進行標準化 output = ( input - mean ) / std - mean: 各通道的均值 - std: 各通道的標準差 - inplace: 是否原地操作 該方法呼叫的是`F.normalize(tensor, self.mean, self.std, self.inplace)` 而``F.normalize()`方法如下: ``` def normalize(tensor, mean, std, inplace=False): if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') if not inplace: tensor = tensor.clone() dtype = tensor.dtype mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) return tensor ``` 首先判斷是否為 tensor,如果不是 tensor 則丟擲異常。然後根據`inplace`是否為 true 進行 clone,接著把 mean 和 std 都轉換為 tensor (原本是 list),最後減去均值除以方差:`tensor.sub_(mean[:, None, None]).div_(std[:, None, None])` 對資料進行均值為 0,標準差為 1 的標準化,可以加快模型的收斂。 在[邏輯迴歸的實驗]()中,我們的資料生成程式碼如下: ``` sample_nums = 100 mean_value = 1.7 bias = 1 n_data = torch.ones(sample_nums, 2) # 使用正態分佈隨機生成樣本,均值為張量,方差為標量 x0 = torch.normal(mean_value * n_data, 1) + bias # 類別0 資料 shape=(100, 2) # 生成對應標籤 y0 = torch.zeros(sample_nums) # 類別0 標籤 shape=(100, 1) # 使用正態分佈隨機生成樣本,均值為張量,方差為標量 x1 = torch.normal(-mean_value * n_data, 1) + bias # 類別1 資料 shape=(100, 2) # 生成對應標籤 y1 = torch.ones(sample_nums) # 類別1 標籤 shape=(100, 1) train_x = torch.cat((x0, x1), 0) train_y = torch.cat((y0, y1), 0) ``` 生成的資料均值是`mean_value+bias=1.7+1=2.7`,比較靠近 0 均值。模型在 380 次迭代時,準確率就超過了 99.5%。 如果我們把 bias 修改為 5。那麼資料的均值變成了 6.7,偏離 0 均值較遠,這時模型訓練需要更多次才能收斂 (準確率達到 99.5%)。
**參考資料** - [深度之眼 PyTorch 框架班](https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6)
如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章
常用的影象預處理方法有: - 資料中心化 - 資料標準化 - 縮放 - 裁剪 - 旋轉 - 翻轉 - 填充 - 噪聲新增 - 灰度變換 - 線性變換 - 仿射變換 - 亮度、飽和度以及對比度變換。 在[人民幣圖片二分類實驗]()中,我們對資料進行了一定的增強。 ``` # 設定訓練集的資料增強和轉化 train_transform = transforms.Compose([ transforms.Resize((32, 32)),# 縮放 transforms.RandomCrop(32, padding=4), #裁剪 transforms.ToTensor(), # 轉為張量,同時歸一化 transforms.Normalize(norm_mean, norm_std),# 標準化 ]) # 設定驗證集的資料增強和轉化,不需要 RandomCrop valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) ``` 當我們需要多個`transforms`操作時,需要作為一個`list`放在`transforms.Compose`中。需要注意的是`transforms.ToTensor()`是把圖片轉換為張量,同時進行歸一化操作,把每個通道 0~255 的值歸一化為 0~1。在驗證集的資料增強中,不再需要`transforms.RandomCrop()`操作。然後把這兩個`transform`操作作為引數傳給`Dataset`,在`Dataset`的`__getitem__()`方法中做影象增強。 ``` def __getitem__(self, index): # 通過 index 讀取樣本 path_img, label = self.data_info[index] # 注意這裡需要 convert('RGB') img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在這裡做transform,轉為tensor等等 # 返回是樣本和標籤 return img, label ``` 其中`self.transform(img)`會呼叫`Compose`的`__call__()`函式: ``` def __call__(self, img): for t in self.transforms: img = t(img) return img ``` 可以看到,這裡是遍歷`transforms`中的函式,按順序應用到 img 中。 # transforms.Normalize ``` torchvision.transforms.Normalize(mean, std, inplace=False) ``` 功能:逐 channel 地對影象進行標準化 output = ( input - mean ) / std - mean: 各通道的均值 - std: 各通道的標準差 - inplace: 是否原地操作 該方法呼叫的是`F.normalize(tensor, self.mean, self.std, self.inplace)` 而``F.normalize()`方法如下: ``` def normalize(tensor, mean, std, inplace=False): if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') if not inplace: tensor = tensor.clone() dtype = tensor.dtype mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) return tensor ``` 首先判斷是否為 tensor,如果不是 tensor 則丟擲異常。然後根據`inplace`是否為 true 進行 clone,接著把 mean 和 std 都轉換為 tensor (原本是 list),最後減去均值除以方差:`tensor.sub_(mean[:, None, None]).div_(std[:, None, None])` 對資料進行均值為 0,標準差為 1 的標準化,可以加快模型的收斂。 在[邏輯迴歸的實驗]()中,我們的資料生成程式碼如下: ``` sample_nums = 100 mean_value = 1.7 bias = 1 n_data = torch.ones(sample_nums, 2) # 使用正態分佈隨機生成樣本,均值為張量,方差為標量 x0 = torch.normal(mean_value * n_data, 1) + bias # 類別0 資料 shape=(100, 2) # 生成對應標籤 y0 = torch.zeros(sample_nums) # 類別0 標籤 shape=(100, 1) # 使用正態分佈隨機生成樣本,均值為張量,方差為標量 x1 = torch.normal(-mean_value * n_data, 1) + bias # 類別1 資料 shape=(100, 2) # 生成對應標籤 y1 = torch.ones(sample_nums) # 類別1 標籤 shape=(100, 1) train_x = torch.cat((x0, x1), 0) train_y = torch.cat((y0, y1), 0) ``` 生成的資料均值是`mean_value+bias=1.7+1=2.7`,比較靠近 0 均值。模型在 380 次迭代時,準確率就超過了 99.5%。 如果我們把 bias 修改為 5。那麼資料的均值變成了 6.7,偏離 0 均值較遠,這時模型訓練需要更多次才能收斂 (準確率達到 99.5%)。
**參考資料** - [深度之眼 PyTorch 框架班](https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6)
如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章