1. 程式人生 > 實用技巧 >[NOI2016][洛谷P1117]優秀的拆分(SA)

[NOI2016][洛谷P1117]優秀的拆分(SA)

題面

https://www.luogu.com.cn/problem/P1117

題解

前置知識:

本題要求一個字串中所有AABB形式的字串(可重)的個數。

首先考慮簡化要求:設f[x]表示以第x位為結尾,有多少個AA形式的字串;g[x]表示以第x位為開頭有多少個AA形式的字串。答案顯然是\(\sum f[i]g[i+1]\)

列舉AA型字串的半長len,然後設定第1位,第len+1位,第2len+1位…為特殊點。一個長度為2len的AA型字串一定通過恰好兩個相鄰的特殊點。不妨設這兩個點是i,j。

A在特殊點左邊的部分長l(包括特殊點本身),那麼顯然有\(1{\leq}l{\leq}len\)。另外,i,j還必須滿足\(lcs(pre_i,pre_j){\geq}l\)以及\(lcp(suf_i,suf_j){\geq}len-l+1\)

所以通過兩個相鄰特殊點i、j,並且特殊點左邊的部分長為l的、半長為len的AA型字串存在的必要條件是:

\[\begin{cases} l{\geq}\max(1,len+1-lcp(suf_i,suf_j)) \\ l{\leq}\min(len,lcs(pre_i,pre_j)) \end{cases} \]

不難發現這也是充分條件。

所以枚舉了len,i,j之後,設\(high=\min(len,lcs(pre_i,pre_j)),low=\max(1,len+1-lcp(suf_i,suf_j))\)

,如果\(high{\leq}low\),就把i-high+1到i-low+1的g值全部++,把j+len-high到j+len-low的f值全部++。這個可以維護差分而做到\(O(1)\)的更新。

字首的最長公共字尾、字尾的最長公共字首都可以通過預處理前(後)綴陣列+height陣列上ST表做到O(1)。

所以總時間複雜度是調和級數\(O(\sum_{i=1}^{n}{\frac{n}{i}})=O(n \log n)\)

程式碼

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define In inline
#define ll long long

const int N = 30000;

In int read(){
	int s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

int n;
char s[N+5];
ll f[N+5],g[N+5];
int lg[N+5];

struct ST{
	int minn[N+5][16];
	void prepro(int a[]){
		for(rg int i = 1;i <= n;i++)minn[i][0] = a[i];
		for(rg int j = 1;j <= 15;j++)
			for(rg int i = 1;i + (1<<j) - 1 <= n;i++)minn[i][j] = min(minn[i][j-1],minn[i+(1<<(j-1))][j-1]);
	}
	int query(int l,int r){
		int d = lg[r-l+1];
		return min(minn[l][d],minn[r+1-(1<<d)][d]);
	}
};

struct SA{
	int sa[N+5],rk[N+5],temp[N+5],num[N+5],h[N+5];
	int m;	
	void clear(){
		memset(sa,0,sizeof(int)*(n+2));
		memset(rk,0,sizeof(int)*(n+2));
		memset(temp,0,sizeof(int)*(n+2));
	}
	void qsort(){
		memset(num,0,sizeof(int) * (m+1));
		for(rg int i = 1;i <= n;i++)num[rk[i]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = n;i >= 1;i--)sa[num[rk[temp[i]]]--] = temp[i];
	}
	ST H;
	void calch(){
		int k = 0;
		for(rg int i = 1;i <= n;i++){
			if(rk[i] == 1)h[1] = k = 0;
			else{
				if(k)k--;
				int j = sa[rk[i]-1];
				while(s[i+k] == s[j+k])k++;
				h[rk[i]] = k;
			}
		}
	}
	void init(){
		clear();
		m = 26;
		for(rg int i = 1;i <= n;i++)temp[i] = i;
		for(rg int i = 1;i <= n;i++)rk[i] = s[i] - 'a' + 1;
		qsort();
		for(rg int d = 1;d <= n;d <<= 1){
			int cnt = 0;
			for(rg int i = n - d + 1;i <= n;i++)temp[++cnt] = i;
			for(rg int i = 1;i <= n;i++)if(sa[i] > d)temp[++cnt] = sa[i] - d;
			qsort();
			memcpy(temp,rk,sizeof(int) * (n+1));
			cnt = 1;
			rk[sa[1]] = 1;
			for(rg int i = 2;i <= n;i++){
				if(temp[sa[i]] != temp[sa[i-1]] || temp[sa[i]+d] != temp[sa[i-1]+d])cnt++;
				rk[sa[i]] = cnt;
			}
			if(cnt == n)break;
			m = cnt;
		}
		calch();
		H.prepro(h);
	}
	int lcp(int i,int j){
		int x = rk[i],y = rk[j];
		if(x > y)swap(x,y);
		return H.query(x + 1,y);
	}
}S;

struct PA{
	int pa[N+5],rk[N+5],temp[N+5],num[N+5],h[N+5];
	int m;
	void clear(){
		memset(pa,0,sizeof(int)*(n+2));
		memset(rk,0,sizeof(int)*(n+2));
		memset(temp,0,sizeof(int)*(n+2));
	}
	void qsort(){
		memset(num,0,sizeof(int) * (m+1));
		for(rg int i = 1;i <= n;i++)num[rk[i]]++;
		for(rg int i = 2;i <= m;i++)num[i] += num[i-1];
		for(rg int i = n;i >= 1;i--)pa[num[rk[temp[i]]]--] = temp[i];
	}
	ST H;
	void calch(){
		int k = 0;
		for(rg int i = n;i >= 1;i--){
			if(rk[i] == 1)h[1] = k = 0;
			else{
				if(k)k--;
				int j = pa[rk[i]-1];
				while(s[i-k] == s[j-k])k++;
				h[rk[i]] = k;
			}
		}
	}
	void init(){
		clear();
		m = 26;
		for(rg int i = 1;i <= n;i++)temp[i] = i;
		for(rg int i = 1;i <= n;i++)rk[i] = s[i] - 'a' + 1;
		qsort();
		for(rg int d = 1;d <= n;d <<= 1){
			int cnt = 0;
			for(rg int i = 1;i <= d;i++)temp[++cnt] = i;
			for(rg int i = 1;i <= n;i++)if(pa[i] + d <= n)temp[++cnt] = pa[i] + d;
			qsort();
			memcpy(temp,rk,sizeof(int) * (n+1));
			cnt = 1;
			rk[pa[1]] = 1;
			for(rg int i = 2;i <= n;i++){
				if(temp[pa[i]] != temp[pa[i-1]] || temp[pa[i]-d] != temp[pa[i-1]-d])cnt++;
				rk[pa[i]] = cnt;
			}
			if(cnt == n)break;
			m = cnt;
		}
		calch();
		H.prepro(h);
	}
	int lcs(int i,int j){
		int x = rk[i],y = rk[j];
		if(x > y)swap(x,y);
		return H.query(x + 1,y);
	}
}P;

void calcfg(){
	for(rg int len = 1;(len<<1) <= n;len++){
		for(rg int i = 1;i + len <= n;i += len){
			int j = i + len;
			int high = P.lcs(i,j); high = min(high,len);
			int low = S.lcp(i,j); low = max(len + 1 - low,1);
			if(low <= high){
				g[i-high+1]++;
				g[i-low+2]--;
				f[j+len-high]++;
				f[j+len-low+1]--;
			}
		}
	}
	for(rg int i = 1;i <= n;i++)f[i] += f[i-1],g[i] += g[i-1];
}

int main(){
	for(rg int i = 2;i <= N;i++)lg[i] = lg[i>>1] + 1;
	int T = read();
	while(T--){
		scanf("%s",s + 1);
		n = strlen(s + 1);
		S.init();
		P.init();
		calcfg();
		ll ans = 0;
		for(rg int i = 1;i < n;i++)ans += f[i] * g[i+1];
		cout << ans << endl;	
		memset(f,0,sizeof(ll) * (n+2));
		memset(g,0,sizeof(ll) * (n+2));		
	}
	return 0;
}