1. 程式人生 > 其它 >【迴文自動機 動態規劃】JZOJ_4752 字串合成

【迴文自動機 動態規劃】JZOJ_4752 字串合成

迴文自動機上的dp

題面

思路

可以發現答案即對於每個迴文子串,求出它的合成代價(通過翻轉啥的),再暴力一個一個加上其它的字元。

用迴文自動機跑出每個迴文子串(即列舉到\(i\)時以\(i\)結尾的迴文串)。
\(dp_x\)為在迴文自動機上點\(x\)表示的迴文串的合成代價,\(fail_x\)\(x\)的最長迴文字尾,\(trans_x\)\(x\)的最長迴文字尾(不超過長度一半)。
\(1、x\)的長度為奇數,則
\(dp_x=min\{dp_{fa_x}+2,dp_{fail_x}+len_x–len_{fail_x}\}\)
由於奇數長迴文串不能翻轉加倍,故只可能由首尾加字元得來,要麼首尾一起加一個字元,要麼只在尾部加(與只在首部加效果相同)。

\(2、x\)的長度為偶數,則最後一步必為翻轉加倍,若翻轉加倍前在首部加了字元,則\(dp_x=dp_{fa_x}+1\),表示在\(fa_x\)進行最後一步翻轉加倍之前先在首部補上一個字元再加倍。
若翻轉加倍前未在首部加字元,則\(dp_x=dp_{trans_x}+len_x/2–len_{trans_x}+1\)
表示先合成\(trans_x\),再在其後新增字元至\(x\)的左邊一半,再翻轉加倍。

對於每個迴文串的統計答案都是\(f_x+n-len_x\),可以邊建\(PAM\)邊dp。

程式碼

#include <cstdio>
#include <cstring>
#include <algorithm>

int t, n, last, ans;
int tmp[100001], f[100001];
char s[100001];

struct PAM {
	int cnt;
	int len[100001], num[100001], fail[100001], tree[100001][27], fa[100001], trans[100001];
	void clear() {
		memset(len, 0, sizeof(len));
		memset(num, 0, sizeof(num));
		memset(fail, 0, sizeof(fail));
		memset(tree, 0, sizeof(tree));
		memset(fa, 0, sizeof(fa));
		memset(trans, 0, sizeof(trans));
		cnt = 1;
		fail[0] = 1;
		len[1] = -1;
	}
	int getFail(int p, int i) {
		while (tmp[i - len[p] - 1] != tmp[i] || i - len[p] - 1 < 0)
			p = fail[p];
		return p;
	}
	int getTrans(int p, int i) {
		while (tmp[i - len[p] - 1] != tmp[i] || (len[p] + 2 << 1) > len[cnt])
			p = fail[p];
		return p;		
	}
	void insert(int u, int i) {
		int Fail = getFail(last, i);
		if (!tree[Fail][u]) {
			len[++cnt] = len[Fail] + 2;
			fail[cnt] = tree[getFail(fail[Fail], i)][u];
			tree[Fail][u] = cnt;
			num[cnt] = num[fail[cnt]] + 1;
			fa[cnt] = Fail;
			if (len[cnt] <= 2)
				trans[cnt] = fail[cnt];
			else
				trans[cnt] = tree[getTrans(trans[Fail], i)][u];
		}
		last = tree[Fail][u];
	}
} a;

int main() {
	scanf("%d", &t);
	while (t--) {
		scanf("%s", s + 1);
		n = strlen(s + 1);
		ans = n;
		tmp[0] = -1;//坑:一開始為0會與s[i]匹配
		f[0] = 1;//別漏加偶迴文串最初加字元的代價
		a.clear();
		for (int i = 1; i <= n; i++) {
			tmp[i] = s[i] - 97;
			a.insert(tmp[i], i);
			if (a.len[last] & 1)
				f[last] = std::min(f[a.fa[last]] + 2, f[a.fail[last]] + a.len[last] - a.len[a.fail[last]]);
			else
				f[last] = std::min(f[a.fa[last]] + 1, f[a.trans[last]] + a.len[last] / 2 - a.len[a.trans[last]] + 1);
			ans = std::min(ans, f[last] + n - a.len[last]);
		}
		printf("%d\n", ans);
	}
}