1. 程式人生 > >dp方法論——由矩陣相乘問題學習dp解題思路

dp方法論——由矩陣相乘問題學習dp解題思路

導語

刷過一些演算法題,就會十分珍惜“方法論”這種東西。Leetcode上只有題目、討論和答案,沒有方法論。往往答案看起來十分切中要害,但是從看題目到得到思路的那一段,就是繞不過去。樓主有段時間曾把這個過程歸結於智商和靈感的結合,直到有天為了搞懂Leetcode上一位老兄的題型總結,花兩天時間學習了回溯法,突然有種驚為天人的感覺——原來真正掌握一個演算法是應該觸類旁通的,而不是將題中一個細節換掉就又成了新題……

掌握方法論絕對是一種很爽的感覺。看起來好像很花費時間,其實是一種“因為慢,所以快”的方法。以前可能你學習一個dp題目要大半天;當你花了半個周時間,學會了dp的套路,你會發現,有些medium的dp題甚至不需要半個小時就能做完,而且從頭到尾不需提示,全靠自己!

方法論

那麼,怎麼從一個看起來毫無頭緒的問題出發,找到解題的思路並用dp將問題解出來呢?本文以矩陣相乘問題為例,給出dp問題的一般解題思路。

當然,按照思路解題的前提是你已經知道這道題要用dp去解,如何確定一個問題可以用dp去解,則是下一篇要討論的話題。

下面就是動態規劃的一般解題思路:

  1. 分析最優解的特徵。
  2. 遞迴地定義最優解的值。
  3. 計算最優解的值。
  4. 根據計算好的資訊構造最優解。

看起來非常抽象是吧?在這裡不需要完全理解。等你看完全文再回來,保你會有不一樣的感受。

矩陣相乘問題

問題

這是一個看起來可能有點抽象的數學問題,但請你耐心往下看。當你看完解法時,你會驚異於動態規劃的魔力。

題目:給出一個由n個矩陣組成的矩陣鏈<A1,A2,...,An>,矩陣Ai的秩為pi-1×pi。將A1A2...An這個乘積全括號化,使得計算這個乘積所需要的的標量乘法最少。

全括號化是以一種遞迴的形式定義的:

一個全括號化的乘積只有兩種可能:一是一個單個矩陣;二是兩個全括號化的乘積的乘積。

天啦也太繞了,舉個例子吧。對於矩陣鏈<A1,A2,A3,A4>的乘積,共有五種全括號化的方法:

(A1(A2(A3A4))),

(A1((A2A3)A4)),

((A1A2)(A3A4)),

(((A1A2)A3)A4),

((A1(A2A3))A4)

我們知道矩陣乘法是滿足結合律的,所以以上五個式子的乘積相等,但是它們的運算時間是否相等呢?

矩陣乘法的運算時間

我們知道,矩陣乘法的定義是:

兩個互相相容的矩陣A,B可以相乘。互相相容是指A的列數與B的行數相等。假如A是一個p×q的矩陣,而B是一個q×r的矩陣,則乘積C是一個p×r的矩陣且有

cij = ∑ aik·bkj, k = 1,...,q.

由於要對C中的每一個元素進行計算(共q·r個元素),而每次運算要做q次乘法,所以總的運算時間為pqr。

來看看讓乘積中的不同因子結合對運算時間有什麼影響。假設我們有 <A1,A2,A3>這個矩陣鏈,三個矩陣的秩分別為10×100, 100×5和5×50。則

  • ((A1A2)A3)的運算時間為10×100×5+10×5×50=7500;
  • (A1(A2A3))的運算時間為100×5×50+10×100×50=75000。

按照不同的順序做矩陣乘法,所需要的乘法次數竟相差10倍。

初步分析

按照慣例,我們來感受一下窮舉的演算法複雜度。

假設有一個長度為n的矩陣鏈,我們通過遍歷所有的全括號化的可能性來解題。設全括號化的可能性數目為P(n)。當n為1時,矩陣鏈只有一個矩陣,符合全括號化的定義;當n>=2時,全括號化後為兩個矩陣的乘積,即((...)(...))的形式。用遞迴的思路去分析,則中間兩個括號的分界位置有n-1種可能,如下面豎線所示

A1|A2|A3|...|An

當分界線將矩陣鏈分為長度為k和n-k的兩個子矩陣鏈時,全括號化可能性為P(k)P(n-k)。我們對所有的k值求和,就得出給整個矩陣鏈全括號化的數目:

P(n) = ∑ P(k)P(n-k), k=1...n-1   (n>=2)

這是一個,它的增長速率為Ω(4n/n3/2),它的漸進值為Ω(2n)

對漸進值還不太熟,如果有小夥伴明白“增長速率”和“漸進值”之間的關係,歡迎指教。

總的來說,如果對這個題目使用窮舉法,演算法複雜度是指數的。後面我們分析了dp的演算法複雜度,再來比較。

用dp方法論解題

演算法的學習永遠沒有“手把手”這一說。如果你在認真學習這篇文章,希望你能做到比你看到的小節思路提前一點。比如,在看第一步前,先對這個題目有一點大致思路,明白讓自己迷茫的點在哪裡;看第x步前,對第x步的內容在心中有一個猜測。這樣做比起完全放棄思考,只是跟著文章的思路走,收穫會大很多。

第一步:分析最優解的特徵

這一步的精髓是分析最優子解如何構成最優解

在上一節中已經提到,對於n>=2的情況,全括號化後為((chain_1)(chain_2))的形式。這樣,問題自然而然地分成了兩個子問題:求前後兩個子括號中的最優解。

假設對於某種特定的分割(即chain_1chain_2之間的分界線位置固定),chain_1的秩為m×p,其內部的標量乘法數目為x;chain_2的秩為p×n,其內部的標量乘法數目為y。則整個矩陣鏈的乘法次數為x+y+mpn。由於m,p,n是固定的,我們需要讓x和y為最小值從而使整個矩陣鏈的乘法次數最小。即,對於某種特定的分割,兩個子括號中的最優解構成整個問題的最優解的一個選項

總結來說,我們將矩陣乘積簡略地看成兩個子矩陣的乘積,這兩個子矩陣的分界有n-1種可能。對每一種可能,問題被分割成兩個子問題,即求左右兩個子矩陣鏈的最優解。如果遍歷這n-1種可能並選出最好的一個,那就是整個問題的最優解。

第二步:遞迴地定義最優解的值

第二步非常關鍵,是我們將前後思路打通的一步。

第一步中提出了一個比較簡單的思路,即把矩陣鏈分割成左右兩個子矩陣鏈。既然有了這個初步思路,我們就來塗鴉一番,看看這個思路是否可行。

對於遞迴性的問題,一個很好的方法是畫遞迴樹,這樣會使得問題看起來比較具象,而且也會暴露一些演算法上的問題,比如重疊子樹等。畫遞迴樹的時候,最好舉一個實際的例子。這裡我們假設有一個長度為4的矩陣鏈<A1,A2,A3,A4>,簡單地畫一下它的子問題分割:

 

上圖中的數字表示子矩陣鏈的長度,根為4,即初始矩陣鏈;它可以分為1+3,2+2,3+1三種情況,這三種情況又可以各自細分。

這裡暴露了一個問題,請看圖中的兩個塗色的子樹。兩個子樹的節點數字是一樣的。但是左邊這個子樹的根節點3代表的是A2A3A4這個乘積;而右邊這個代表的是A1A2A3這個乘積。由於A1,A2,A3,A4四個矩陣的秩是未知的,它們很可能不相同,則A1A2A3A2A3A4的最優解也很有可能不同。換言之,它們並不是同一個子問題,它們的子子樹也並不相同。

這個問題意味著我們對子問題的定義不夠嚴謹——子問題不能只用長度這個變數來確定。也就是說,如果在bottom-up的dp中用一個數組記錄子問題的值,那麼這個陣列應該是一個二維陣列。子問題不僅應該由子矩陣鏈的長度確定,還要加上起始index這樣的資訊。

為了更通用一些,我們不用起始index+長度,而選用起始index+結束index的定義方法,這是二維dp的慣用套路,在許多字串和陣列有關的問題中都有用到。

設用一個二位矩陣dp[][]存取子問題的解。定義dp[i][j](1<=i<=j<=n)的值為Ai...Aj的最小乘法次數。則按照以上的思路,可以把Ai...Aj再遞迴細分為子問題Ai...AkAk+1...Aj(i<=k<j),則Ai...Aj的最優解值為兩個子問題最優解的和+兩個子矩陣鏈相乘的乘法次數。即有

i==j時,dp[i][j] = 0;

i <j時,dp[i][j] = min{dp[i][k] + dp[k+1][j] + pi-1pkpj}, k = i...j-1 (p為各個矩陣的秩,見題目一節)

到此為止,最關鍵的一步順利完成啦(樓主寫得好累,擊掌╭(○`∀´○)╯╰(○'◡'○)╮)。在這一步中,我們遞迴地定義了子問題最優解的值,完成了演算法最核心的設計部分。在後面兩步中,我們只要把上面這兩個式子翻譯成程式碼,再注意一些實現細節就可以了。

第三步:計算最優解的值

細節一

從第二步順理成章,我們會在一個二維數組裡記錄子問題的解。但是按照什麼順序去填這個二維陣列是個問題。

還是舉例子,在<A1,A2,A3,A4>這個矩陣鏈中,我們會有一個5×5的二維陣列,隨便挑選dp[1][4]這個元素舉例。根據第二步中的狀態轉移方程,有

dp[1][4] = min{(dp[1][1]+dp[2][4]+...),(dp[1][2]+dp[3][4]+...),(dp[1][3]+dp[4][4]+...)}

省略號表示我們此處不需關注pi-1pkpj這一項,只需要看這個格子對其它格子的依賴是什麼樣子。

由上圖可以看出,要計算某一個元素(粉色邊框),我們需要其左邊下面的元素(同樣深度的藍色表示一組資料)。

所以,我們的遍歷方向是從下到上,從左到右

細節二

細心的讀者可能注意到還有一個問題,就是我們一直在求“最優解的值”,也就是“最小的乘法次數”,可是題目中要求的是“最優解”,也就是“加括號的方式”。

這兩者並不矛盾,專注於求解前者可以讓我們先思考相對簡單的問題,通常在求解前者的過程中,我們也找出了後者,只是沒有將它記錄下來。

在此題中,我們可以選擇用一個同樣的二維矩陣s[][]來記錄後者,其中s[i][j]中記錄Ai...Aj的分割分界線k。

程式碼

 1     int matrixChain(int[] p){
 2         int n = p.length - 1; //number of matrices
 3         int[][] dp = new int[n + 1][n + 1]; //we need dp[1][n]
 4         int[][] s = new int[n + 1][n + 1];    //for storing of k
 5         for(int[] row : dp)
 6             Arrays.fill(row, Integer.MAX_VALUE);
 7 
 8         for(int i = 1; i <= n; i++)
 9             dp[i][i] = 0;    //dp[i][j] = 0 when i == j
10         
11         for(int i = n; i >= 1; i--)
12             for(int j = i; j <= n; j++){
13                 if(i == j){
14                     dp[i][j] = 0;
15                 }else{
16                     for(int k = i; k < j; k++){
17                         int count = dp[i][k] + dp[k+1][j] + p[i-1]*p[k]*p[j];
18                         if(count < dp[i][j]){
19                             dp[i][j] = count; //record optimal solution value
20                             s[i][j] = k;      //record splitting point k
21                         }
22                     }
23                 }
24             }
25         return dp[1][n];
26     }

執行一個例子:

即輸入的陣列p為{30,35,15,5,10,20,25}。

如果在return之前打印出dp[][]和s[][]的值,結果為:

      

從左圖可看出最優解為dp[1][6] = 15,125,即最少可以進行一萬五千多次乘法。右圖記錄了對於每一個[i,j]決定的子矩陣鏈如何進行括號分割。

順便分享一個ArrayPrinter的util,可以直接用,能打印出上圖那樣的二維int陣列。

 1 public class ArrayPrinter {
 2     public static void print(int[] arr){
 3         printReplacing(false, arr, 0,"");
 4     }
 5     
 6     public static void print(int[][] matrix){
 7         printReplacing(false, matrix, 0,"");
 8     }
 9     
10     public static void printReplacing(int[] arr, int before, String after){
11         printReplacing(true, arr, before, after);
12     }
13     
14     public static void printReplacing(int[][] matrix, int before, String after){
15         printReplacing(true, matrix, before, after);
16     }
17     
18     /*--------------------------private utils-------------------------------*/
19     
20     private static void printReplacing(boolean replace, int[] arr, int before, String after){
21         int maxLen = maxLength(arr);
22         if(replace){
23             for(int i : arr)
24                 print(((i==before)?after:number(i)), maxLen);
25         }else{
26             for(int i : arr)
27                 print(number(i), maxLen);
28         }
29         print("\n", maxLen);
30     }
31     
32     public static void printReplacing(boolean replace, int[][] matrix, int before, String after){
33         int maxLen = maxLength(matrix);
34         if(replace){
35             for(int[] row : matrix){
36                 for(int i : row)
37                     print(((i==before)?after:number(i)), maxLen);
38                 print("\n", maxLen);
39             }
40         }else{
41             for(int[] row : matrix){
42                 for(int i : row)
43                     print(number(i), maxLen);
44                 print("\n", maxLen);
45             }
46         }
47     }
48 
49     private static int maxLength(int[] arr){
50         int maxLen = 0;
51         for(int aint : arr)
52             maxLen = Math.max(Integer.toString(aint).length(), maxLen);
53         return maxLen;
54     }
55     
56     private static int maxLength(int[][] matrix){
57         int maxLen = 0;
58         for(int row[] : matrix)
59             maxLen = Math.max(maxLength(row), maxLen);
60         return maxLen;
61     }
62     
63     //actual printing 
64     private static void print(String s, int length){
65         System.out.print(String.format("%1$"+(length+1)+"s", s));
66     }
67     
68     //formatting of number
69     private static String number(int i){
70         return NumberFormat.getNumberInstance(Locale.US).format(i);
71     } 
72 }
ArrayPrinter

使用方法:

1 ArrayPrinter.printReplacing(dp, Integer.MAX_VALUE, "/");
2 ArrayPrinter.print(s);

第四步:根據計算好的資訊構造最優解

還差一步就大功告成。這一步我們要拿著上一步計算出的矩陣s把最終的全括號矩陣乘積打印出來。遞迴列印即可。

 1     private void printParenthesis(int[][] s, int i, int j) {
 2         if(i == j)
 3             print("A"+i);
 4         else{
 5             print("(");
 6             printParenthesis(s, i, s[i][j]);
 7             printParenthesis(s, s[i][j]+1, j);
 8             print(")");
 9         }
10     }

列印結果:

複雜度

前面說過,窮舉法的複雜度大概是O(2n)。在以上的dp演算法中,主演算法需要填滿一個(n+1)×(n+1)的二維陣列的上半部分,每填一個元素需要一個長度為j-i的迴圈,可通過這個思路對j-i進行求和(i=0...n, j=i...n),也可以通過大概估算得到時間複雜度為O(n3),遠好於窮舉法。

空間複雜度主要由二維陣列決定,為O(n2)。

總結

本文主要介紹瞭解一個dp問題的思路。

dp問題一般有兩個顯著特點,這一點下一篇會詳細講述:

  • 問題的最優解由子問題的最優解構成
  • 子問題互相重疊

也再複習一下解題的四個步驟,看你現在有沒有更深刻的理解:

  1. 分析最優解的特徵。               (分析最優子解如何構成最優解)
  2. 遞迴地定義最優解的值。               (畫遞迴樹,定義子問題,寫狀態轉移方程)
  3. 計算最優解的值。                        (寫程式碼求出最優解,如果有要求的話,記錄額外資訊,為第4步作準備)
  4. 根據計算好的資訊構造最優解。       (從第3步記錄的資訊中構建最優解,在本題中就是括號的寫法)

參考資料