1. 程式人生 > 其它 >殘差網路resnet理解與pytorch程式碼實現

殘差網路resnet理解與pytorch程式碼實現

寫在前面

​ 深度殘差網路(Deep residual network, ResNet)自提出起,一次次重新整理CNN模型在ImageNet中的成績,解決了CNN模型難訓練的問題。何凱明大神的工作令人佩服,模型簡單有效,思想超凡脫俗。

​ 直觀上,提到深度學習,我們第一反應是模型要足夠“深”,才可以提升模型的準確率。但事實往往不盡如人意,先看一個ResNet論文中提到的實驗,當用一個平原網路(plain network)構建很深層次的網路時,56層的網路的表現相比於20層的網路反而更差了。說明網路隨著深度的加深,會更加難以訓練。

​ 圖一:模型退化問題

​ 若模型隨著網路深度的增加,準確率先上升,然後達到飽和,深度增加準確率下降。那麼如果在模型達到飽和時,後面接上幾個恆等變換層,這樣可以保證誤差不會增加,resnet便是這種思想來解決網路退化問題。

第一部分

模型

假設網路的輸入是x, 期望輸出為H(x),我們轉化一下思路,把網路要學到的H(x)轉化為期望輸出H(x)與輸出x之間的差值F(x) = H(x) - x。當殘差接近為0時, 相當於網路在此層僅僅做了恆等變換,而不會使網路的效果下降。

​ 圖二:殘差結構

殘差為什麼容易學習?

此處參考一位知乎大佬的分析(原文在文末有連結),因為網路要學習的殘差項通常比較小:

其中 和 分別表示的是第 個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結構。 是殘差函式,表示學習到的殘差,而 表示恆等對映, 是ReLU啟用函式。基於上式,我們求得從淺層 到深層 的學習特徵為:

利用鏈式規則,可以求得反向過程的梯度:

式子的第一個因子 表示的損失函式到達 的梯度,小括號中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那麼巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導並不是嚴格的證明。

深度殘差網路結構如下:

第二部分

pytorch程式碼實現

# -*- coding:utf-8 -*-
# handwritten digits recognition
# Data: MINIST
# model: resnet
# date: 2021.10.8 14:18

import math
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.utils.data as Data
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt

train_curve = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# param
batch_size = 100
n_class = 10
padding_size = 15
epoches = 10

train_dataset = torchvision.datasets.MNIST('./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST('./data/', train=False, transform=transforms.ToTensor(), download=False)
train = Data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=5)
test = Data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=5)

def gelu(x):
  "Implementation of the gelu activation function by Hugging Face"
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class ResBlock(nn.Module):
  # 殘差塊
  def __init__(self, in_size, out_size1, out_size2):
    super(ResBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels = in_size,
        out_channels = out_size1,
        kernel_size = 3,
        stride = 2,
        padding = padding_size
    )
    self.conv2 = nn.Conv2d(
        in_channels = out_size1,
        out_channels = out_size2,
        kernel_size = 3,
        stride = 2,
        padding = padding_size
    )
    self.batchnorm1 = nn.BatchNorm2d(out_size1)
    self.batchnorm2 = nn.BatchNorm2d(out_size2)
  
  def conv(self, x):
    # gelu效果比relu好呀哈哈
    x = gelu(self.batchnorm1(self.conv1(x)))
    x = gelu(self.batchnorm2(self.conv2(x)))
    return x
  
  def forward(self, x):
    # 殘差連線
    return x + self.conv(x)

# resnet
class Resnet(nn.Module):
  def __init__(self, n_class = n_class):
    super(Resnet, self).__init__()
    self.res1 = ResBlock(1, 8, 16)
    self.res2 = ResBlock(16, 32, 16)
    self.conv = nn.Conv2d(
        in_channels = 16,
        out_channels = n_class,
        kernel_size = 3,
        stride = 2,
        padding = padding_size
    )
    self.batchnorm = nn.BatchNorm2d(n_class)
    self.max_pooling = nn.AdaptiveAvgPool2d(1)

  def forward(self, x):
    # x: [bs, 1, h, w]
    # x = x.view(-1, 1, 28, 28)
    x = self.res1(x)
    x = self.res2(x)
    x = self.max_pooling(self.batchnorm(self.conv(x)))

    return x.view(x.size(0), -1)

resnet = Resnet().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=resnet.parameters(), lr=1e-2, momentum=0.9)

# train
total_step = len(train)
sum_loss = 0
for epoch in range(epoches):
  for i, (images, targets) in enumerate(train):
    optimizer.zero_grad()
    images = images.to(device)
    targets = targets.to(device)
    preds = resnet(images)
    
    loss = loss_fn(preds, targets)
    sum_loss += loss.item()
    loss.backward()
    optimizer.step()
    if (i+1)%100==0:
      print('[{}|{}] step:{}/{} loss:{:.4f}'.format(epoch+1, epoches, i+1, total_step, loss.item()))
  train_curve.append(sum_loss)
  sum_loss = 0

# test
resnet.eval()
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test:
    images = images.to(device)
    labels = labels.to(device)
    outputs = resnet(images)
    _, maxIndexes = torch.max(outputs, dim=1)
    correct += (maxIndexes==labels).sum().item()
    total += labels.size(0)
  
  print('in 1w test_data correct rate = {:.4f}'.format((correct/total)*100))

pd.DataFrame(train_curve).plot() # loss曲線

測試了1萬條測試集樣本結果:

程式碼連結:

jupyter版本:https://github.com/PouringRain/blog_code/blob/main/deeplearning/resnet.ipynb

py版本:https://github.com/PouringRain/blog_code/blob/main/deeplearning/resnet.py

喜歡的話,給萌新的github倉庫一顆小星星哦……^ _^

參考資料:

https://zhuanlan.zhihu.com/p/31852747

https://zhuanlan.zhihu.com/p/80226180