1. 程式人生 > 實用技巧 >神經網路量化入門--量化感知訓練

神經網路量化入門--量化感知訓練

上一篇文章介紹了後訓練量化的基本流程,並用 pytorch 演示了最簡單的後訓練量化演算法。

後訓練量化雖然操作簡單,並且大部分推理框架都提供了這類離線量化演算法 (如 tensorrtncnnSNPE 等),但有時候這種方法並不能保證足夠的精度,因此本文介紹另一種比後訓練量化更有效地量化方法——量化感知訓練。

量化感知訓練,顧名思義,就是在量化的過程中,對網路進行訓練,從而讓網路引數能更好地適應量化帶來的資訊損失。這種方式更加靈活,因此準確性普遍比後訓練量化要高。當然,它的一大缺點是操作起來不方便,這一點後面會詳談。

同樣地,這篇文章會講解最簡單的量化訓練演算法流程,並沿用之前文章的程式碼框架,用 pytorch 從零構建量化訓練演算法的流程。

量化訓練的困難

要理解量化訓練的困難之處,需要了解量化訓練相比普通的全精度訓練有什麼區別。為了看清這一點,我們回顧一下上一篇文章中卷積量化的程式碼:

class QConv2d(QModule):

    def forward(self, x):
        if hasattr(self, 'qi'):
            self.qi.update(x)

        self.qw.update(self.conv_module.weight.data)

        self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
        self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)

        x = self.conv_module(x)

        if hasattr(self, 'qo'):
            self.qo.update(x)

        return x

這裡面區別於全精度模型的地方在於,我們在卷積運算前先對 weight 做了一遍量化,然後又再反量化成 float。這一步在後訓練量化中其實可有可無,但量化感知訓練中卻是需要的「之前為了程式碼上的一致,我提前把這一步加上去了」

那這一步有什麼特別嗎?可以回顧一下量化的具體操作:

def quantize_tensor(x, scale, zero_point, num_bits=8, signed=False):
    if signed:
        qmin = - 2. ** (num_bits - 1)
        qmax = 2. ** (num_bits - 1) - 1
    else:
        qmin = 0.
        qmax = 2.**num_bits - 1.
 
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    
    return q_x.float()

這裡面有個 round 函式,而這個函式是沒法訓練的。它的函式影象如下:

這個函式幾乎每一處的梯度都是 0,如果網路中存在該函式,會導致反向傳播的梯度也變成 0。

可以看個例子:

conv = nn.Conv2d(3, 1, 3, 1)

def quantize(weight):
    w = weight.round()
    return w

class QuantConv(nn.Module):

    def __init__(self, conv_module):
        super(QuantConv, self).__init__()
        self.conv_module = conv_module

    def forward(self, x):
        return F.conv2d(x, quantize(self.conv_module.weight), self.conv_module.bias, 3, 1)


x = torch.randn((1, 3, 4, 4))

quantconv = QuantConv(conv)

a = quantconv(x).sum().backward()

print(quantconv.conv_module.weight.grad)

這個例子裡面,我將權重 weight 做了一遍 round 操作後,再進行卷積運算,但返回的梯度全是 0:

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

換言之,這個函式是沒法學習的,從而導致量化訓練進行不下去。

Straight Through Estimator

那要怎麼解決這個問題呢?

一個很容易想到的方法是,直接跳過偽量化的過程,避開 round。直接把卷積層的梯度回傳到偽量化之前的 weight 上。這樣一來,由於卷積中用的 weight 是經過偽量化操作的,因此可以模擬量化誤差,把這些誤差的梯度回傳到原來的 weight,又可以更新權重,使其適應量化產生的誤差,量化訓練就可以正常進行下去了。

這個方法就叫做 Straight Through Estimator(STE)。

pytorch實現

本文的相關程式碼都可以在 https://github.com/Jermmy/pytorch-quantization-demo 上找到。

偽量化節點實現

上面講完量化訓練最基本的思路,下面我們繼續沿用前文的程式碼框架,加入量化訓練的部分。

首先,我們需要修改偽量化的寫法,之前的程式碼是直接對 weight 的數值做了偽量化:

self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)

這在後訓練量化裡面沒有問題,但在 pytorch 中,這種寫法是沒法回傳梯度的,因此量化訓練裡面,需要重新修改偽量化節點的寫法。

另外,STE 需要我們重新定義反向傳播的梯度。因此,需要藉助 pytorch 中的 Function 介面來重新定義偽量化的過程:

from torch.autograd import Function

class FakeQuantize(Function):

    @staticmethod
    def forward(ctx, x, qparam):
        x = qparam.quantize_tensor(x)
        x = qparam.dequantize_tensor(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

這裡面的 forward 函式,和之前的寫法是類似的,就是把數值量化之後再反量化回去。但在 backward 中,我們直接返回了後一層傳過來的梯度 grad_output,相當於直接跳過了偽量化這一層的梯度計算,讓梯度直接流到前一層 (Straight Through)。

pytorch 定義 backward 函式的返回變數需要與 forward 的輸入引數對應,分別表示對應輸入的梯度。由於 qparam 只是統計 min、max,不需要梯度,因此返回給它的梯度是 None

量化卷積程式碼

量化卷積層的程式碼除了 forward 中需要修改偽量化節點外,其餘的和之前的文章基本一致:

class QConv2d(QModule):

    def forward(self, x):
        if hasattr(self, 'qi'):
            self.qi.update(x)
            x = FakeQuantize.apply(x, self.qi)

        self.qw.update(self.conv_module.weight.data)

        x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw),
                     self.conv_module.bias, 
                     stride=self.conv_module.stride,
                     padding=self.conv_module.padding, dilation=self.conv_module.dilation, 
                     groups=self.conv_module.groups)

        if hasattr(self, 'qo'):
            self.qo.update(x)
            x = FakeQuantize.apply(x, self.qo)

        return x

由於我們需要先對 weight 做一些偽量化的操作,根據 pytorch 中的規則,在做卷積運算的時候,不能像之前一樣用 x = self.conv_module(x) 的寫法,而要用 F.conv2d 來呼叫。另外,之前的程式碼中輸入輸出沒有加偽量化節點,這在後訓練量化中沒有問題,但在量化訓練中最好加上,方便網路更好地感知量化帶來的損失。

由於上一篇文章中做量化推理的時候,我發現精度損失不算太重,3 個 bit 的情況下,準確率依然能達到 96%。為了更好地體會量化訓練帶來的收益,我們把量化推理的程式碼再細緻一點,加大量化損失:

class QConv2d(QModule):

    def quantize_inference(self, x):
        x = x - self.qi.zero_point
        x = self.conv_module(x)
        x = self.M * x
        x.round_()      # 多加一個round操作
        x = x + self.qo.zero_point        
        x.clamp_(0., 2.**self.num_bits-1.).round_()
        return x

相比之前的程式碼,其實就是多加了個 round,讓量化推理更接近真實的推理過程。

量化訓練的收益

這裡仍然沿用之前文章裡的小網路,在 mnist 上測試分類準確率。由於量化推理有修改,為了方便對比,我重新跑了一遍後訓練量化的準確率:

bit 1 2 3 4 5 6 7 8
accuracy 10% 47% 83% 96% 98% 98% 98% 98%

接下來,測試一下量化訓練的效果,下面是 bit=3 時輸出的 log:

Test set: Full Model Accuracy: 98%

Quantization bit: 3
Quantize Aware Training Epoch: 1 [3200/60000]   Loss: 0.087867
Quantize Aware Training Epoch: 1 [6400/60000]   Loss: 0.219696
Quantize Aware Training Epoch: 1 [9600/60000]   Loss: 0.283124
Quantize Aware Training Epoch: 1 [12800/60000]  Loss: 0.172751
Quantize Aware Training Epoch: 1 [16000/60000]  Loss: 0.315173
Quantize Aware Training Epoch: 1 [19200/60000]  Loss: 0.302261
Quantize Aware Training Epoch: 1 [22400/60000]  Loss: 0.218039
Quantize Aware Training Epoch: 1 [25600/60000]  Loss: 0.301568
Quantize Aware Training Epoch: 1 [28800/60000]  Loss: 0.252994
Quantize Aware Training Epoch: 1 [32000/60000]  Loss: 0.138346
Quantize Aware Training Epoch: 1 [35200/60000]  Loss: 0.203350

...

Test set: Quant Model Accuracy: 90%

總的實驗結果如下:

bit 1 2 3 4 5 6 7 8
accuracy 10% 63% 90% 97% 98% 98% 98% 98%

用曲線把它們 plot 在一起:

灰色線是量化訓練,橙色線是後訓練量化,可以看到,在 bit = 2、3 的時候,量化訓練能帶來很明顯的提升。

在 bit = 1 的時候,我發現量化訓練回傳的梯度為 0,訓練基本失敗了。這是因為 bit = 1 的時候,整個網路已經退化成一個二值網路了,而低位元量化訓練本身不是一件容易的事情,雖然我們前面用 STE 解決了梯度的問題,但由於低位元會使得網路的資訊損失巨大,因此通常的訓練方式很難起到作用。

另外,量化訓練本身存在很多 trick,在這個實驗中我發現,學習率對結果的影響非常顯著,尤其是低位元量化的時候,學習率太高容易導致梯度變為 0,導致量化訓練完全不起作用「一度以為程式碼出錯」。

量化訓練部署

前面說過,量化訓練雖然收益明顯,但實際應用起來卻比後訓練量化麻煩得多。

目前大部分主流推理框架在處理後訓練量化時,只需要使用者把模型和資料扔進去,就可以得到量化模型,然後直接部署。但很少有框架支援量化訓練。目前量化訓練缺少統一的規範,各家推理引擎的量化演算法雖然本質一樣,但很多細節處很難做到一致。而目前大家做模型訓練的前端框架是不統一的「當然主流還是 tf 和 pytorch」,如果各家的推理引擎需要支援不同前端的量化訓練,就需要針對不同的前端框架,按照後端部署的實現規則「比如哪些層的量化需要合併、weight 是否採用對稱量化等」,從頭再搭一套量化訓練框架,這個工作量想想就嚇人。

總結

這篇文章主要介紹了量化訓練的基本方法,並用 pytorch 構建了一個簡單的量化訓練例項。下一篇文章會介紹這系列教程的最後一篇文章——關於 fold BatchNorm 相關的知識。

參考