bzoj3864 Hero meet devil(dp套dp)
阿新 • • 發佈:2019-01-08
題面
題目大意:
給出一個模式串\(S(|S|≤15)\) 問存在多少個長為\(m(m≤1000)\) 的字串T滿足\(LCS(S,T)=x(0≤x≤|S|)\) 輸出\(|S|+1\)個結果\((mod 1e9+7)\) (\(|S|\)表示字串S的長度,字符集為\(A,T,C,G\)四個字母)
題解
樸素\(lcs\)的\(dp\)
for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) if (a[i] == b[j]) f[i][j] = f[i-1][j-1]+1; else f[i][j] = max(f[i-1][j], f[i][j-1], f[i-1][j-1]);
我們能發現
\(f[i][j]\)和\(f[i][j-1]\),\(f[i][j+1]\)最多相差\(1\)
\(|S| ≤ 15\)
我們可以把\(j\)那一維的差分陣列狀壓一下
然後呢?
設\(f[i][S]\)表示在第\(i\)個位置,此時\(lcs\)的狀態為\(S\)的方案數
預處理出 \(nxt[S][A/C/G/T]\) 為\(S\)狀態下,新增\(A/C/G/T\)後分別的狀態
然後就有
\(f[i+1][nxt[s][k]] += f[i][s]\)
至於預處理,我們把狀壓還原出來
模擬樸素\(dp\)一遍,再壓回去
Code
#include<bits/stdc++.h> #define LL long long #define RG register using namespace std; template<class T> inline void read(T &x) { x = 0; RG char c = getchar(); bool f = 0; while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1; while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar(); x = f ? -x : x; return ; } template<class T> inline void write(T x) { if (!x) {putchar(48);return ;} if (x < 0) x = -x, putchar('-'); int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10; for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ; } const int N = 1001, Mod = 1e9+7; char S[16], SS[5] = {"ACGT"}; int a[16], f[N][(1<<15)+1], nxt[(1<<15)+1][5], n, len, limit, ans[16]; int tmp[2][16]; int solve(int s, int ch) { int ret = 0; memset(tmp, 0, sizeof(tmp)); for (int i = 0; i < n; i++) tmp[0][i+1] = tmp[0][i]+((s>>i)&1); for (int i = 1; i <= n; i++) { int mx = 0; if (a[i] == ch) mx = tmp[0][i-1]+1; mx = max(max(mx, tmp[0][i]), tmp[1][i-1]); tmp[1][i] = mx; } for (int i = 0; i < n; i++) ret += (1<<i)*(tmp[1][i+1]-tmp[1][i]); return ret; } int main() { //freopen(".in", "r", stdin); //freopen(".out", "w", stdout); int q; read(q); while (q--) { memset(f, 0, sizeof(f)); memset(ans, 0, sizeof(ans)); scanf("%s", S+1); n = strlen(S+1); limit = 1<<n; for (int i = 1; i <= n; i++) for (int j = 0; j < 4; j++) if (S[i] == SS[j]) {a[i] = j+1; break;} read(len); for (int s = 0; s < limit; s++) for (int j = 1; j <= 4; j++) nxt[s][j] = solve(s, j); f[0][0] = 1; for (int i = 0; i < len; i++) for (int s = 0; s < limit; s++) for (int k = 1; k <= 4; k++) (f[i+1][nxt[s][k]] += f[i][s]) %= Mod; for (int s = 0; s < limit; s++) { int cnt = __builtin_popcount(s); (ans[cnt] += f[len][s]) %= Mod; } for (int i = 0; i <= n; i++) printf("%d\n", ans[i]); } return 0; }