1. 程式人生 > 實用技巧 >題解 LOJ2083 「NOI2016」優秀的拆分

題解 LOJ2083 「NOI2016」優秀的拆分

題目連結

約定:\(\text{suf}(i)\)表示以\(i\)開頭的字尾(\(s[i\dots n]\)),\(\text{pre}(i)\)表示以\(i\)結尾的字首(\(s[1\dots i]\)),\(\text{lcp}(s_1,s_2)\)表示兩個串的最長公共字首,\(\text{lcs}(s_1,s_2)\)表示兩個串的最長公共字尾。

\(f[i]\)表示以\(i\)為結尾的,\(\text{AA}\)式的子串數量。\(g[i]\)表示以\(i\)為開頭的,\(\text{BB}\)式的子串數量。我們可以列舉\(\text{AA}\)\(\text{BB}\)的分界點,然後求出答案,也就是說,答案等於:

\[\sum_{i=1}^{n-1}f[i]\cdot g[i+1] \]

對每個\(i\),都求一遍\(f\)\(g\)。利用雜湊或字尾陣列判斷子串相等。時間複雜度\(O(n^2)\),期望得\(95\)分。

繼續優化。\(f\), \(g\)是類似的,以求\(f\)為例。考慮列舉\(\text{A}\)的長度\(\text{len}\)。我們把\(1\dots n\)中所有是\(\text{len}\)的倍數的點,標為“關鍵點”。發現一個\(\text{AA}\)式的子串,必定跨過恰好\(2\)個關鍵點!

考慮一組相鄰的關鍵點\(i,j\) (顯然,\(j=i+\text{len}\)),計算跨過\(i,j\)

\(\text{AA}\)的數量。設\(x=\text{lcp}(\text{suf}(i),\text{suf}(j))\)\(y=\text{lcs}(\text{pre}(i-1),\text{pre}(j-1))\)。如果\(x+y<\text{len}\),顯然跨過\(i,j\)\(\text{AA}\)數量為\(0\)。否則,中間的兩段\(x\), \(y\)會有一個交,第一個\(\text{A}\)的終點,只要落在這段交上都是可行的,所以會有\(x+y-\text{len}+1\)個跨過\(i,j\)\(\text{AA}\),並且它們的結尾位置,是一段連續的區間,所以對這段區間的\(f\)
\(+1\)即可,可以用差分實現。

預處理字尾陣列後,求\(\text{lcp}\), \(\text{lcs}\)可以\(O(1)\)實現。因為關鍵點的總數是調和級數的,所以總時間複雜度\(O(n\log n)\)

參考程式碼:

//problem:LOJ2083
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

template<typename T>inline void ckmax(T& x,T y){x=(y>x?y:x);}
template<typename T>inline void ckmin(T& x,T y){x=(y<x?y:x);}

const int MAXN=3e4;
const int LOG_MAXN=15;
int n;
char s[MAXN+5];
struct SuffixArray{
	int n,m,x[MAXN+5],y[MAXN+5],c[MAXN+5],sa[MAXN+5];
	int rk[MAXN+5],ht[MAXN+5],st[MAXN+5][LOG_MAXN+5];
	int _log2[MAXN+5];
	void clear(){
		for(int i=1;i<=max(n,(int)'z');++i)c[i]=0;
		for(int i=1;i<=n;++i)y[i]=0;
	}
	void init(){
		_log2[0]=-1;
		for(int i=1;i<=MAXN;++i)_log2[i]=_log2[i>>1]+1;
	}
	void build(char* s,int _n){
		n=_n;
		m='z';
		for(int i=1;i<=n;++i)c[x[i]=s[i]]++;
		for(int i=1;i<=m;++i)c[i]+=c[i-1];
		for(int i=n;i>=1;--i)sa[c[x[i]]--]=i;
		for(int k=1;k<=n;k<<=1){
			int num=0;
			for(int i=n-k+1;i<=n;++i)y[++num]=i;
			for(int i=1;i<=n;++i)if(sa[i]>k)y[++num]=sa[i]-k;
			for(int i=1;i<=m;++i)c[i]=0;
			for(int i=1;i<=n;++i)c[x[i]]++;
			for(int i=1;i<=m;++i)c[i]+=c[i-1];
			for(int i=n;i>=1;--i)sa[c[x[y[i]]]--]=y[i];
			
			for(int i=1;i<=n;++i)swap(x[i],y[i]);
			x[sa[1]]=1;num=1;
			for(int i=2;i<=n;++i)
				x[sa[i]]=((y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k])?num:++num);
			if(num==n)break;
			m=num;
		}
		for(int i=1;i<=n;++i)rk[sa[i]]=i;
		//for(int i=1;i<=n;++i)cout<<rk[i]<<" ";cout<<endl;
		int k=0;
		for(int i=1;i<=n;++i){
			if(rk[i]==1)continue;
			if(k)--k;
			int j=sa[rk[i]-1];
			while(i+k<=n && j+k<=n && s[i+k]==s[j+k])
				++k;
			ht[rk[i]]=k;
		}
		for(int i=1;i<=n;++i)st[i][0]=ht[i];
		for(int j=1;j<=LOG_MAXN;++j){
			for(int i=1;i+(1<<(j-1))<=n;++i){
				st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
			}
		}
	}
	int rmq(int l,int r){
		int k=_log2[r-l+1];
		return min(st[l][k],st[r-(1<<k)+1][k]);
	}
	int get_lcp(int i,int j){
		if(i==j)
			return n-i+1;
		i=rk[i]; j=rk[j];
		if(i>j)swap(i,j);
		assert(i<j);
		return rmq(i+1,j);
	}
	SuffixArray(){}
};

SuffixArray SA,SA_rev;

int f[MAXN+5],g[MAXN+5];

void solve_case(){
	SA.clear();
	SA_rev.clear();
	cin>>(s+1); n=strlen(s+1);
	SA.build(s,n);
	reverse(s+1,s+n+1);
	SA_rev.build(s,n);
	
	for(int i=1;i<=n;++i)f[i]=g[i]=0;
	
	int maxlen=(n-2)/2;
	for(int len=1;len<=maxlen;++len){
		for(int k=1;(k+1)*len<=n;++k){
			int i=k*len;
			int j=(k+1)*len;
			
			int lcp=SA.get_lcp(i,j);
			int lcs=(i==1?0:SA_rev.get_lcp(n-(i-1)+1,n-(j-1)+1));
			
			if(!lcp)continue;
			
			int l=j-min(lcs,len-1)+len-1;
			int r=min(n,j+min(lcp,len)-1);
			if(l<=r){
				f[l]++;f[r+1]--;
			}
			l=max(1,i-min(lcs,len-1));
			r=i+min(lcp,len)-1-len+1;
			if(l<=r){
				g[l]++;g[r+1]--;
			}
		}
	}
	for(int i=1;i<=n;++i)f[i]+=f[i-1],g[i]+=g[i-1];
	f[n+1]=g[n+1]=0;
	ll ans=0;
	for(int i=1;i<=n-2;++i)
		ans+=(ll)f[i]*g[i+1];
	cout<<ans<<endl;
}
int main() {
	SA.init();
	SA_rev.init();
	int T;cin>>T;while(T--){
		solve_case();
	}
	return 0;
}