1. 程式人生 > >PyTorch ResNet 使用與原始碼解析

PyTorch ResNet 使用與原始碼解析

> 本章程式碼:[https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py) 這篇文章首先會簡單介紹一下 `PyTorch` 中提供的影象分類的網路,然後重點介紹 `ResNet` 的使用,以及 `ResNet` 的原始碼。 # 模型概覽 在`torchvision.model`中,有很多封裝好的模型。
可以分類 3 類: - 經典網路 - alexnet - vgg - resnet - inception - densenet - googlenet - 輕量化網路 - squeezenet - mobilenet - shufflenetv2 - 自動神經結構搜尋方法的網路 - mnasnet # ResNet18 使用 以 `ResNet 18` 為例。 首先載入訓練好的模型引數: ``` resnet18 = models.resnet18() # 修改全連線層的輸出 num_ftrs = resnet18.fc.in_features resnet18.fc = nn.Linear(num_ftrs, 2) # 載入模型引數 checkpoint = torch.load(m_path) resnet18.load_state_dict(checkpoint['model_state_dict']) ``` 然後比較重要的是把模型放到 GPU 上,並且轉換到`eval`模式: ``` resnet18.to(device) resnet18.eval() ``` 在 inference 時,主要流程如下: - 程式碼要放在`with torch.no_grad():`下。`torch.no_grad()`會關閉反向傳播,可以減少記憶體、加快速度。 - 根據路徑讀取圖片,把圖片轉換為 tensor,然後使用`unsqueeze_(0)`方法把形狀擴大為 $B \times C \times H \times W$,再把 tensor 放到 GPU 上 。 - 模型的輸出資料`outputs`的形狀是 $1 \times 2$,表示 `batch_size` 為 1,分類數量為 2。`torch.max(outputs,0)`是返回`outputs`中**每一列**最大的元素和索引,`torch.max(outputs,1)`是返回`outputs`中**每一行**最大的元素和索引。 這裡使用`_, pred_int = torch.max(outputs.data, 1)`返回最大元素的索引,然後根據索引獲得 label:`pred_str = classes[int(pred_int)]`。 關鍵程式碼如下: ``` with torch.no_grad(): for idx, img_name in enumerate(img_names): path_img = os.path.join(img_dir, img_name) # step 1/4 : path --> img img_rgb = Image.open(path_img).convert('RGB') # step 2/4 : img --> tensor img_tensor = img_transform(img_rgb, inference_transform) img_tensor.unsqueeze_(0) img_tensor = img_tensor.to(device) # step 3/4 : tensor --> vector outputs = resnet18(img_tensor) # step 4/4 : get label _, pred_int = torch.max(outputs.data, 1) pred_str = classes[int(pred_int)] ``` 全部程式碼如下所示: ``` import os import time import torch.nn as nn import torch import torchvision.transforms as transforms from PIL import Image from matplotlib import pyplot as plt import torchvision.models as models import enviroments BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cpu") # config vis = True # vis = False vis_row = 4 norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] inference_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) classes = ["ants", "bees"] def img_transform(img_rgb, transform=None): """ 將資料轉換為模型讀取的形式 :param img_rgb: PIL Image :param transform: torchvision.transform :return: tensor """ if transform is None: raise ValueError("找不到transform!必須有transform對img進行處理") img_t = transform(img_rgb) return img_t def get_img_name(img_dir, format="jpg"): """ 獲取資料夾下format格式的檔名 :param img_dir: str :param format: str :return: list """ file_names = os.listdir(img_dir) # 使用 list(filter(lambda())) 篩選出 jpg 字尾的檔案 img_names = list(filter(lambda x: x.endswith(format), file_names)) if len(img_names) < 1: raise ValueError("{}下找不到{}格式資料".format(img_dir, format)) return img_names def get_model(m_path, vis_model=False): resnet18 = models.resnet18() # 修改全連線層的輸出 num_ftrs = resnet18.fc.in_features resnet18.fc = nn.Linear(num_ftrs, 2) # 載入模型引數 checkpoint = torch.load(m_path) resnet18.load_state_dict(checkpoint['model_state_dict']) if vis_model: from torchsummary import summary summary(resnet18, input_size=(3, 224, 224), device="cpu") return resnet18 if __name__ == "__main__": img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees") model_path = "./checkpoint_14_epoch.pkl" time_total = 0 img_list, img_pred = list(), list() # 1. data img_names = get_img_name(img_dir) num_img = len(img_names) # 2. model resnet18 = get_model(model_path, True) resnet18.to(device) resnet18.eval() with torch.no_grad(): for idx, img_name in enumerate(img_names): path_img = os.path.join(img_dir, img_name) # step 1/4 : path --> img img_rgb = Image.open(path_img).convert('RGB') # step 2/4 : img --> tensor img_tensor = img_transform(img_rgb, inference_transform) img_tensor.unsqueeze_(0) img_tensor = img_tensor.to(device) # step 3/4 : tensor --> vector time_tic = time.time() outputs = resnet18(img_tensor) time_toc = time.time() # step 4/4 : visualization _, pred_int = torch.max(outputs.data, 1) pred_str = classes[int(pred_int)] if vis: img_list.append(img_rgb) img_pred.append(pred_str) if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1: for i in range(len(img_list)): plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i]) plt.title("predict:{}".format(img_pred[i])) plt.show() plt.close() img_list, img_pred = list(), list() time_s = time_toc-time_tic time_total += time_s print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s)) print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s". format(device, time_total, time_total/num_img)) if torch.cuda.is_available(): print("GPU name:{}".format(torch.cuda.get_device_name())) ``` 總結一下 inference 階段需要注意的事項: - 確保 model 處於 eval 狀態,而非 trainning 狀態 - 設定 torch.no_grad(),減少記憶體消耗,加快運算速度 - 資料預處理需要保持一致,比如 RGB 或者 rBGR # 殘差連線 以 ResNet 為例:

一個殘差塊有2條路徑 $F(x)$ 和 $x$,$F(x)$ 路徑擬合殘差,不妨稱之為殘差路徑;$x$ 路徑為`identity mapping`恆等對映,稱之為`shortcut`。圖中的⊕為`element-wise addition`,要求參與運算的 $F(x)$ 和 $x$ 的尺寸要相同。 `shortcut` 路徑大致可以分成 2 種,取決於殘差路徑是否改變了`feature map`數量和尺寸。 - 一種是將輸入`x`原封不動地輸出。 - 另一種則需要經過 $1×1$ 卷積來升維或者降取樣,主要作用是將輸出與 $F(x)$ 路徑的輸出保持`shape`一致,對網路效能的提升並不明顯。 兩種結構如下圖所示:

`ResNet` 中,使用了上面 2 種 `shortcut`。 # 網路結構 ResNet 有很多變種,包括 `ResNet 18`、`ResNet 34`、`ResNet 50`、`ResNet 101`、`ResNet 152`,網路結構對比如下:
`ResNet` 的各個變種,資料處理大致流程如下: - 輸入的圖片形狀是 $3 \times 224 \times 224$。 - 圖片經過 `conv1` 層,輸出圖片大小為 $ 64 \times 112 \times 112$。 - 圖片經過 `max pool` 層,輸出圖片大小為 $ 64 \times 56 \times 56 $。 - 圖片經過 `conv2` 層,輸出圖片大小為 $ 64 \times 56 \times 56$。**(注意,圖片經過這個 `layer`, 大小是不變的)** - 圖片經過 `conv3` 層,輸出圖片大小為 $ 128 \times 28 \times 28$。 - 圖片經過 `conv4` 層,輸出圖片大小為 $ 256 \times 14 \times 14$。 - 圖片經過 `conv5` 層,輸出圖片大小為 $ 512 \times 7 \times 7$。 - 圖片經過 `avg pool` 層,輸出大小為 $ 512 \times 1 \times 1$。 - 圖片經過 `fc` 層,輸出維度為 $ num_classes$,表示每個分類的 `logits`。 下面,我們稱每個 `conv` 層為一個 `layer`(第一個 `conv` 層就是一個卷積層,因此第一個 `conv` 層除外)。 其中 `ResNet 18`、`ResNet 34` 的每個 `layer` 由多個 `BasicBlock` 組成,只是每個 `layer` 裡堆疊的 `BasicBlock` 數量不一樣。 而 `ResNet 50`、`ResNet 101`、`ResNet 152` 的每個 `layer` 由多個 `Bottleneck` 組成,只是每個 `layer` 裡堆疊的 `Bottleneck` 數量不一樣。 # 原始碼分析 我們來看看各個 `ResNet` 的原始碼,首先從建構函式開始。 ## 建構函式 ### ResNet 18 `resnet18` 的建構函式如下。 `[2, 2, 2, 2]` 表示有 4 個 `layer`,每個 layer 中有 2 個 `BasicBlock`。 `conv1`為 1 層,`conv2`、`conv3`、`conv4`、`conv5`均為 4 層(每個 `layer` 有 2 個 `BasicBlock`,每個 `BasicBlock` 有 2 個卷積層),總共為 16 層,最後一層全連線層,$ 總層數 = 1+ 4 \times 4 + 1 = 18$,依此類推。 ``` def resnet18(pretrained=False, progress=True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recog