1. 程式人生 > 實用技巧 >2018ICPC南京/gym101981 M-Mediocre String Problem 迴文自動機+擴充套件kmp

2018ICPC南京/gym101981 M-Mediocre String Problem 迴文自動機+擴充套件kmp

2018ICPC南京/gym101981 M-Mediocre String Problem

題意

給定兩個字串\(s\)\(t\),計算有多少個三元組\((i,j,k)\)滿足如下條件:

  1. \(1\le i \le j \le |s|\)

  2. \(1\le k \le |t|\)

  3. \(j-i+1>k\)

  4. \(s[i;j]+t[1;k]\)是一個迴文串

分析

因為\(j-i+1>k\),所以第四個條件可以看作滿足\(s[i;i+k-1]+t[1,k]\)是一個迴文串且\(s[i+k;j]\)是一個迴文串,將\(s\)反轉一下就是\(s[i;j-k]\)是一個迴文串且\(s[j-k+1;j]=t[1,k]\)

,那麼思路就很顯然了,令\(z[i]\)為字尾\(s[i;n]\)\(t\)的最長公共字首長度,利用擴充套件\(kmp\)可以在\(O(n)\)時間複雜度求出\(z\),對\(s\)建迴文自動機,可以求出陣列\(cnt\)\(cnt[i]\)表示\(s\)中以\(i\)結尾的迴文子串個數,答案即為\(\sum_{i=1}^{n-1}cnt[i]\cdot z[i+1]\)

Code

#include<bits/stdc++.h>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=2e6+10;
const int inf=1e9;
int z[N],n,m,len;
char s[N],t[N],ts[N];
int now,dp[N],g[N];
struct PAM{
	int ch[N][26],fail[N],len[N],cnt[N],tot,last;
	int newnode(int x){
		++tot;
		memset(ch[tot],0,sizeof ch[tot]);
		fail[tot]=0;len[tot]=x;
		return tot;
	}
	void init(){
		tot=-1;newnode(0);newnode(-1);
		fail[0]=1;
		last=0;
	}
	int gao(int x){
		while(s[now-len[x]-1]!=s[now]) x=fail[x];
		return x;
	}
	void insert(){
		int p=gao(last);
		int c=s[now]-'a';
		if(!ch[p][c]){
			int tmp=ch[gao(fail[p])][c];
			ch[p][c]=newnode(len[p]+2);
			fail[tot]=tmp;
			cnt[tot]=cnt[tmp]+1;
		}
		last=ch[p][c];
	}
	ll qy(){
		ll ans=0;
		for(now=1;now<=n;now++){
			insert();
			ans+=1ll*cnt[last]*z[m+2+now];
		}
		return ans;
	}
}P;
void Z(){
	for(int i=2,l=1,r=1;i<=n+m+1;i++){
		if(i<=r) z[i]=min(r-i+1,z[i-l+1]);
		while(i+z[i]<=n+m+1&&ts[z[i]+1]==ts[i+z[i]]) ++z[i];
		if(i+z[i]-1>r) l=i,r=i+z[i]-1;
	}
}
int main(){
	cin>>s+1>>t+1;
	n=strlen(s+1);
	m=strlen(t+1);
	rep(i,1,m) ts[++len]=t[i];
	ts[++len]='#';
	reverse(s+1,s+n+1);
	rep(i,1,n) ts[++len]=s[i];
	Z();
	P.init();
	printf("%lld\n",P.qy());
	return 0;
}