BZOJ.4650.[NOI2016]優秀的拆分(字尾陣列 思路)
阿新 • • 發佈:2018-12-30
令\(st[i]\)表示以\(i\)為開頭有多少個\(AA\)這樣的子串,\(ed[i]\)表示以\(i\)結尾有多少個\(AA\)這樣的子串。那麼\(Ans=\sum_{i=1}^{n-1}ed[i]*st[i+1]\)。
考慮如何求\(st[i],ed[i]\)。暴力的話可以列舉\(i\),然後雜湊判一下。這樣\(O(n^2)\)就有\(95\)分了。。
正解是,列舉長度\(len\),判斷每個位置是否存在長為\(2*len\)的\(AA\)這樣的子串。
每隔\(len\)的距離放一個關鍵點,這樣一個長度為\(2*len\)的串一定會經過兩個相鄰的關鍵點。
考慮列舉兩個相鄰的關鍵點,即令\(i=k*len,\ j=i+len\)
(不想畫圖了,注意別看錯,可以拿個串比如
aabaabab
試一下)
當\(x+y-1<len\)時,因為中間沒有相同的部分所以找不到一個經過\(i,j\)長為\(2*len\)的\(AA\)串。
當\(x+y-1\geq len\)時,我們發現因為\(i,j\)是兩個相距為\(len\)的點,我們取\(i-x+len,\ j-x+len\),這兩個點之間能形成長\(2*len\)的\(AA\)
也就是當\(p\)取\([j-x+len,\ j+y-1]\)中的某個位置時,都能得到以\(p\)為結尾的長為\(2*len\)的\(AA\)串。同理當\(p\)在\([i-x+1,\ i+y-len]\)中時,也都能得到以\(p\)開頭的長為\(2*len\)的\(AA\)串。
所以就是區間加一,差分一下就可以了。
只是列舉\(len\),然後每隔\(len\)放一個點,統計相鄰兩點間的貢獻。所以複雜度還是\(O(n\log n)\)。
//5892kb 784ms #include <cstdio> #include <cstring> #include <algorithm> typedef long long LL; const int N=3e4+5; int Log[N]; struct Suffix_Array { int tm[N],sa[N],sa2[N],rk[N],ht[N],st[N][15]; inline void Init_ST(const int n) { for(int i=1; i<=n; ++i) st[i][0]=ht[i]; for(int j=1; j<=Log[n]; ++j) for(int t=1<<j-1,i=n-t; i; --i) st[i][j]=std::min(st[i][j-1],st[i+t][j-1]); } inline int LCP(int l,int r) { l=rk[l], r=rk[r]; if(l>r) std::swap(l,r); ++l; int k=Log[r-l+1]; return std::min(st[l][k],st[r-(1<<k)+1][k]); } void Build(char *s,const int n) { memset(rk,0,sizeof rk); memset(sa2,0,sizeof sa2);//要清空...! 因為下面比較懶得加<=n了。 int m=26,*x=rk,*y=sa2; for(int i=0; i<=m; ++i) tm[i]=0; for(int i=1; i<=n; ++i) ++tm[x[i]=s[i]-'a'+1]; for(int i=1; i<=m; ++i) tm[i]+=tm[i-1]; for(int i=n; i; --i) sa[tm[x[i]]--]=i; for(int k=1,p=0; k<n; k<<=1,m=p,p=0) { for(int i=n-k+1; i<=n; ++i) y[++p]=i; for(int i=1; i<=n; ++i) if(sa[i]>k) y[++p]=sa[i]-k; for(int i=0; i<=m; ++i) tm[i]=0; for(int i=1; i<=n; ++i) ++tm[x[i]]; for(int i=1; i<=m; ++i) tm[i]+=tm[i-1]; for(int i=n; i; --i) sa[tm[x[y[i]]]--]=y[i]; std::swap(x,y), x[sa[1]]=p=1; for(int i=2; i<=n; ++i) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?p:++p;//because of this if(p>=n) break; } for(int i=1; i<=n; ++i) rk[sa[i]]=i; ht[1]=0; for(int i=1,k=0,p; i<=n; ++i) { if(rk[i]==1) continue; if(k) --k; p=sa[rk[i]-1]; while(i+k<=n && p+k<=n && s[i+k]==s[p+k]) ++k; ht[rk[i]]=k; } Init_ST(n); } }sa1,sa2; inline void Init_Log(const int n) { for(int i=2; i<=n; ++i) Log[i]=Log[i>>1]+1; } void Solve() { static int st[N],ed[N]; static char s[N]; scanf("%s",s+1); const int n=strlen(s+1); sa1.Build(s,n), std::reverse(s+1,s+1+n), sa2.Build(s,n); memset(st,0,n+1<<2), memset(ed,0,n+1<<2); for(int len=1,lim=n>>1; len<=lim; ++len) for(int i=len,j=len<<1; j<=n; i=j,j+=len) { int x=std::min(len,sa2.LCP(n-i+1,n-j+1)),y=std::min(len,sa1.LCP(i,j)); if(x+y-1>=len) ++st[i-x+1], --st[i+y-len+1], ++ed[j-x+len], --ed[j+y]; } LL ans=0; for(int i=1; i<n; ++i) st[i+1]+=st[i], ed[i+1]+=ed[i], ans+=1ll*ed[i]*st[i+1]; printf("%lld\n",ans); } int main() { Init_Log(30000); int T; scanf("%d",&T); while(T--) Solve(); return 0; }