矩陣乘法優化演算法
阿新 • • 發佈:2019-01-09
本篇文章大部分思路與程式碼都來自於微信公眾號“CPP開發者”中2016年4月11日的文章《矩陣相乘優化演算法實現講解》,基本相當於這篇文章的重點重述。
矩陣是什麼以及矩陣乘法是怎麼操作的,我想點開這篇文章的人都應該知道了,這裡就不再贅述了。
首先回顧一下我們最樸素的演算法:
//計算矩陣a乘矩陣b,將結果存入c;p是第一個矩陣的行數,q是第二個矩陣的行數,r是第二個矩陣的列數 void mult(int a[MAXN][MAXN],int b[MAXN][MAXN],int c[MAXN][MAXN],int p,int q,int r) { int i,j,k; //先對c進行初始化 for(i=0;i<p;i++) { for(j=0;j<r;j++) { c[i][j] = 0; } } //計算矩陣乘法 for(i=0;i<p;i++) { for(j=0;j<r;j++) { for(k=0;k<q;k++) { c[i][j] += a[i][k] * b[k][j]; } } } }
這個演算法就是直接模擬矩陣乘法的定義,時間複雜度是O(n^3),同時也是Ω(n^3)。
接下來介紹優化演算法:
這個優化演算法的最差時間複雜度也是O(n^3),但是對於矩陣中零比較多的情況會有所改善。
基本思路是遍歷其中一個矩陣的所有元素,計算所有結果中用到這個元素的部分。如果這個元素是零,那麼就沒有必要計算了,略過去。這麼說可能不清楚,所以還是還是那個程式碼吧。
int mult(int a[MAXN][MAXN],int b[MAXN][MAXN],int c[MAXN][MAXN],int p,int q,int r) { int i,j,k; for(i=0;i<p;i++) { for(j=0;j<r;j++) { c[i][j] = 0; } } for(i=0;i<p;i++) { for(k=0;k<q;k++) { if(a[i][k]!=0) //如果該元素是零,就省去以下計算 { for(j=0;j<r;j++) { c[i][j] += a[i][k] * b[k][j]; } } } } }
比起其他最差時間複雜度有有效降低的演算法,這一優化演算法更便於實現,而且對於零比較多的矩陣會有很好的效果。