[NOI2016][洛谷P1117]優秀的拆分(SA)
阿新 • • 發佈:2020-10-07
題面
https://www.luogu.com.cn/problem/P1117
題解
前置知識:
本題要求一個字串中所有AABB形式的字串(可重)的個數。
首先考慮簡化要求:設f[x]表示以第x位為結尾,有多少個AA形式的字串;g[x]表示以第x位為開頭有多少個AA形式的字串。答案顯然是\(\sum f[i]g[i+1]\)。
列舉AA型字串的半長len,然後設定第1位,第len+1位,第2len+1位…為特殊點。一個長度為2len的AA型字串一定通過恰好兩個相鄰的特殊點。不妨設這兩個點是i,j。
A在特殊點左邊的部分長l(包括特殊點本身),那麼顯然有\(1{\leq}l{\leq}len\)。另外,i,j還必須滿足\(lcs(pre_i,pre_j){\geq}l\)以及\(lcp(suf_i,suf_j){\geq}len-l+1\)。
所以通過兩個相鄰特殊點i、j,並且特殊點左邊的部分長為l的、半長為len的AA型字串存在的必要條件是:
\[\begin{cases} l{\geq}\max(1,len+1-lcp(suf_i,suf_j)) \\ l{\leq}\min(len,lcs(pre_i,pre_j)) \end{cases} \]不難發現這也是充分條件。
所以枚舉了len,i,j之後,設\(high=\min(len,lcs(pre_i,pre_j)),low=\max(1,len+1-lcp(suf_i,suf_j))\)
字首的最長公共字尾、字尾的最長公共字首都可以通過預處理前(後)綴陣列+height陣列上ST表做到O(1)。
所以總時間複雜度是調和級數\(O(\sum_{i=1}^{n}{\frac{n}{i}})=O(n \log n)\)。
程式碼
#include<bits/stdc++.h> using namespace std; #define rg register #define In inline #define ll long long const int N = 30000; In int read(){ int s = 0,ww = 1; char ch = getchar(); while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();} while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();} return s * ww; } int n; char s[N+5]; ll f[N+5],g[N+5]; int lg[N+5]; struct ST{ int minn[N+5][16]; void prepro(int a[]){ for(rg int i = 1;i <= n;i++)minn[i][0] = a[i]; for(rg int j = 1;j <= 15;j++) for(rg int i = 1;i + (1<<j) - 1 <= n;i++)minn[i][j] = min(minn[i][j-1],minn[i+(1<<(j-1))][j-1]); } int query(int l,int r){ int d = lg[r-l+1]; return min(minn[l][d],minn[r+1-(1<<d)][d]); } }; struct SA{ int sa[N+5],rk[N+5],temp[N+5],num[N+5],h[N+5]; int m; void clear(){ memset(sa,0,sizeof(int)*(n+2)); memset(rk,0,sizeof(int)*(n+2)); memset(temp,0,sizeof(int)*(n+2)); } void qsort(){ memset(num,0,sizeof(int) * (m+1)); for(rg int i = 1;i <= n;i++)num[rk[i]]++; for(rg int i = 2;i <= m;i++)num[i] += num[i-1]; for(rg int i = n;i >= 1;i--)sa[num[rk[temp[i]]]--] = temp[i]; } ST H; void calch(){ int k = 0; for(rg int i = 1;i <= n;i++){ if(rk[i] == 1)h[1] = k = 0; else{ if(k)k--; int j = sa[rk[i]-1]; while(s[i+k] == s[j+k])k++; h[rk[i]] = k; } } } void init(){ clear(); m = 26; for(rg int i = 1;i <= n;i++)temp[i] = i; for(rg int i = 1;i <= n;i++)rk[i] = s[i] - 'a' + 1; qsort(); for(rg int d = 1;d <= n;d <<= 1){ int cnt = 0; for(rg int i = n - d + 1;i <= n;i++)temp[++cnt] = i; for(rg int i = 1;i <= n;i++)if(sa[i] > d)temp[++cnt] = sa[i] - d; qsort(); memcpy(temp,rk,sizeof(int) * (n+1)); cnt = 1; rk[sa[1]] = 1; for(rg int i = 2;i <= n;i++){ if(temp[sa[i]] != temp[sa[i-1]] || temp[sa[i]+d] != temp[sa[i-1]+d])cnt++; rk[sa[i]] = cnt; } if(cnt == n)break; m = cnt; } calch(); H.prepro(h); } int lcp(int i,int j){ int x = rk[i],y = rk[j]; if(x > y)swap(x,y); return H.query(x + 1,y); } }S; struct PA{ int pa[N+5],rk[N+5],temp[N+5],num[N+5],h[N+5]; int m; void clear(){ memset(pa,0,sizeof(int)*(n+2)); memset(rk,0,sizeof(int)*(n+2)); memset(temp,0,sizeof(int)*(n+2)); } void qsort(){ memset(num,0,sizeof(int) * (m+1)); for(rg int i = 1;i <= n;i++)num[rk[i]]++; for(rg int i = 2;i <= m;i++)num[i] += num[i-1]; for(rg int i = n;i >= 1;i--)pa[num[rk[temp[i]]]--] = temp[i]; } ST H; void calch(){ int k = 0; for(rg int i = n;i >= 1;i--){ if(rk[i] == 1)h[1] = k = 0; else{ if(k)k--; int j = pa[rk[i]-1]; while(s[i-k] == s[j-k])k++; h[rk[i]] = k; } } } void init(){ clear(); m = 26; for(rg int i = 1;i <= n;i++)temp[i] = i; for(rg int i = 1;i <= n;i++)rk[i] = s[i] - 'a' + 1; qsort(); for(rg int d = 1;d <= n;d <<= 1){ int cnt = 0; for(rg int i = 1;i <= d;i++)temp[++cnt] = i; for(rg int i = 1;i <= n;i++)if(pa[i] + d <= n)temp[++cnt] = pa[i] + d; qsort(); memcpy(temp,rk,sizeof(int) * (n+1)); cnt = 1; rk[pa[1]] = 1; for(rg int i = 2;i <= n;i++){ if(temp[pa[i]] != temp[pa[i-1]] || temp[pa[i]-d] != temp[pa[i-1]-d])cnt++; rk[pa[i]] = cnt; } if(cnt == n)break; m = cnt; } calch(); H.prepro(h); } int lcs(int i,int j){ int x = rk[i],y = rk[j]; if(x > y)swap(x,y); return H.query(x + 1,y); } }P; void calcfg(){ for(rg int len = 1;(len<<1) <= n;len++){ for(rg int i = 1;i + len <= n;i += len){ int j = i + len; int high = P.lcs(i,j); high = min(high,len); int low = S.lcp(i,j); low = max(len + 1 - low,1); if(low <= high){ g[i-high+1]++; g[i-low+2]--; f[j+len-high]++; f[j+len-low+1]--; } } } for(rg int i = 1;i <= n;i++)f[i] += f[i-1],g[i] += g[i-1]; } int main(){ for(rg int i = 2;i <= N;i++)lg[i] = lg[i>>1] + 1; int T = read(); while(T--){ scanf("%s",s + 1); n = strlen(s + 1); S.init(); P.init(); calcfg(); ll ans = 0; for(rg int i = 1;i < n;i++)ans += f[i] * g[i+1]; cout << ans << endl; memset(f,0,sizeof(ll) * (n+2)); memset(g,0,sizeof(ll) * (n+2)); } return 0; }