【迴文自動機 動態規劃】JZOJ_4752 字串合成
阿新 • • 發佈:2021-08-11
迴文自動機上的dp
題面
思路
可以發現答案即對於每個迴文子串,求出它的合成代價(通過翻轉啥的),再暴力一個一個加上其它的字元。
用迴文自動機跑出每個迴文子串(即列舉到\(i\)時以\(i\)結尾的迴文串)。
設\(dp_x\)為在迴文自動機上點\(x\)表示的迴文串的合成代價,\(fail_x\)為\(x\)的最長迴文字尾,\(trans_x\)為\(x\)的最長迴文字尾(不超過長度一半)。
\(1、x\)的長度為奇數,則
\(dp_x=min\{dp_{fa_x}+2,dp_{fail_x}+len_x–len_{fail_x}\}\)
由於奇數長迴文串不能翻轉加倍,故只可能由首尾加字元得來,要麼首尾一起加一個字元,要麼只在尾部加(與只在首部加效果相同)。
\(2、x\)的長度為偶數,則最後一步必為翻轉加倍,若翻轉加倍前在首部加了字元,則\(dp_x=dp_{fa_x}+1\),表示在\(fa_x\)進行最後一步翻轉加倍之前先在首部補上一個字元再加倍。
若翻轉加倍前未在首部加字元,則\(dp_x=dp_{trans_x}+len_x/2–len_{trans_x}+1\),
表示先合成\(trans_x\),再在其後新增字元至\(x\)的左邊一半,再翻轉加倍。
對於每個迴文串的統計答案都是\(f_x+n-len_x\),可以邊建\(PAM\)邊dp。
程式碼
#include <cstdio> #include <cstring> #include <algorithm> int t, n, last, ans; int tmp[100001], f[100001]; char s[100001]; struct PAM { int cnt; int len[100001], num[100001], fail[100001], tree[100001][27], fa[100001], trans[100001]; void clear() { memset(len, 0, sizeof(len)); memset(num, 0, sizeof(num)); memset(fail, 0, sizeof(fail)); memset(tree, 0, sizeof(tree)); memset(fa, 0, sizeof(fa)); memset(trans, 0, sizeof(trans)); cnt = 1; fail[0] = 1; len[1] = -1; } int getFail(int p, int i) { while (tmp[i - len[p] - 1] != tmp[i] || i - len[p] - 1 < 0) p = fail[p]; return p; } int getTrans(int p, int i) { while (tmp[i - len[p] - 1] != tmp[i] || (len[p] + 2 << 1) > len[cnt]) p = fail[p]; return p; } void insert(int u, int i) { int Fail = getFail(last, i); if (!tree[Fail][u]) { len[++cnt] = len[Fail] + 2; fail[cnt] = tree[getFail(fail[Fail], i)][u]; tree[Fail][u] = cnt; num[cnt] = num[fail[cnt]] + 1; fa[cnt] = Fail; if (len[cnt] <= 2) trans[cnt] = fail[cnt]; else trans[cnt] = tree[getTrans(trans[Fail], i)][u]; } last = tree[Fail][u]; } } a; int main() { scanf("%d", &t); while (t--) { scanf("%s", s + 1); n = strlen(s + 1); ans = n; tmp[0] = -1;//坑:一開始為0會與s[i]匹配 f[0] = 1;//別漏加偶迴文串最初加字元的代價 a.clear(); for (int i = 1; i <= n; i++) { tmp[i] = s[i] - 97; a.insert(tmp[i], i); if (a.len[last] & 1) f[last] = std::min(f[a.fa[last]] + 2, f[a.fail[last]] + a.len[last] - a.len[a.fail[last]]); else f[last] = std::min(f[a.fa[last]] + 1, f[a.trans[last]] + a.len[last] / 2 - a.len[a.trans[last]] + 1); ans = std::min(ans, f[last] + n - a.len[last]); } printf("%d\n", ans); } }