1. 程式人生 > >bzoj3864 Hero meet devil(dp套dp)

bzoj3864 Hero meet devil(dp套dp)

題面

bzoj

題目大意:
給出一個模式串\(S(|S|≤15)\) 問存在多少個長為\(m(m≤1000)\) 的字串T滿足\(LCS(S,T)=x(0≤x≤|S|)\) 輸出\(|S|+1\)個結果\((mod 1e9+7)\) (\(|S|\)表示字串S的長度,字符集為\(A,T,C,G\)四個字母)

題解

樸素\(lcs\)\(dp\)

for (int i = 1; i <= n; i++)
    for (int j = 1; j <= m; j++)
        if (a[i] == b[j]) f[i][j] = f[i-1][j-1]+1;
        else f[i][j] = max(f[i-1][j], f[i][j-1], f[i-1][j-1]);

我們能發現

  • \(f[i][j]\)\(f[i][j-1]\),\(f[i][j+1]\)最多相差\(1\)

  • \(|S| ≤ 15\)

我們可以把\(j\)那一維的差分陣列狀壓一下

然後呢?

\(f[i][S]\)表示在第\(i\)個位置,此時\(lcs\)的狀態為\(S\)的方案數

預處理出 \(nxt[S][A/C/G/T]\)\(S\)狀態下,新增\(A/C/G/T\)後分別的狀態

然後就有
\(f[i+1][nxt[s][k]] += f[i][s]\)

至於預處理,我們把狀壓還原出來
模擬樸素\(dp\)一遍,再壓回去

Code

#include<bits/stdc++.h>

#define LL long long
#define RG register

using namespace std;
template<class T> inline void read(T &x) {
    x = 0; RG char c = getchar(); bool f = 0;
    while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
    while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
    x = f ? -x : x;
    return ;
}
template<class T> inline void write(T x) {
    if (!x) {putchar(48);return ;}
    if (x < 0) x = -x, putchar('-');
    int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
    for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
const int N = 1001, Mod = 1e9+7;
char S[16], SS[5] = {"ACGT"};
int a[16], f[N][(1<<15)+1], nxt[(1<<15)+1][5], n, len, limit, ans[16];

int tmp[2][16];
int solve(int s, int ch) {
    int ret = 0;
    memset(tmp, 0, sizeof(tmp));
    for (int i = 0; i < n; i++) tmp[0][i+1] = tmp[0][i]+((s>>i)&1);
    for (int i = 1; i <= n; i++) {
        int mx = 0;
        if (a[i] == ch) mx = tmp[0][i-1]+1;
        mx = max(max(mx, tmp[0][i]), tmp[1][i-1]);
        tmp[1][i] = mx;
    }
    for (int i = 0; i < n; i++) ret += (1<<i)*(tmp[1][i+1]-tmp[1][i]);
    return ret;
}

int main() {
    //freopen(".in", "r", stdin);
    //freopen(".out", "w", stdout);
    int q; read(q);
    while (q--) {
        memset(f, 0, sizeof(f)); memset(ans, 0, sizeof(ans));
        scanf("%s", S+1);
        n = strlen(S+1); limit = 1<<n;
        for (int i = 1; i <= n; i++)
            for (int j = 0; j < 4; j++)
                if (S[i] == SS[j]) {a[i] = j+1; break;}
        read(len);
        for (int s = 0; s < limit; s++)
            for (int j = 1; j <= 4; j++)
                nxt[s][j] = solve(s, j);        
        f[0][0] = 1;
        for (int i = 0; i < len; i++)
            for (int s = 0; s < limit; s++)
                for (int k = 1; k <= 4; k++)
                    (f[i+1][nxt[s][k]] += f[i][s]) %= Mod;
        for (int s = 0; s < limit; s++) {
            int cnt = __builtin_popcount(s);
            (ans[cnt] += f[len][s]) %= Mod;
        }
        for (int i = 0; i <= n; i++)
            printf("%d\n", ans[i]);
    }
    return 0;
}