1. 程式人生 > >c++的矩陣乘法加速trick

c++的矩陣乘法加速trick

最近讀RNNLM的原始碼,發現其實現矩陣乘法時使用了一個trick,這裡描述一下這個trick。

首先是正常版的矩陣乘法(其實是矩陣乘向量)

void matrixXvector(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
    for(int row=0;row<srcmatrix_rownum;++row){
        destvect[row]=0;
        for(int col=0;col<srcmatrix_colnum;++col){
            destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];
        }
    }
}

就是最簡單的for迴圈,逐行逐列遍歷。

接下來是RNNLM中實現的trick版本

void matrixXvector2(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
    int row, col;
    float val1, val2, val3, val4;
    float val5, val6, val7, val8;
    
    for(row=0;row<srcmatrix_rownum/8
;++row){ val1 = 0; val2 = 0; val3 = 0; val4 = 0; val5 = 0; val6 = 0; val7 = 0; val8 = 0; for(col=0;col<srcmatrix_colnum;++col){ val1+=srcmatrix[(row*8+0)*srcmatrix_colnum+col]*srcvect[col]; val2+=srcmatrix[(row*8
+1
)*srcmatrix_colnum+col]*srcvect[col]; val3+=srcmatrix[(row*8+2)*srcmatrix_colnum+col]*srcvect[col]; val4+=srcmatrix[(row*8+3)*srcmatrix_colnum+col]*srcvect[col]; val5+=srcmatrix[(row*8+4)*srcmatrix_colnum+col]*srcvect[col]; val6+=srcmatrix[(row*8+5)*srcmatrix_colnum+col]*srcvect[col]; val7+=srcmatrix[(row*8+6)*srcmatrix_colnum+col]*srcvect[col]; val8+=srcmatrix[(row*8+7)*srcmatrix_colnum+col]*srcvect[col]; } destvect[row*8+0]+=val1; destvect[row*8+1]+=val2; destvect[row*8+2]+=val3; destvect[row*8+3]+=val4; destvect[row*8+4]+=val5; destvect[row*8+5]+=val6; destvect[row*8+6]+=val7; destvect[row*8+7]+=val8; } for(row=row*8;row<srcmatrix_rownum;++row){ for(col=0;col<srcmatrix_colnum;++col){ destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col]; } } }

對比普通版,trick版把遍歷行的for迴圈分成了8份,同時進行列遍歷。

實際測試中,這個trick版比普通版快了接近2倍~這是編譯器優化造成的麼……?

參考:http://www.cnblogs.com/plwang1990/p/4139357.html