dp方法論——由矩陣相乘問題學習dp解題思路
導語
刷過一些演算法題,就會十分珍惜“方法論”這種東西。Leetcode上只有題目、討論和答案,沒有方法論。往往答案看起來十分切中要害,但是從看題目到得到思路的那一段,就是繞不過去。樓主有段時間曾把這個過程歸結於智商和靈感的結合,直到有天為了搞懂Leetcode上一位老兄的題型總結,花兩天時間學習了回溯法,突然有種驚為天人的感覺——原來真正掌握一個演算法是應該觸類旁通的,而不是將題中一個細節換掉就又成了新題……
掌握方法論絕對是一種很爽的感覺。看起來好像很花費時間,其實是一種“因為慢,所以快”的方法。以前可能你學習一個dp題目要大半天;當你花了半個周時間,學會了dp的套路,你會發現,有些medium的dp題甚至不需要半個小時就能做完,而且從頭到尾不需提示,全靠自己!
方法論
那麼,怎麼從一個看起來毫無頭緒的問題出發,找到解題的思路並用dp將問題解出來呢?本文以矩陣相乘問題為例,給出dp問題的一般解題思路。
當然,按照思路解題的前提是你已經知道這道題要用dp去解,如何確定一個問題可以用dp去解,則是下一篇要討論的話題。
下面就是動態規劃的一般解題思路:
- 分析最優解的特徵。
- 遞迴地定義最優解的值。
- 計算最優解的值。
- 根據計算好的資訊構造最優解。
看起來非常抽象是吧?在這裡不需要完全理解。等你看完全文再回來,保你會有不一樣的感受。
矩陣相乘問題
問題
這是一個看起來可能有點抽象的數學問題,但請你耐心往下看。當你看完解法時,你會驚異於動態規劃的魔力。
題目:給出一個由n個矩陣組成的矩陣鏈<A1,A2,...,An>,矩陣Ai的秩為pi-1×pi。將A1A2...An這個乘積全括號化,使得計算這個乘積所需要的的標量乘法最少。
全括號化是以一種遞迴的形式定義的:
一個全括號化的乘積只有兩種可能:一是一個單個矩陣;二是兩個全括號化的乘積的乘積。
天啦也太繞了,舉個例子吧。對於矩陣鏈<A1,A2,A3,A4>的乘積,共有五種全括號化的方法:
(A1(A2(A3A4))),
(A1((A2A3)A4)),
((A1A2)(A3A4)),
(((A1A2)A3)A4),
((A1(A2A3))A4)
我們知道矩陣乘法是滿足結合律的,所以以上五個式子的乘積相等,但是它們的運算時間是否相等呢?
矩陣乘法的運算時間
我們知道,矩陣乘法的定義是:
兩個互相相容的矩陣A,B可以相乘。互相相容是指A的列數與B的行數相等。假如A是一個p×q的矩陣,而B是一個q×r的矩陣,則乘積C是一個p×r的矩陣且有
cij = ∑ aik·bkj, k = 1,...,q.
由於要對C中的每一個元素進行計算(共q·r個元素),而每次運算要做q次乘法,所以總的運算時間為pqr。
來看看讓乘積中的不同因子結合對運算時間有什麼影響。假設我們有 <A1,A2,A3>這個矩陣鏈,三個矩陣的秩分別為10×100, 100×5和5×50。則
- ((A1A2)A3)的運算時間為10×100×5+10×5×50=7500;
- (A1(A2A3))的運算時間為100×5×50+10×100×50=75000。
按照不同的順序做矩陣乘法,所需要的乘法次數竟相差10倍。
初步分析
按照慣例,我們來感受一下窮舉的演算法複雜度。
假設有一個長度為n的矩陣鏈,我們通過遍歷所有的全括號化的可能性來解題。設全括號化的可能性數目為P(n)。當n為1時,矩陣鏈只有一個矩陣,符合全括號化的定義;當n>=2時,全括號化後為兩個矩陣的乘積,即((...)(...))的形式。用遞迴的思路去分析,則中間兩個括號的分界位置有n-1種可能,如下面豎線所示
A1|A2|A3|...|An
當分界線將矩陣鏈分為長度為k和n-k的兩個子矩陣鏈時,全括號化可能性為P(k)P(n-k)。我們對所有的k值求和,就得出給整個矩陣鏈全括號化的數目:
P(n) = ∑ P(k)P(n-k), k=1...n-1 (n>=2)
這是一個,它的增長速率為Ω(4n/n3/2),它的漸進值為Ω(2n)。
(對漸進值還不太熟,如果有小夥伴明白“增長速率”和“漸進值”之間的關係,歡迎指教。)
總的來說,如果對這個題目使用窮舉法,演算法複雜度是指數的。後面我們分析了dp的演算法複雜度,再來比較。
用dp方法論解題
演算法的學習永遠沒有“手把手”這一說。如果你在認真學習這篇文章,希望你能做到比你看到的小節思路提前一點。比如,在看第一步前,先對這個題目有一點大致思路,明白讓自己迷茫的點在哪裡;看第x步前,對第x步的內容在心中有一個猜測。這樣做比起完全放棄思考,只是跟著文章的思路走,收穫會大很多。
第一步:分析最優解的特徵
這一步的精髓是分析最優子解如何構成最優解。
在上一節中已經提到,對於n>=2的情況,全括號化後為((chain_1)(chain_2))的形式。這樣,問題自然而然地分成了兩個子問題:求前後兩個子括號中的最優解。
假設對於某種特定的分割(即chain_1和chain_2之間的分界線位置固定),chain_1的秩為m×p,其內部的標量乘法數目為x;chain_2的秩為p×n,其內部的標量乘法數目為y。則整個矩陣鏈的乘法次數為x+y+mpn。由於m,p,n是固定的,我們需要讓x和y為最小值從而使整個矩陣鏈的乘法次數最小。即,對於某種特定的分割,兩個子括號中的最優解構成整個問題的最優解的一個選項。
總結來說,我們將矩陣乘積簡略地看成兩個子矩陣鏈的乘積,這兩個子矩陣鏈的分界有n-1種可能。對每一種可能,問題被分割成兩個子問題,即求左右兩個子矩陣鏈的最優解。如果遍歷這n-1種可能並選出最好的一個,那就是整個問題的最優解。
第二步:遞迴地定義最優解的值
第二步非常關鍵,是我們將前後思路打通的一步。
第一步中提出了一個比較簡單的思路,即把矩陣鏈分割成左右兩個子矩陣鏈。既然有了這個初步思路,我們就來塗鴉一番,看看這個思路是否可行。
對於遞迴性的問題,一個很好的方法是畫遞迴樹,這樣會使得問題看起來比較具象,而且也會暴露一些演算法上的問題,比如重疊子樹等。畫遞迴樹的時候,最好舉一個實際的例子。這裡我們假設有一個長度為4的矩陣鏈<A1,A2,A3,A4>,簡單地畫一下它的子問題分割:
上圖中的數字表示子矩陣鏈的長度,根為4,即初始矩陣鏈;它可以分為1+3,2+2,3+1三種情況,這三種情況又可以各自細分。
這裡暴露了一個問題,請看圖中的兩個塗色的子樹。兩個子樹的節點數字是一樣的。但是左邊這個子樹的根節點3代表的是A2A3A4這個乘積;而右邊這個代表的是A1A2A3這個乘積。由於A1,A2,A3,A4四個矩陣的秩是未知的,它們很可能不相同,則A1A2A3和A2A3A4的最優解也很有可能不同。換言之,它們並不是同一個子問題,它們的子子樹也並不相同。
這個問題意味著我們對子問題的定義不夠嚴謹——子問題不能只用長度這個變數來確定。也就是說,如果在bottom-up的dp中用一個數組記錄子問題的值,那麼這個陣列應該是一個二維陣列。子問題不僅應該由子矩陣鏈的長度確定,還要加上起始index這樣的資訊。
為了更通用一些,我們不用起始index+長度,而選用起始index+結束index的定義方法,這是二維dp的慣用套路,在許多字串和陣列有關的問題中都有用到。
設用一個二位矩陣dp[][]存取子問題的解。定義dp[i][j](1<=i<=j<=n)的值為Ai...Aj的最小乘法次數。則按照以上的思路,可以把Ai...Aj再遞迴細分為子問題Ai...Ak和Ak+1...Aj(i<=k<j),則Ai...Aj的最優解值為兩個子問題最優解的和+兩個子矩陣鏈相乘的乘法次數。即有
i==j時,dp[i][j] = 0;
i <j時,dp[i][j] = min{dp[i][k] + dp[k+1][j] + pi-1pkpj}, k = i...j-1 (p為各個矩陣的秩,見題目一節)
到此為止,最關鍵的一步順利完成啦(樓主寫得好累,擊掌╭(○`∀´○)╯╰(○'◡'○)╮)。在這一步中,我們遞迴地定義了子問題最優解的值,完成了演算法最核心的設計部分。在後面兩步中,我們只要把上面這兩個式子翻譯成程式碼,再注意一些實現細節就可以了。
第三步:計算最優解的值
細節一
從第二步順理成章,我們會在一個二維數組裡記錄子問題的解。但是按照什麼順序去填這個二維陣列是個問題。
還是舉例子,在<A1,A2,A3,A4>這個矩陣鏈中,我們會有一個5×5的二維陣列,隨便挑選dp[1][4]這個元素舉例。根據第二步中的狀態轉移方程,有
dp[1][4] = min{(dp[1][1]+dp[2][4]+...),(dp[1][2]+dp[3][4]+...),(dp[1][3]+dp[4][4]+...)}
省略號表示我們此處不需關注pi-1pkpj這一項,只需要看這個格子對其它格子的依賴是什麼樣子。
由上圖可以看出,要計算某一個元素(粉色邊框),我們需要其左邊和下面的元素(同樣深度的藍色表示一組資料)。
所以,我們的遍歷方向是從下到上,從左到右。
細節二
細心的讀者可能注意到還有一個問題,就是我們一直在求“最優解的值”,也就是“最小的乘法次數”,可是題目中要求的是“最優解”,也就是“加括號的方式”。
這兩者並不矛盾,專注於求解前者可以讓我們先思考相對簡單的問題,通常在求解前者的過程中,我們也找出了後者,只是沒有將它記錄下來。
在此題中,我們可以選擇用一個同樣的二維矩陣s[][]來記錄後者,其中s[i][j]中記錄Ai...Aj的分割分界線k。
程式碼
1 int matrixChain(int[] p){ 2 int n = p.length - 1; //number of matrices 3 int[][] dp = new int[n + 1][n + 1]; //we need dp[1][n] 4 int[][] s = new int[n + 1][n + 1]; //for storing of k 5 for(int[] row : dp) 6 Arrays.fill(row, Integer.MAX_VALUE); 7 8 for(int i = 1; i <= n; i++) 9 dp[i][i] = 0; //dp[i][j] = 0 when i == j 10 11 for(int i = n; i >= 1; i--) 12 for(int j = i; j <= n; j++){ 13 if(i == j){ 14 dp[i][j] = 0; 15 }else{ 16 for(int k = i; k < j; k++){ 17 int count = dp[i][k] + dp[k+1][j] + p[i-1]*p[k]*p[j]; 18 if(count < dp[i][j]){ 19 dp[i][j] = count; //record optimal solution value 20 s[i][j] = k; //record splitting point k 21 } 22 } 23 } 24 } 25 return dp[1][n]; 26 }
執行一個例子:
即輸入的陣列p為{30,35,15,5,10,20,25}。
如果在return之前打印出dp[][]和s[][]的值,結果為:
從左圖可看出最優解為dp[1][6] = 15,125,即最少可以進行一萬五千多次乘法。右圖記錄了對於每一個[i,j]決定的子矩陣鏈如何進行括號分割。
順便分享一個ArrayPrinter的util,可以直接用,能打印出上圖那樣的二維int陣列。
1 public class ArrayPrinter { 2 public static void print(int[] arr){ 3 printReplacing(false, arr, 0,""); 4 } 5 6 public static void print(int[][] matrix){ 7 printReplacing(false, matrix, 0,""); 8 } 9 10 public static void printReplacing(int[] arr, int before, String after){ 11 printReplacing(true, arr, before, after); 12 } 13 14 public static void printReplacing(int[][] matrix, int before, String after){ 15 printReplacing(true, matrix, before, after); 16 } 17 18 /*--------------------------private utils-------------------------------*/ 19 20 private static void printReplacing(boolean replace, int[] arr, int before, String after){ 21 int maxLen = maxLength(arr); 22 if(replace){ 23 for(int i : arr) 24 print(((i==before)?after:number(i)), maxLen); 25 }else{ 26 for(int i : arr) 27 print(number(i), maxLen); 28 } 29 print("\n", maxLen); 30 } 31 32 public static void printReplacing(boolean replace, int[][] matrix, int before, String after){ 33 int maxLen = maxLength(matrix); 34 if(replace){ 35 for(int[] row : matrix){ 36 for(int i : row) 37 print(((i==before)?after:number(i)), maxLen); 38 print("\n", maxLen); 39 } 40 }else{ 41 for(int[] row : matrix){ 42 for(int i : row) 43 print(number(i), maxLen); 44 print("\n", maxLen); 45 } 46 } 47 } 48 49 private static int maxLength(int[] arr){ 50 int maxLen = 0; 51 for(int aint : arr) 52 maxLen = Math.max(Integer.toString(aint).length(), maxLen); 53 return maxLen; 54 } 55 56 private static int maxLength(int[][] matrix){ 57 int maxLen = 0; 58 for(int row[] : matrix) 59 maxLen = Math.max(maxLength(row), maxLen); 60 return maxLen; 61 } 62 63 //actual printing 64 private static void print(String s, int length){ 65 System.out.print(String.format("%1$"+(length+1)+"s", s)); 66 } 67 68 //formatting of number 69 private static String number(int i){ 70 return NumberFormat.getNumberInstance(Locale.US).format(i); 71 } 72 }ArrayPrinter
使用方法:
1 ArrayPrinter.printReplacing(dp, Integer.MAX_VALUE, "/"); 2 ArrayPrinter.print(s);
第四步:根據計算好的資訊構造最優解
還差一步就大功告成。這一步我們要拿著上一步計算出的矩陣s把最終的全括號矩陣乘積打印出來。遞迴列印即可。
1 private void printParenthesis(int[][] s, int i, int j) { 2 if(i == j) 3 print("A"+i); 4 else{ 5 print("("); 6 printParenthesis(s, i, s[i][j]); 7 printParenthesis(s, s[i][j]+1, j); 8 print(")"); 9 } 10 }
列印結果:
複雜度
前面說過,窮舉法的複雜度大概是O(2n)。在以上的dp演算法中,主演算法需要填滿一個(n+1)×(n+1)的二維陣列的上半部分,每填一個元素需要一個長度為j-i的迴圈,可通過這個思路對j-i進行求和(i=0...n, j=i...n),也可以通過大概估算得到時間複雜度為O(n3),遠好於窮舉法。
空間複雜度主要由二維陣列決定,為O(n2)。
總結
本文主要介紹瞭解一個dp問題的思路。
dp問題一般有兩個顯著特點,這一點下一篇會詳細講述:
- 問題的最優解由子問題的最優解構成
- 子問題互相重疊
也再複習一下解題的四個步驟,看你現在有沒有更深刻的理解:
- 分析最優解的特徵。 (分析最優子解如何構成最優解)
- 遞迴地定義最優解的值。 (畫遞迴樹,定義子問題,寫狀態轉移方程)
- 計算最優解的值。 (寫程式碼求出最優解,如果有要求的話,記錄額外資訊,為第4步作準備)
- 根據計算好的資訊構造最優解。 (從第3步記錄的資訊中構建最優解,在本題中就是括號的寫法)