1. 程式人生 > 實用技巧 >JSOI2013 快樂的JYY(manacher,迴文串)

JSOI2013 快樂的JYY(manacher,迴文串)

萌新不會PAM
於是用manacher+hash
O(nlogn)過了這題
首先我們有一個結論
對於一個長度n的串,它的本質不同的迴文子串數量是O(n)的

然後我們先用manacher算出半徑r[i]
然後開一個數組記錄一下最長迴文子串在原字串中哪個位置

我們考慮任意一個>=3的迴文子串不論是奇是偶,去掉頭尾它依然是迴文串
又因為只有O(n)個本質不同的迴文子串(以下簡稱子串),所以我們可以考慮用字串雜湊map給每個本質不同的子串標號,然後每個子串的節點把去掉頭尾的子串的節點當作父親,就類似於AC自動機的fail樹,意即你訪問子節點的串同時一定遍歷了父節點的串,然後只有長度為1或2的子串向根連邊(把根當作父親).

於是這就形成了一顆樹

於是我們每遍歷一個以某個位置為對稱軸(可以是字母也可以是兩個字母間的空隙)的最長迴文子串就給它打個標記,一個迴文子串遍歷了多少次顯然是它樹上的結點的子樹和,於是我們可以算出每個迴文子串在A中出現了幾次,同理對B建樹,然後把相同的子串在兩樹中的方案乘起來再累加到ans裡,就做完了

/*快樂的JYY*/
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstring>
#include<map>
using namespace std;
#define ll long long
const int maxn = 2e5 + 10; 
#define ul unsigned long long
ul mi[maxn];
ll ans = 0;
int head[maxn];
int idcnt = 0;
struct Edge{
	int nxt,point;
}edge[maxn*2];
int tot;
void add(int x,int y){
	edge[++tot].nxt = head[x];
	edge[tot].point = y;
	head[x] = tot;
}
struct Str{
	char str[maxn];	char s[maxn*2];
	int r[maxn] ;int bit[maxn];ul hash[maxn];map<ul,int>id;
	ul v[maxn];
	ll dp[maxn * 2];
	map<ul,ll>f;
	int len;
	int cnt;
	int rt;
	#define p 133331
	void init(){
		for(int i = 1; i <= len; ++i)
			hash[i] = hash[i-1] * p + str[i] - 'A' + 1;
	}
	ul calc(int l,int r){
		int k = r - l + 1;
		return hash[r] - mi[k] * hash[l-1];
	}
	void manacher(){
		s[++cnt] = '~',s[++cnt] = '#';
		for(int i = 1; i <= len; ++i){
			s[++cnt] = str[i];	s[++cnt] = '#';
			bit[cnt] = i;
		}
		s[++cnt] = '!';
		s[cnt+1] = '\0';
		int mr = 0,mid = 0;
		for(int i = 2; i <= cnt - 1; ++i){
			if(i <= mr)		r[i] = min(r[(mid<<1)-i],mr - i + 1);
			else	r[i] = 1;
			while(s[i+r[i]] == s[i-r[i]])		r[i]++;
			if(i + r[i] > mr){
				mr = i + r[i] - 1;
				mid = i;
			}
		} 
	}
	void build(){
		rt = ++idcnt;
		for(int i = 2; i <= cnt - 1; ++i){
			int L = i - r[i] + 1,R = i + r[i] - 1;
			if(L == R && s[L] == '#')	continue;
			L = bit[L] + 1,R = bit[R];
			int lst = 0;
			while(L <= R){	
				ul val = calc(L,R);
				if(id.find(val) == id.end()){
					id[val] = ++idcnt;
					v[idcnt] = val;
					if(lst)		add(id[val],lst);
				}
				else{
					if(lst)		add(id[val],lst);
					break;
				}
				lst = id[val];
				if(L == R){
					add(rt,id[val]);
				} 
				if(R == L + 1){
					add(rt,id[val]);
				}
				L++,R--;
			}
			L = i - r[i] + 1,R = i + r[i] - 1;
			L = bit[L] + 1,R = bit[R];
			ul val = calc(L,R);
			dp[id[val]]++;
		}
	}
	void dfs(int x){
		for(int i = head[x]; i ; i = edge[i].nxt){
			int y = edge[i].point;
			dfs(y);
			dp[x] += dp[y];
		}
		f[v[x]] += dp[x];
	}
}A,B;
void init(){
	mi[0] = 1;
	for(int i = 1; i <= 1e5; ++i)
		mi[i] = mi[i-1] * p;
	scanf("%s",A.str+1);
	A.len = strlen(A.str + 1);
	A.init();
	A.manacher();
	A.build();
	A.dfs(A.rt);
	scanf("%s",B.str+1);
	B.len = strlen(B.str + 1);
	B.init();
	B.manacher();
	B.build();
	B.dfs(B.rt);
}
void solve(int x){
	if(x != A.rt){
		ans += A.f[A.v[x]] * B.f[A.v[x]];
	}
	for(int i = head[x]; i ; i = edge[i].nxt){
		int y = edge[i].point;
		solve(y);
	}
}
int main()
{
	init();
	solve(A.rt);
	printf("%lld\n",ans);
	return 0;
}