1. 程式人生 > 實用技巧 >詳解卷積中的Winograd加速演算法

詳解卷積中的Winograd加速演算法

6. 何時開啟WinoGrad卷積

和Sgemm用於卷積一樣,我們也需要思考WinoGrad在何種情況下是適用的,或者說是有明顯加速的。這篇文章介紹的WinoGrad卷積是針對NCHW這種記憶體排布的,然後我們來看一下NCNN在基於NCHW這種記憶體排佈下,是在何種情況下啟用WinoGrad()?

通過檢視NCNN的原始碼(https://github.com/Tencent/ncnn/blob/master/src/layer/arm/convolution_arm.cpp)可以發現,只有在輸入輸出通道均>=16,並且特徵圖長寬均小於等於120的條件下才會啟用WinoGrad卷積。

那麼這個條件是如何得出的,除了和手工優化的conv3x3s1

(https://github.com/msnh2012/Msnhnet/blob/master/src/layers/arm/MsnhConvolution3x3s1.cpp)在不同條件下做速度對比測試之外,我們也可以感性的分析一下。

第一,WinoGrad演算法設計到幾個矩陣變換,如果計算量不大,這幾個矩陣變換的成本佔計算總成本的比例就越大,所以WinoGrad應當是在計算量比較大時才能有效,如VGG16。

第二,當計算量比較大的時候,又要考慮到Cache命中率的問題,這個時候WinoGrad訪存可能會比直接手動優化更差,導致速度上不去。


詳解卷積中的Winograd加速演算法

「GiantPandaCV導語」:這篇文章為大家介紹一下用來加速卷積運算的WinoGrad演算法的原理,工程實現以及相關優化思路,如果你對卷積加速演算法感興趣可以看看這篇文章。演算法的完整實現請到MsnhNet的github倉庫檢視,地址為:https://github.com/msnh2012/Msnhnet

1. 為什麼會引入WinoGrad?

做過ACM/OI的朋友大家應該對FFT並不陌生,我們知道對於兩個序列的乘法通過FFT可以從原始O(n^2)複雜度變成O(nlogn),所以我們就會想著FFT這個演算法是否可以應用到我們計算卷積中來呢?當然是可以的,但是FFT的計算有個問題哦,會引入複數。而移動端是不好處理複數的,對於小卷積核可能減少的計算量和複數運算帶來的降速效果是不好說誰會主導的。所以在這種情況下,針對卷積的WinoGrad演算法出現了,它不僅可以類似FFT一樣降低計算量,它還不會引入複數,使得卷積的運算加速成為了可能。因此,本文嘗試從工程實現的角度來看一下WinoGrad,希望對從事演算法加速的小夥伴有一些幫助。

2. 為什麼會有這篇文章?

最近嘗試給MsnhNet做卷積的WinoGrad實現,然後開始瞭解這個演算法,並嘗試參考著NCNN來理解和動手寫一下。參考了多篇優秀的講解文章和NCNN原始碼,感覺算是對這個演算法有了較為清楚的認識,這篇文章就記錄一下我在實現並且步長為的WinoGrad卷積時的一些理解。這篇文章的重點是WinoGrad卷積的實現,關於WinoGrad卷積裡面的變化矩陣如何推導可以看樑德澎作者的文章:詳解Winograd變換矩陣生成原理 (聽說後續他會做個視訊來仔細講講QAQ),現在就假設我們知道了WinoGrad的幾個變換矩陣。如果你不知道也沒關係,因為有一個Python工具包可以直接幫我們計算,地址為:https://github.com/andravin/wincnn 。然後現在我們就要用拿到的這幾個矩陣來實現WinoGrad演算法,聽起來比較簡單,但我們還是得一步步理清楚是不。

3. WinoGrad演算法原理

WinoGrad演算法起源於1980年,是Shmuel Winograd提出用來減少FIR濾波器計算量的一個演算法。它指出,對於輸出個數為,引數個數為的FIR濾波器,不需要次乘法計算,而只需要次乘法計算即可。

下面是一個經典例子,以1維卷積為例,輸入訊號,卷積核,則卷積可以寫成如下矩陣乘法形式:

式子1

如果這個計算過程使用普通的矩陣乘法,則一共需要「6次乘法和4次加法」 。

但是,我們仔細觀察一下,卷積運算中輸入訊號轉換得到的矩陣不是任意矩陣,其有規律的分佈著大量的重複元素,例如第一行的和,卷積轉換成的矩陣乘法比一般乘法的問題域更小,所以這就讓優化存為了可能。

然後WinoGrad的做法就是:

式子2

其中,

等式3

我們知道,在CNN的推理階段,卷積核上的元素是固定的,所以上式中和相關的式子可以提前算好,在預測階段只用計算一次,可以忽略。所以這裡一共需要「4次乘法加4次加法」。

相比於普通的矩陣乘法,使用WinoGrad演算法之後乘法次數減少了,這樣就可以達到加速的目的了。

這個例子實際上是「1D的WinoGrad演算法」,我們將上面的計算過程寫成矩陣的形式如下:

式子4

其中,表示element-wise multiplication(Hadamard product)對應位置相乘。其中,

相關矩陣解釋
  • :表示卷積核
  • :表示輸入訊號
  • :卷積核變換矩陣,尺寸為
  • :輸入變換矩陣,尺寸
  • :輸出變換矩陣,尺寸

所以整個計算過程可以分為4步:

  • 輸入變換
  • 卷積核變換
  • 外積
  • 輸出變換

然後我們將1D的WinoGrad擴充套件到2D,就可以實現卷積的加速了,那麼如何從1維擴充套件到2維呢?公式如下:

式子5

其中,為的卷積核,為的影象塊,我們把上面的擴充套件到,先寫成矩陣乘法的方式:

F(2x2,3x3) 圖片來自https://www.cnblogs.com/shine-lee/p/10906535.html

上圖表示我們將卷積核的元素拉成了一列,將輸入訊號每個滑動視窗中的元素拉成了一行。注意圖中紅線分成的矩陣塊,每個矩陣塊中重複元素的位置與一維相同,即:

二維和一維的WinoGrad矩陣關係

然後,令,即影象視窗中的第0行元素,然後表示第行,,然後可以推導:

2D WinoGrad矩陣形式計算推導

在上面的推導中,表示長度為4的和長度為的卷積結果,結果為長度為2的列向量,其中和均為長度為4的列向量。

進一步,可以看成3對長度為4的列向量兩兩對應位置相乘再相加,結果為長度為4的列向量,也可以看成是4組長度為3的行向量的點積運算。

同樣,也是3對長度為4的列向量的內積運算。

然後類似1D WinoGrad演算法,我們考慮兩者的重疊部分和,剛好對應1D WinoGrad中的每一行在的對應行上進行1維卷積,基於上面推導的1D WinoGrad公式,行向量的卷積只需要將所有左乘的變換矩陣轉置後變成右乘即可。

然後上面的推導就做完了。

下圖表示2D WinoGrad的示意圖:

2D WinoGrad示意圖

這個時候,WinoGrad演算法的乘法次數為,而如果直接卷積乘法次數為,「降低了2.25倍的乘法計算複雜度」。

4. 從工程角度來看WinoGrad

下面我們就從一個實際例子來說,如何利用WinoGrad來實現並且步長為1的卷積運算。基於上面介紹的2D WinoGrad的原理,我們現在只需要分4步即可實現WnoGrad演算法:

  • 第一步就是對輸入卷積核的變換:
  • 第二步就是對輸入資料的變換:
  • 第三步就是對M矩陣的計算:
  • 最後一步就是結果的計算:

接下來我們就以WinoGrad實現並且步長為1的卷積計算為例子,來理解一下WinoGrad的工程實現。

4.1 對輸入卷積核進行變換

這一步就是對卷積核進行變化,公式為:,其中表示輸出通道標號,表示輸入通道標號,一個對應卷積核的一個。由於我們要實現的是,因此是一個的矩陣,我們不難寫出這部分程式碼(其中,矩陣可以通過https://github.com/andravin/wincnn 這個工具進行計算):

// 矩陣G
        const float ktm[8][3] = {
            {1.0f,      0.0f,      0.0f},
            {-2.0f / 9, -2.0f / 9, -2.0f / 9},
            {-2.0f / 9, 2.0f / 9, -2.0f / 9},
            {1.0f / 90, 1.0f / 45, 2.0f / 45},
            {1.0f / 90, -1.0f / 45, 2.0f / 45},
            {1.0f / 45, 1.0f / 90, 1.0f / 180},
            {1.0f / 45, -1.0f / 90, 1.0f / 180},
            {0.0f, 0.0f, 1.0f}
        };

        const int kernelTmSize = inChannel * 8 * 8;

#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif
        for(int outc = 0; outc < outChannel; outc++){
            for(int inc = 0; inc < inChannel; inc++){
                const float* kernel0 = (const float*)kernel + outc * inChannel * 9 + inc * 9;
                float *kernel_tm0 = kernel_tm + outc * kernelTmSize + inc * 64;

                //需要變換的卷積核
                const float* k0 = kernel0;
                const float* k1 = kernel0 + 3;
                const float* k2 = kernel0 + 6;

                float tmpG[8][3];    // tmp = G*g
                for(int i = 0; i < 8; i++){
                    tmpG[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
                    tmpG[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
                    tmpG[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
                }

                //U = kernel_tm0 = G*g*G^T
                //[8*3] x [3*8]
                for(int i = 0; i < 8; i++){
                    float *tmpPtr = &tmpG[i][0];
                    for(int j = 0; j < 8; j++){
                        kernel_tm0[i * 8 + j] = tmpPtr[0] * ktm[j][0] + tmpPtr[1] * ktm[j][1] + tmpPtr[2] * ktm[j][2];
                    }
                }

            }
        }

通過這段程式碼,所有的卷積核都被轉換成了U,存放在了kernel_tm上,一行代表一個,kernel_tm的記憶體排布如下圖所示:

U_{k,c}的記憶體排布

其中W=64的原因是因為F(6x6,3x3)需要每一個輸入影象塊(tile)的大小為,權重塊也對應,這樣才可以做卷積運算(eltwise_mult)。

然後上次我們講到資料Pack的優勢詳解Im2Col+Pack+Sgemm策略更好的優化卷積運算,所以這裡仍然使用NCNN的Pack策略來獲得更好的訪存,即將上面的kernel_tm進行一次重排,將維度全部壓到維度上,另外再對維度做一個額外的4倍壓縮,來獲得更好的訪存。

將H的維度全部壓到維度上示意圖:

將kernel_tm的H維度全部壓到W維度變成一個扁平的Blob

然後在這個基礎上,將C維度進行進一步壓縮,這個時候還需要注意的是對於每一個輸出通道,我們在這個平面上是同時拿出了2行,也就是拿出了128個數據,然後進行交織排列,最後獲得kernel_tm2。這裡以輸出通道的前4個為例,即剛好處理8個U矩陣之後結果矩陣kernel_tm2應該是長什麼樣子,如下圖所示:

![Pack策略之後的矩陣kernel_tm2就長這個樣子

這部分的程式碼實現如下:

int nnOutchannel = outChannel >> 2;
        int remainOutChannel = nnOutchannel << 2;
        
        int packOutChannel = nnOutchannel + (outChannel % 4 + 3) / 4;
        int packOutH = 1;
        int packOutW = (8 * 8 * inChannel * 4);

        //float *kernel_tm2 = new float[packOutChannel * packOutH * packOutW];

#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif       
        for(int cc = 0; cc < nnOutchannel; cc++){
            int c = cc << 2;
            float *ktm2 = kernel_tm2 + cc * packOutH * packOutW;
            
            const float *kernel0_tm = kernel_tm + c * kernelTmSize;
            const float *kernel1_tm = kernel_tm + (c + 1) * kernelTmSize;
            const float *kernel2_tm = kernel_tm + (c + 2) * kernelTmSize;
            const float *kernel3_tm = kernel_tm + (c + 3) * kernelTmSize;

            int q = 0;

            for(; q + 1 < inChannel; q += 2){
                const float *k00 = kernel0_tm + q * 64;
                const float *k01 = kernel0_tm + (q + 1) * 64;
                const float *k10 = kernel1_tm + q * 64;
                const float *k11 = kernel1_tm + (q + 1) * 64;
                const float *k20 = kernel2_tm + q * 64;
                const float *k21 = kernel2_tm + (q + 1) * 64;
                const float *k30 = kernel3_tm + q * 64;
                const float *k31 = kernel3_tm + (q + 1) * 64;

                for(int i = 0; i < 16; i++){
                    for(int j = 0; j < 4; j++){
                        ktm2[0 + j] = k00[j];
                        ktm2[4 + j] = k01[j];
                        ktm2[8 + j] = k10[j];
                        ktm2[12 + j] = k11[j];
                        ktm2[16 + j] = k20[j];
                        ktm2[20 + j] = k21[j];
                        ktm2[24 + j] = k30[j];
                        ktm2[28 + j] = k31[j];
                    }

                    k00 += 4;
                    k01 += 4;
                    k10 += 4;
                    k11 += 4;
                    k20 += 4;
                    k21 += 4;
                    k30 += 4;
                    k31 += 4;
                    ktm2 += 32;

                }
            }

            //inChannel方向的拖尾部分
            for(; q < inChannel; q++){
                const float *k00 = kernel0_tm + q * 64;
                const float *k10 = kernel1_tm + q * 64;
                const float *k20 = kernel2_tm + q * 64;
                const float *k30 = kernel3_tm + q * 64;

                for(int i = 0; i < 16; i++){
                    for(int j = 0; j < 4; j++){
                        ktm2[0 + j] = k00[j];
                        ktm2[4 + j] = k10[j];
                        ktm2[8 + j] = k20[j];
                        ktm2[12 + j] = k30[j];
                    }

                    k00 += 4;
                    k10 += 4;
                    k20 += 4;
                    k30 += 4;
                    ktm2 += 16;
                }
            }

        }

#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif      
        for(int cc = remainOutChannel; cc < outChannel; cc++){
            float *ktm2 = kernel_tm2  + nnOutchannel * packOutH * packOutW + 8 * 8 * inChannel * (cc - remainOutChannel);

            const float* kernel0_tm = kernel_tm + cc * kernelTmSize;

            int q = 0;

            for(; q < inChannel; q++){
                const float* k00 = kernel0_tm + q * 64;
                for(int i = 0; i < 16; i++){
                    for(int j = 0; j < 4; j++){
                        ktm2[j] = k00[j];
                    }
                    k00 += 4;
                    ktm2 += 4;
                }
            }
        }        

4.2 對輸入資料進行變換

對卷積核進行變換之後,接下來就輪到對輸入矩陣進行變換了,即對V矩陣進行計算,。上面我們已經提到過,對於卷積核獲得的每一個,我們都需要一個對應的的影象塊(tile)和它做卷積運算(eltwise_multiply)。所以這裡我們首先需要確定輸入資料可以被拆成多少個影象塊,並且我們需要為變換矩陣V申請空間,從第三節可知:輸入變換矩陣,尺寸為,即每個小塊的變換矩陣都為,但是輸入特徵圖長寬不一定會被8整除,這個時候就需要對輸入特徵圖進行擴充套件(padding),這部分預處理的程式碼實現如下:

// Vc,b = B^Td_{c,b}B
        
        // 輸出特徵圖如果長寬不夠需要Padding
        int outW = (outWidth + 5) / 6 * 6;
        int outH = (outHeight + 5) / 6 * 6;

        int W = outW + 2;
        int H = outH + 2;
        int Top = 0;
        int Left = 0;
        int Bottom = H;
        int Right = W;
        int PadHeight = Bottom - Top;
        int PadWidth = Right - Left;
        int PadSize = PadHeight * PadWidth;
        float *srcPadding = new float[PadHeight * PadWidth * inChannel];
        PaddingLayerArm now;
        now.padding(src, inWidth, inHeight, inChannel, srcPadding, 0, H - inHeight, 0, W - inWidth, 0);
        
        int w_tm = outW / 6 * 8;
        int h_tm = outH / 6 * 8;
        int tiles = w_tm / 8 * h_tm / 8;

        int src_tm_channel = inChannel;
        int src_tm_h = 16 * w_tm / 8 * h_tm / 8;
        int src_tm_w = 4;
        
        int src_tm_size = src_tm_h * src_tm_w;
        float *src_tm  = new float[src_tm_channel * src_tm_h * src_tm_w];

注意上面src_tm的形狀,這是考慮到了卷積核變換矩陣已經執行了Pack策略,所以這裡為了方便後續的卷積計算和進行指令集加速,同樣將src_tm進行Pack,這個Pack是直接規定計算完之後4個4個岔開儲存的方式來實現的。另外,輸入Blob的一個Channel對應了輸出Blob的一個Channel。

然後我們再通過WinCNN工具可以獲得B矩陣和B的轉置矩陣,並確定V矩陣更好的計算策略(指的是可以複用一些中間變數)。

// BT = 
        // ⎡1   0    -21/4    0    21/4     0    -1  0⎤
        // ⎢                                          ⎥
        // ⎢0   1      1    -17/4  -17/4    1    1   0⎥
        // ⎢                                          ⎥
        // ⎢0   -1     1    17/4   -17/4   -1    1   0⎥
        // ⎢                                          ⎥
        // ⎢0  1/2    1/4   -5/2   -5/4     2    1   0⎥
        // ⎢                                          ⎥
        // ⎢0  -1/2   1/4    5/2   -5/4    -2    1   0⎥
        // ⎢                                          ⎥
        // ⎢0   2      4    -5/2    -5     1/2   1   0⎥
        // ⎢                                          ⎥
        // ⎢0   -2     4     5/2    -5    -1/2   1   0⎥
        // ⎢                                          ⎥
        // ⎣0   -1     0    21/4     0    -21/4  0   1⎦

        //B = 
        // ⎡1     0     0    0    0   0     0   0    ⎤
     // ⎢0     1     -1    1/2    -1/2   2    -2   -1   ⎥
     // ⎢-21/4 1     1    1/4    1/4   4     4   0    ⎥
     // ⎢0     -17/4 17/4   -5/2    5/2   -5/2 5/2   21/4 ⎥
     // ⎢21/4 -17/4 -17/4  -5/4   -5/4   -5 -5   0    ⎥   
     // ⎢0     1     -1    2    2   1/2 -1/2  -21/4⎥
     // ⎢-1     1     1    1    1   1     1   0    ⎥
     // ⎢0     0     0    0    0   0     0   1    ⎥


        // 0 = r00 - r06 + (r04 - r02) * 5.25
        // 7 = r07 - r01 + (r03 - r05) * 5.25

        // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
        // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)

        // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
        // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)

        // reuse r04 * 1.25
        // reuse r03 * 2.5
        // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
        // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)

接下來我們就可以開始計算V矩陣了,程式碼如下:

#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif
        for(int q = 0; q < inChannel; q++){
            const float *padptr = srcPadding + q * PadSize;
            float *srcptr = src_tm + q * src_tm_size;

            float tmpV[8][8];

            //tile
            for(int i = 0; i < h_tm / 8; i++){
                for(int j = 0; j < w_tm / 8; j++){
                    float *r0 = padptr + i * 6 * PadWidth + j * 6;
                    
                    // Bd_{c,b}
                    for(int m = 0; m < 8; m++){

                        tmpV[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25f;
                        tmpV[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25f;

                        float t1 = (r0[2] + r0[6] - r0[4] * 4.25f);
                        float t2 = (r0[1] + r0[5] - r0[3] * 4.25f);

                        tmpV[1][m] = t1 + t2;
                        tmpV[2][m] = t1 - t2;

                        float t3 = (r0[6] + r0[2] * 0.25f - r0[4] * 1.25f);
                        float t4 = (r0[1] * 0.5f - r0[3] * 2.5f + r0[5] * 2.f);
                        tmpV[3][m] = t3 + t4;
                        tmpV[4][m] = t3 - t4;

                        float t5 = (r0[6] + (r0[2] - r0[4] * 1.25f) * 4.f);
                        float t6 = (r0[1] * 2.f - r0[3] * 2.5f + r0[5] * 0.5f);

                        tmpV[5][m] = t5 + t6;
                        tmpV[6][m] = t5 - t6;

                        r0 += PadWidth;
                    }

                    //Bd_{c,b}B^T
                    float *r00 = srcptr + (i * w_tm / 8 + j) * src_tm_w;
                    float *r04 = srcptr + (i * w_tm /8 + j + tiles) * src_tm_w;

                    for(int m = 0; m < 8; m++){
                        float* tmpVPtr = tmpV[m];
                        r00[0] = tmpVPtr[0] - tmpVPtr[6] + (tmpVPtr[4] - tmpVPtr[2]) * 5.25f;
                        r04[3] = tmpVPtr[7] - tmpVPtr[1] + (tmpVPtr[3] - tmpVPtr[5]) * 5.25f;
                        
                        float t1 =  (tmpVPtr[2] + tmpVPtr[6] - tmpVPtr[4] * 4.25f);
                        float t2 =  (tmpVPtr[1] - tmpVPtr[3] * 4.25f + tmpVPtr[5]);
                        r00[1] = t1 + t2;
                        r00[2] = t1 - t2;

                        float t3 = (tmpVPtr[6] + tmpVPtr[2] * 0.25f - tmpVPtr[4] * 1.25);
                        float t4 = (tmpVPtr[1] * 0.5f - tmpVPtr[3] * 2.5f + tmpVPtr[5] * 2.f);
                        r00[3] = t3 + t4;
                        r04[0] = t3 - t4;

                        float t5 = (tmpVPtr[6] + (tmpVPtr[2] - tmpVPtr[4] * 1.25f) * 4.f);
                        float t6 = (tmpVPtr[1] * 2.f - tmpVPtr[3] * 2.5f + tmpVPtr[5] * 0.5f);

                        r04[1] = t5 + t6;
                        r04[2] = t5 - t6;

                        r00 += 2 * tiles * src_tm_w;
                        r04 += 2 * tiles * src_tm_w;

                    }

                }
            }
        }

        delete [] srcPadding;

可以看到這個地方不僅計算了V矩陣,並在儲存時就對V矩陣進行了重新排列,以適應卷積核變化矩陣的Pack結果,方便後面進行卷積計算的加速同時獲得更好的訪存,這個過程如下圖所示:

對輸入矩陣進行變換的過程

4.3 計算M矩陣

M矩陣的計算公式為:

其中,k代表輸出通道數,b表示tile序號。

由於上面輸入影象塊已經執行了Pack策略,這裡只需要將對應小塊進行乘加操作即完成了M矩陣的計算,這部分的程式碼實現如下:

#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif
        for(int cc = 0; cc < nnOutChannel; cc++){
            int c = cc * 4;
            float *dest0 = dest_tm + c * dst_tm_size;
            float *dest1 = dest_tm + (c + 1) * dst_tm_size;
            float *dest2 = dest_tm + (c + 2) * dst_tm_size;
            float *dest3 = dest_tm + (c + 3) * dst_tm_size;

            const float *ktm = kernel + cc * kernelSize;
            int q = 0;
            
            for(; q + 1 < inChannel; q += 2){
                const float* r0 = src_tm + q * src_tm_size;
                const float* r1 = src_tm + (q + 1) * src_tm_size;
                
                float* destptr0 = dest0;
                float *destptr1 = dest1;
                float *destptr2 = dest2;
                float *destptr3 = dest3;

                for(int r = 0; r < 16; r++){
                    for(int t = 0; t < tiles; t++){
                        for(int m = 0; m < 4; m++){
                            destptr0[m] += r0[m] * ktm[m];
                            destptr0[m] += r1[m] * ktm[m + 4];
                            destptr1[m] += r0[m] * ktm[m + 8];
                            destptr1[m] += r1[m] * ktm[m + 12];
                            destptr2[m] += r0[m] * ktm[m + 16];
                            destptr2[m] += r1[m] * ktm[m + 20];
                            destptr3[m] += r0[m] * ktm[m + 24];
                            destptr3[m] += r1[m] * ktm[m + 28];  
                        }

                        r0 += 4;
                        r1 += 4;
                        destptr0 += 4;
                        destptr1 += 4;
                        destptr2 += 4;
                        destptr3 += 4;
                    }

                    ktm += 32;
                }
            }

            for(; q < inChannel; q++){
                const float *r0 = src_tm + q * src_tm_size;
                float* destptr0 = dest0;
                float *destptr1 = dest1;
                float *destptr2 = dest2;
                float *destptr3 = dest3;

                for(int r = 0; r < 16; r++){
                    for(int t = 0; t < tiles; t++){
                        for(int m = 0; m < 4; m++){
                            destptr0[m] += r0[m] * ktm[m];
                            destptr1[m] += r0[m] * ktm[m + 4];
                            destptr2[m] += r0[m] * ktm[m + 8];
                            destptr3[m] += r0[m] * ktm[m + 12];
                        }

                        r0 += 4;
                        destptr0 += 4;
                        destptr1 += 4;
                        destptr2 += 4;
                        destptr3 += 4;
                    }

                    ktm += 16;
                }
            }
        }
        
#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif
        for(int cc = remainOutChannel; cc < outChannel; cc++){
            int c = cc;
            float *dest0 = dest_tm + c * dst_tm_size;
            const float *ktm = kernel + nnOutChannel * kernelSize + 8 * 8 * inChannel * (c - remainOutChannel);

            int q = 0;
            for(; q < inChannel; q++){
                const float* r0 = src_tm + q * src_tm_size;
                float* destptr0 = dest0;
                for(int r = 0; r < 16; r++){
                    for(int i = 0; i < tiles; i++){
                        for(int m = 0; m < 4; m++){
                            destptr0[m] += r0[m] * ktm[m];
                        }

                        r0 += 4;
                        destptr0 += 4;
                    }

                    ktm += 4;
                }
            }
        }

至此,我們獲得了M矩陣,矩陣大概長下面這樣子,它仍然是交錯排列的:

M矩陣長得和V矩陣有點像,主要是通道維度變了

4.4 計算結果Y矩陣

現在就到了最後一步了,我們需要計算結果矩陣Y,公式為:

其中表示輸出通道數,b表示tile標號,這部分和上面卷積核的計算類似,程式碼如下:

// Yk,b=A^TMk,bA
// AT=
// ⎡1  1  1   1    1    1      1    0⎤
// ⎢                                 ⎥
// ⎢0  1  -1  2   -2   1/2   -1/2   0⎥
// ⎢                                 ⎥
// ⎢0  1  1   4    4   1/4    1/4   0⎥
// ⎢                                 ⎥
// ⎢0  1  -1  8   -8   1/8   -1/8   0⎥
// ⎢                                 ⎥
// ⎢0  1  1   16  16   1/16  1/16   0⎥
// ⎢                                 ⎥
// ⎣0  1  -1  32  -32  1/32  -1/32  1⎦

        // 0 = r0 + (r1 + r2) + (r3 + r4)     + (r5 + r6) * 32
        // 1 =      (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
        // 2 =      (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
        // 3 =      (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
        // 4 =      (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
        // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)


        float *dest_tm2 = new float[outW * outH * outChannel];
        const int dst_tm_size2 = outW * outH;
        

        const int outSize = outHeight * outWidth;

#if USE_OMP
    #pragma omp parallel for num_threads(OMP_THREAD)
#endif
        for(int cc = 0; cc < outChannel; cc++){
            float *destptr = dest_tm + cc * dst_tm_size;
            float *outptr = dest_tm2 + cc * dst_tm_size2;

            float tmpA[6][8];

            for(int i = 0; i < outH / 6; i++){
                for(int j = 0; j < outW / 6; j++){
                    float *destptr0 = destptr + (i * w_tm / 8 + j) * dst_tm_w;
                    float *destptr4 = destptr + (i * w_tm / 8 + j + tiles) * dst_tm_w;

                    for(int m = 0; m < 8; m++){

                        float t1 = destptr0[1] + destptr0[2];
                        float t2 = destptr0[1] - destptr0[2];

                        float t3 = destptr0[3] + destptr4[0];
                        float t4 = destptr0[3] - destptr4[0];

                        float t5 = destptr4[1] + destptr4[2];
                        float t6 = destptr4[1] - destptr4[2];

                        tmpA[0][m] = destptr0[0] + t1 + t3 + t5 * 32;
                        tmpA[2][m] = t1 + t3 * 4 + t5 * 8;
                        tmpA[4][m] = t1 + t3 * 16 + t5 + t5;

                        tmpA[1][m] = t2 + t4 + t4 + t6 * 16;
                        tmpA[3][m] = t2 + t4 * 8 + t6 * 4;
                        tmpA[5][m] = destptr4[3] + t2 + t4 * 32 + t6;

                        destptr0 += dst_tm_w * 2 * tiles;
                        destptr4 += dst_tm_w * 2 * tiles;
                    }

                    float *outptr0 = outptr + (i * 6) * outW + j * 6;

                    for(int m = 0; m < 6; m++){

                        const float* tmp0 = tmpA[m];

                        float t1 = tmp0[1] + tmp0[2];
                        float t2 = tmp0[1] - tmp0[2];

                        float t3 = tmp0[3] + tmp0[4];
                        float t4 = tmp0[3] - tmp0[4];

                        float t5 = tmp0[5] + tmp0[6];
                        float t6 = tmp0[5] - tmp0[6];

                        outptr0[0] = tmp0[0] + t1 + t3 + t5 * 32;
                        outptr0[2] = t1 + t3 * 4 + t5 * 8;
                        outptr0[4] = t1 + t3 * 16 + t5 + t5;

                        outptr0[1] = t2 + t4 + t4 + t6 * 16;
                        outptr0[3] = t2 + t4 * 8 + t6 * 4;
                        outptr0[5] = tmp0[7] + t2 + t4 * 32 + t6;

                        outptr0 += outW;
                    }
                }
            }
        } 

這部分程式碼就實現了M矩陣匯聚並利用A矩陣獲得了最終的結果Y。這個過程上一節圖中已經畫了,這裡主要實現的是圖中的右半部分:

Y矩陣匯聚存放獲得輸出Blob

但是需要注意的是這裡獲得的Y有可能是多了幾行或者幾列,也就是拖尾為0的部分,所以需要把這一部分Crop掉,才能獲得我們最終的結果特徵圖。Crop部分的程式碼如下:

//crop
        for(int cc = 0; cc < outChannel; cc++){
            float *outptr = dest_tm2 + cc * dst_tm_size2;
            float *outptr2 = dest + cc * outHeight * outWidth;
            for(int i = 0; i < outHeight; i++){
                for(int j = 0; j < outWidth; j++){
                    outptr2[0] = outptr[0];
                    outptr2++;
                    outptr++;
                }
                outptr += (outW - outWidth);
            }
        }

至此,WinoGrad的演算法流程結束,我們獲得了最後的卷積計算結果。

5. WinoGrad演算法進一步加速

上面無論是針對U,V,M還是Y矩陣的計算我們使用的都是暴力計算,所以接下來可以使用Neon Instrics和Neon Assembly技術進行優化。介於篇幅原因,這裡就不貼程式碼了,有需要學習的可以關注後續MsnhNet的WinoGrad程式碼部分https://github.com/msnh2012/Msnhnet/blob/master/src/layers/arm/MsnhConvolution3x3s1Winograd.cpp。這個程式碼實現的思路取自開源框架NCNN,在此表示感謝NCNN這一優秀工作(github:https://github.com/Tencent/ncnn)。

6. 何時開啟WinoGrad卷積

和Sgemm用於卷積一樣,我們也需要思考WinoGrad在何種情況下是適用的,或者說是有明顯加速的。這篇文章介紹的WinoGrad卷積是針對NCHW這種記憶體排布的,然後我們來看一下NCNN在基於NCHW這種記憶體排佈下,是在何種情況下啟用WinoGrad()?

通過檢視NCNN的原始碼(https://github.com/Tencent/ncnn/blob/master/src/layer/arm/convolution_arm.cpp)可以發現,只有在輸入輸出通道均>=16,並且特徵圖長寬均小於等於120的條件下才會啟用WinoGrad卷積。

那麼這個條件是如何得出的,除了和手工優化的conv3x3s1(https://github.com/msnh2012/Msnhnet/blob/master/src/layers/arm/MsnhConvolution3x3s1.cpp)在不同條件下做速度對比測試之外,我們也可以感性的分析一下。

第一,WinoGrad演算法設計到幾個矩陣變換,如果計算量不大,這幾個矩陣變換的成本佔計算總成本的比例就越大,所以WinoGrad應當是在計算量比較大時才能有效,如VGG16。

第二,當計算量比較大的時候,又要考慮到Cache命中率的問題,這個時候WinoGrad訪存可能會比直接手動優化更差,導致速度上不去。

7. 速度測試

由於筆者還未實現完整Neon Instrics和Assembly部分,所以暫時無法給出速度對比。嘗試從NCNN的BenchMark中找到WinoGrad的加速效果大概是什麼樣的,但只能找到各個網路在各種板子上的整體推理速度,沒有WinoGrad F(6,3)單獨的速度對比,等國慶爭取補上來吧。

8. 結語

關於WinoGrad的原理介紹還有工程實現(基於NCNN)暫時就講到這裡了,有問題歡迎在評論區討論哦。我剛入門移動端優化幾個月還有非常多知識需要學習,nihui,蟲叔,白牛,大老師他們都是高人,這幾個月從他們的文章受益良多,非常感謝!

9. 致謝

  • https://zhuanlan.zhihu.com/p/72149270
  • https://www.cnblogs.com/shine-lee/p/10906535.html
  • https://zhuanlan.zhihu.com/p/81201840