1. 程式人生 > >FFT:BZOJ4503 兩個串

FFT:BZOJ4503 兩個串

題目描述:戳這裡

題解:

如果沒有"?",那麼我們可以用kmp。
我們可以把這道題目抽象成一個和式:
假設兩串S,T分別是0~n,0~m,翻轉T串(變成m~0)。
假設T串中"?"的位置都設為0。
假設S串從第x個位置開始匹配可以匹配完T串,那麼等價於要滿足:
0 m (

S x + i T m
i
) 2 T m
i
= 0 \sum_0^m(S_{x+i}-T_{m-i})^2T_{m-i}=0
化簡一下:
0 m S x + i 2 T m i 2 S x + i T m i + T m i 3 \sum_0^mS_{x+i}^2T_{m-i}-2S_{x+i}T_{m-i}+T_{m-i}^3
對於最後一項,可以字首和,前面兩項,都是卷積的形式,可以通過FFT來快速解決。
那麼複雜度就是 O ( n l o g n ) O(nlogn)

程式碼如下:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=(1<<18)+5;
const double Pi=acos(-1.0);
int n,m,limn,R[maxn];
char S[maxn],T[maxn];
struct comx{
	double x,y;
	comx(double xx=0,double yy=0){x=xx,y=yy;}
	comx operator +(const comx b){return comx(x+b.x,y+b.y);}
	comx operator -(const comx b){return comx(x-b.x,y-b.y);}
	comx operator *(const comx b){return comx(x*b.x-y*b.y,x*b.y+y*b.x);} 
}a[maxn],b[maxn],c[maxn],d[maxn],w[maxn];
double s;
void pre(){
	int L=0; limn=1; while (limn<=n+m) limn<<=1,L++;
	for (int i=0;i<limn;i++){
		R[i]=((R[i>>1]>>1)|((i&1)<<(L-1)));
		w[i]=comx(cos(2*Pi/limn*i),sin(2*Pi/limn*i));
	}
}
void FFT(comx *a,int lim){
	for (int i=0;i<lim;i++) if (R[i]>i) swap(a[R[i]],a[i]);
	for (int t=lim>>1,d=1;d<lim;d<<=1,t>>=1)
	for (int i=0;i<lim;i+=(d<<1))
	for (int j=0;j<d;j++){
		comx p=w[t*j]*a[i+j+d];
		a[i+j+d]=a[i+j]-p,a[i+j]=a[i+j]+p;
	}
}
void doit(comx *p,comx *q){
	FFT(p,limn); FFT(q,limn);
	for (int i=0;i<limn;i++) p[i]=p[i]*q[i],w[i].y=-w[i].y;
	FFT(p,limn);
	for (int i=0;i<limn;i++) w[i].y=-w[i].y,p[i].x/=limn;
}
ll cal(double x){return (ll)(x+0.5);}
int main(){
	scanf("%s",S); scanf("%s",T);
	n=strlen(S)-1; m=strlen(T)-1;
	for (int i=0;i<=n;i++) c[i].x=S[i]-'a'+1;
	for (int i=0;i<=m;i++)
		if (T[i]!='?') b[m-i].x=T[i]-'a'+1; else b[m-i].x=0;
	for (int i=0;i<=n;i++) a[i].x=c[i].x*c[i].x;
	for (int i=0;i<=m;i++) d[i].x=b[i].x*b[i].x,s+=b[i].x*b[i].x*b[i].x;
	pre(); doit(a,b); doit(c,d);
	int ans=0;
	for (int i=m;i<=n;i++)
		if (cal(a[i].x)-2*cal(c[i].x)+cal(s)==0) ans++;
	printf("%d\n",ans);
	for (int i=m;i<=n;i++)
		if (cal(a[i].x)-2*cal(c[i].x)+cal(s)==0) printf("%d\n",i-m);
	return 0;
}