洛谷 P5279 - [ZJOI2019]麻將(dp 套 dp)
一道 dp 套 dp 的 immortal tea
首先考慮如何判斷一套牌是否已經胡牌了,考慮 \(dp\)。我們考慮將所有牌按權值大小從大到小排成一列,那我們設 \(dp_{i,j,k,0/1}\) 表示目前考慮了權值 \(\le i\) 的牌,我們之前預留了 \(j\) 張形如 \((i-1,i)\) 的牌與 \(i+1\) 形成刻子,又留了 \(k\) 張 \(i\) 與 \(i+1,i+2\) 形成刻子,\(0/1\) 表示當前是否留過對子,所能夠形成的最大刻子數。設 \(c_i\) 表示有多少張權值為 \(i\)
接下來考慮解決原問題,不難發現如果我們將所有權值從小到大讀入,並將當前的 \(dp\) 值看作一個狀態(即,把當前 \((dp_{*,0,0,0},dp_{*,0,0,1},dp_{*,1,0,0},dp_{*,1,1,1}\cdots,dp_{*,4,4,0},dp_{*,4,4,1},cnt)\) 這樣的有序對視作一個個節點建一張圖,其中由於七刻子胡牌這一規則的存在,我們需額外記錄一個 \(cnt\) 表示當前有多少個權值 \(i\) 滿足權值等於 \(i\) 的牌的張數 \(\ge 2\))那麼如果我們新添上 \(x\) 張權值為 \(i+1\)
最後考慮怎樣求解答案,注意到雖然每個 \(dp\) 值可能的情況可能很多,差不多是指數級別的,可如果我們真正寫個程式 BFS 一下就可以發現能夠從 \((0,0,\cdots,0)\) 到達的狀態並不多——準確來說,是 \(2092\) 個,再聯絡此題 \(n\) 很小這個資料範圍我們可以想到 DP,具體來說,對於一個摸了 \(x\) 張牌後才胡牌的情形,我們可以將它的貢獻拆成,摸了 \(1,2,\cdots,x-1\) 張牌後還未胡牌的概率。因此我們即需求出摸了 \(i\) 張牌後未胡牌的概率,將它們加起來再加 \(1\) 就是答案。這樣我們可以設 \(dp_{i,j,k}\) 表示有多少個摸牌的集合,滿足其中所有牌的權值 \(\le i\),將它們全部讀入自動機後當前位於自動機上 \(j\) 節點的位置,並且 \(k\) 張牌。轉移就列舉新摸了多少張權值為 \(i+1\) 的牌——設摸了 \(x\) 張權值為 \(i+1\) 的牌,最初的牌堆中有 \(a_{i+1}\) 張權值為 \(i+1\) 的牌,那麼需有 \(x\ge a_{i+1}\)。沿著自動機對應的邊轉移即可。那麼最終對於一個 \(dp_{n,j,k}\),其中 \(j\) 不是終止節點,其對答案的貢獻就是 \(dp_{n,j,k}·\dfrac{k!(4n-13-k)!}{(4n-13)!}\)。
時間複雜度 \(2092·n^2\),可以通過此題。
const int MAXN=400;
const int MAXM=2222;
const int MOD=998244353;
struct mat{
int a[3][3];
void clear(){memset(a,-1,sizeof(a));}
int* operator [](int x){return a[x];}
mat(){clear();}
bool operator ==(const mat &rhs) const{
for(int i=0;i<3;i++) for(int j=0;j<3;j++)
if(a[i][j]^rhs.a[i][j]) return 0;
return 1;
}
bool operator !=(const mat &rhs) const{return !((*this)==rhs);}
bool operator <(const mat &rhs) const{
for(int i=0;i<3;i++) for(int j=0;j<3;j++){
if(a[i][j]<rhs.a[i][j]) return 0;
if(a[i][j]>rhs.a[i][j]) return 1;
} return 0;
}
};
void get_trs(mat &to,mat from,int cnt){
for(int i=0;i<3;i++) for(int j=0;j<3;j++) if(~from[i][j])
for(int k=0;k<3&&i+j+k<=cnt;k++)
chkmax(to[j][k],min(4,from[i][j]+i+(cnt-i-j-k)/3));
}
struct node{
int prs_cnt;mat dp[2];
void clear(){prs_cnt=0;dp[0].clear();dp[1].clear();}
node(){clear();}
void make_zero(){clear();dp[0][0][0]=0;}
void make_win(){prs_cnt=-1;dp[0].clear();dp[1].clear();}
bool check_win(){
if(prs_cnt==7) return 1;
for(int i=0;i<3;i++) for(int j=0;j<3;j++) if(dp[1][i][j]==4)
return 1;
return 0;
}
void check(){if(check_win()) make_win();}
bool operator ==(const node &rhs) const{
return (prs_cnt==rhs.prs_cnt&&dp[0]==rhs.dp[0]&&dp[1]==rhs.dp[1]);
}
bool operator !=(const node &rhs) const{return !((*this)==rhs);}
bool operator <(const node &rhs) const{
if(prs_cnt^rhs.prs_cnt) return prs_cnt<rhs.prs_cnt;
if(dp[0]!=rhs.dp[0]) return dp[0]<rhs.dp[0];
if(dp[1]!=rhs.dp[1]) return dp[1]<rhs.dp[1];
return 0;
}
};
node getnxt(node a,int cnt){
if(!~a.prs_cnt) return a;
node res;
if(cnt>=2) res.prs_cnt=a.prs_cnt+1;
else res.prs_cnt=a.prs_cnt;
if(cnt>=2) get_trs(res.dp[1],a.dp[0],cnt-2);
get_trs(res.dp[0],a.dp[0],cnt);
get_trs(res.dp[1],a.dp[1],cnt);
res.check();return res;
}
map<node,int> id;
int ncnt=0,ch[MAXM+5][5],ed=0;
void find_state(){
queue<node> q;node st;st.make_zero();
id[st]=++ncnt;q.push(st);
while(!q.empty()){
node x=q.front();q.pop();
for(int i=0;i<=4;i++){
node y=getnxt(x,i);
if(!id[y]) id[y]=++ncnt,q.push(y);
ch[id[x]][i]=id[y];
}
} st.make_win();ed=id[st];
}
int fac[MAXN+5],ifac[MAXN+5];
void init_fac(int n){
for(int i=(fac[0]=ifac[0]=ifac[1]=1)+1;i<=n;i++) ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=1ll*ifac[i]*ifac[i-1]%MOD;
}
int binom(int x,int y){return 1ll*fac[x]*ifac[y]%MOD*ifac[x-y]%MOD;}
int n,cnt[MAXN+5],dp[MAXN/4+5][MAXM+5][MAXN+5];
int main(){
init_fac(MAXN);find_state();//printf("%d\n",ncnt);
scanf("%d",&n);
for(int i=1,x;i<=13;i++) scanf("%d%*d",&x),cnt[x]++;
dp[0][1][0]=1;
for(int i=0;i<n;i++){
for(int j=1;j<=ncnt;j++) for(int k=0;k<=n<<2;k++){
if(dp[i][j][k]){
for(int l=cnt[i+1];l<=4;l++) if(ch[j][l]!=ed){
dp[i+1][ch[j][l]][k+l]=(dp[i+1][ch[j][l]][k+l]+1ll*binom(4-cnt[i+1],l-cnt[i+1])*dp[i][j][k])%MOD;
}
}
}
} int tot=(n<<2)-13,res=0;
for(int i=1;i<=ncnt;i++) for(int j=0;j<=tot;j++) if(dp[n][i][j+13])
res=(res+1ll*dp[n][i][j+13]*fac[j]%MOD*fac[tot-j]%MOD*ifac[tot])%MOD;
printf("%d\n",res);
return 0;
}