1. 程式人生 > >字尾陣列+單調棧--luogu[HAOI2016]找相同字元

字尾陣列+單調棧--luogu[HAOI2016]找相同字元

傳送門
可以把兩個字串通過一個特殊字元連起來然後字尾陣列求出 h h

想到一個 n 2 n^2 做法,這個其實就是要求對於 s

2 s2 中的每個字尾,求其和 s 1 s1 的每個字尾的 l c
p lcp
然後加起來,用 h h 陣列和 s t st
表就能辦到

複雜度浪費在了列舉 s 2 s2 的字尾上,但根據 l c p ( i , j ) = m i n i + 1 k j ( h [ k ] ) lcp(i,j)=min_{i+1\le k\le j}(h[k]) 這個性質,就可以 O ( n ) O(n) 地維護一個單調遞增單調棧求出所有的和,只需要分成 s 1 s1 在前和 s 2 s2 在前做兩次就好了。

注意單調棧要倒序列舉

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define maxn 400005
#define LL long long
using namespace std;
int n,m,len1,len2,sa[maxn],rk[maxn],tax[maxn],tp[maxn],h[maxn],stk[maxn],top;
LL ans,f[maxn];
char s1[maxn],s2[maxn],s[maxn];

inline void rsort(){
	for(int i=0;i<=m;i++) tax[i]=0;
	for(int i=1;i<=n;i++) tax[rk[i]]++;
	for(int i=1;i<=m;i++) tax[i]+=tax[i-1];
	for(int i=n;i;i--) sa[tax[rk[tp[i]]]--]=tp[i];
}

inline void ssort(){
	for(int i=1;i<=n;i++) rk[i]=s[i],tp[i]=i;
	rsort();
	for(int w=1,p=0;p<n && w<=n;w<<=1,m=p){
		p=0;
		for(int i=n-w+1;i<=n;i++) tp[++p]=i;
		for(int i=1;i<=n;i++)
			if(sa[i]>w) tp[++p]=sa[i]-w;
		rsort(); swap(rk,tp);
		rk[sa[1]]=p=1;
		for(int i=2;i<=n;i++)
			if(tp[sa[i]]==tp[sa[i-1]]&&tp[min(n+1,sa[i]+w)]==tp[min(n+1,sa[i-1]+w)])
				rk[sa[i]]=p;
			else rk[sa[i]]=++p;
	}
}

inline void geth(){
	int j,k=0;
	for(int i=1;i<=n;i++){
		if(k) --k;
		j=sa[rk[i]-1];
		while(s[i+k]==s[j+k]) ++k;
		h[rk[i]]=k;
	}
}

inline void solve(){
	LL now=0,tmp=0;
	for(int i=n-1;i;i--){
		tmp=0;
		if(sa[i]>len1+1) ans+=now;
		while(stk[top]>h[i] && top)
			now-=1LL*stk[top]*f[top],tmp+=f[top--];
		stk[++top]=h[i],f[top]=tmp;
		if(sa[i]<=len1) f[top]++;
		now+=1LL*f[top]*stk[top];
	}
	while(top) f[top]=stk[top]=0,top--; now=0;
	for(int i=n-1;i;i--){
		tmp=0;
		if(sa[i]<=len1) ans+=now;
		while(stk[top]>h[i] && top)
			now-=1LL*stk[top]*f[top],tmp+=f[top--];
		stk[++top]=h[i],f[top]=tmp;
		if(sa[i]>len1+1) f[top]++;
		now+=1LL*f[top]*stk[top];
	} return ;
}

int main(){
	scanf("%s%s",s1+1,s2+1); len1=strlen(s1+1),len2=strlen(s2+1);
	for(int i=1;i<=len1;i++) s[++n]=s1[i]; s[++n]='z'+1;
	for(int i=1;i<=len2;i++) s[++n]=s2[i]; m=127;
	ssort(); geth(); solve();
	printf("%lld\n",ans);
	return 0;
}