JSOI2013 快樂的JYY(manacher,迴文串)
阿新 • • 發佈:2020-08-12
萌新不會PAM
於是用manacher+hash
O(nlogn)過了這題
首先我們有一個結論
對於一個長度n的串,它的本質不同的迴文子串數量是O(n)的
然後我們先用manacher算出半徑r[i]
然後開一個數組記錄一下最長迴文子串在原字串中哪個位置
我們考慮任意一個>=3的迴文子串不論是奇是偶,去掉頭尾它依然是迴文串
又因為只有O(n)個本質不同的迴文子串(以下簡稱子串),所以我們可以考慮用字串雜湊map給每個本質不同的子串標號,然後每個子串的節點把去掉頭尾的子串的節點當作父親,就類似於AC自動機的fail樹,意即你訪問子節點的串同時一定遍歷了父節點的串,然後只有長度為1或2的子串向根連邊(把根當作父親).
於是我們每遍歷一個以某個位置為對稱軸(可以是字母也可以是兩個字母間的空隙)的最長迴文子串就給它打個標記,一個迴文子串遍歷了多少次顯然是它樹上的結點的子樹和,於是我們可以算出每個迴文子串在A中出現了幾次,同理對B建樹,然後把相同的子串在兩樹中的方案乘起來再累加到ans裡,就做完了
/*快樂的JYY*/ #include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #include<cstring> #include<map> using namespace std; #define ll long long const int maxn = 2e5 + 10; #define ul unsigned long long ul mi[maxn]; ll ans = 0; int head[maxn]; int idcnt = 0; struct Edge{ int nxt,point; }edge[maxn*2]; int tot; void add(int x,int y){ edge[++tot].nxt = head[x]; edge[tot].point = y; head[x] = tot; } struct Str{ char str[maxn]; char s[maxn*2]; int r[maxn] ;int bit[maxn];ul hash[maxn];map<ul,int>id; ul v[maxn]; ll dp[maxn * 2]; map<ul,ll>f; int len; int cnt; int rt; #define p 133331 void init(){ for(int i = 1; i <= len; ++i) hash[i] = hash[i-1] * p + str[i] - 'A' + 1; } ul calc(int l,int r){ int k = r - l + 1; return hash[r] - mi[k] * hash[l-1]; } void manacher(){ s[++cnt] = '~',s[++cnt] = '#'; for(int i = 1; i <= len; ++i){ s[++cnt] = str[i]; s[++cnt] = '#'; bit[cnt] = i; } s[++cnt] = '!'; s[cnt+1] = '\0'; int mr = 0,mid = 0; for(int i = 2; i <= cnt - 1; ++i){ if(i <= mr) r[i] = min(r[(mid<<1)-i],mr - i + 1); else r[i] = 1; while(s[i+r[i]] == s[i-r[i]]) r[i]++; if(i + r[i] > mr){ mr = i + r[i] - 1; mid = i; } } } void build(){ rt = ++idcnt; for(int i = 2; i <= cnt - 1; ++i){ int L = i - r[i] + 1,R = i + r[i] - 1; if(L == R && s[L] == '#') continue; L = bit[L] + 1,R = bit[R]; int lst = 0; while(L <= R){ ul val = calc(L,R); if(id.find(val) == id.end()){ id[val] = ++idcnt; v[idcnt] = val; if(lst) add(id[val],lst); } else{ if(lst) add(id[val],lst); break; } lst = id[val]; if(L == R){ add(rt,id[val]); } if(R == L + 1){ add(rt,id[val]); } L++,R--; } L = i - r[i] + 1,R = i + r[i] - 1; L = bit[L] + 1,R = bit[R]; ul val = calc(L,R); dp[id[val]]++; } } void dfs(int x){ for(int i = head[x]; i ; i = edge[i].nxt){ int y = edge[i].point; dfs(y); dp[x] += dp[y]; } f[v[x]] += dp[x]; } }A,B; void init(){ mi[0] = 1; for(int i = 1; i <= 1e5; ++i) mi[i] = mi[i-1] * p; scanf("%s",A.str+1); A.len = strlen(A.str + 1); A.init(); A.manacher(); A.build(); A.dfs(A.rt); scanf("%s",B.str+1); B.len = strlen(B.str + 1); B.init(); B.manacher(); B.build(); B.dfs(B.rt); } void solve(int x){ if(x != A.rt){ ans += A.f[A.v[x]] * B.f[A.v[x]]; } for(int i = head[x]; i ; i = edge[i].nxt){ int y = edge[i].point; solve(y); } } int main() { init(); solve(A.rt); printf("%lld\n",ans); return 0; }