1. 程式人生 > 其它 >dfs解矩陣問題+記憶化搜尋

dfs解矩陣問題+記憶化搜尋

地宮取寶

https://www.acwing.com/problem/content/1214/

1.純dfs,超時得一半分

#include<bits/stdc++.h>
using namespace std;
const int MOD=1e9+7;

int q[50][50];
int n,m,k;
long long int ans;
//param: x,y express position
//max express maxvalue in mybag
//u express numbers i took
void dfs(int x,int y,int max,int u){
    //base case;
    if(x==n || y==m)    return ; //fail
    int cur=q[x][y];
    if(x==n-1 && y==m-1){
        if(u==k)    ans++;     //success and take q[n-1][m-1]
        if(u==k-1 && cur>max) ans++; //success and dont take q[n-1][m-1]
    }
    //take
    if(cur>max){
        dfs(x+1,y,cur,u+1);
        dfs(x,y+1,cur,u+1);
    }
    //no take
    dfs(x+1,y,max,u);
    dfs(x,y+1,max,u);
}
int main(){
    cin>>n>>m>>k;
    for(int i=0;i<n;i++){
        for(int j=0;j<m;j++){
            scanf("%d",&q[i][j]);
        }
    }
    dfs(0,0,-1,0);
    cout<<ans%MOD;
    return 0;
}

2.記憶化搜尋

我們發現每個格子都會存在重複引數的遍歷情況,於是我們採用記憶化搜尋來降低時間複雜度。

只需開一個多維陣列cache(維度由dfs的引數數量決定),裡面儲存以此相同引數的dfs的結果(通常是題目所求),只需在原dfs的程式碼上修改開頭結尾,以及返回值根據題設進行修改;

修改開頭:通常增設一個base case,先查詢此引數的dfs是否存在cache中,存在則直接return cache

修改結尾:每次dfs結尾必須給cache賦值以表示存入此狀態,通常是題設所求的返回值引數

#include<bits/stdc++.h>
using namespace std;
const int MOD=1e9+7;

int q[51][51];
int n,m,k;
long long cache[51][51][14][13];
//param: x,y express position
//max express maxvalue in mybag
//u express numbers i took
long long dfs(int x,int y,int max,int u){
    if(cache[x][y][max+1][u]!=-1) 
        return cache[x][y][max+1][u]; //memory search
    long long  ans=0;
    //base case;
    if(x==n || y==m ||u>k )    return 0; //fail
    int cur=q[x][y];
    if(x==n-1 && y==m-1){
        if(u==k)    ans++;     //success and take q[n-1][m-1]
        if(u==k-1 && cur>max) ans++; //success and dont take q[n-1][m-1]
        ans%=MOD;
        return ans;
    }
    //take
    if(cur>max){
       ans+= dfs(x+1,y,cur,u+1);
       ans+= dfs(x,y+1,cur,u+1);
    }
    //no take
    ans+= dfs(x+1,y,max,u);
    ans+= dfs(x,y+1,max,u);
    cache[x][y][max+1][u]=ans%MOD;
    return cache[x][y][max+1][u];
}


int main(){
    cin>>n>>m>>k;
    for(int i=0;i<n;i++){
        for(int j=0;j<m;j++){
            scanf("%d",&q[i][j]);
        }
    }
    memset(cache,-1,sizeof(cache));
    printf("%lld",dfs(0,0,-1,0));
    

    return 0;
}