1. 程式人生 > >BZOJ 3238 [Ahoi2013]差異 (字尾陣列+單調棧)

BZOJ 3238 [Ahoi2013]差異 (字尾陣列+單調棧)

題目大意:求$\sum_{1\leq i<j \leq N} suf_{i}+suf_{j}-2\cdot lcp(suf_{i},suf_{j})$

先是字尾陣列打錯了,又是把+=打成了=,我是zz

轉化式子,原式=$\sum_{i=1}^{n-1}(i+1)\cdot i-\sum_{1\leq i<j \leq N}2\cdot lcp(suf_{i},suf_{j})$

這樣計算後面的部分就行了

首先用$sa$預處理出$height$陣列

對字尾進行排序後,對於某個一個字尾$suf_{i}$,如果另一個字尾$suf_{j}$和它的$lcp$長度是$x$,必須要保證$\forall \;k\in[i+1,j-1],h_{k}\geq x$

用一個單調棧維護$height$,設$num_{tp}$表示棧中$lcp$長度為$L_{tp}$的字尾數量總和

用一個動態的數$sum$記錄當前棧中的$h_{k}*num_{k}$總和

每遍歷到一個位置$i$,先刪去棧中大於$h_{i}$的元素,更新$sum$,在這之後,所有$h$大於$h_{i}$的字尾長度都要修改成$h_{i}$,要去掉多出來的部分

即$sum-=(L_{tp}-h_{i})\cdot num_{tp}$

再把$suf_{i}$推入棧中

最終答案就是每次統計完成後的sum總和

 1 #include <bitset>
 2 #include <cstdio>
 3
#include <cstring> 4 #include <algorithm> 5 #define N1 505000 6 #define ll long long 7 #define inf 0x3f3f3f3f 8 #define rint register int 9 using namespace std; 10 11 12 int len; 13 int gch(char *str) 14 { 15 char c=getchar(); 16 while(c<'a'||c>'z'){c=getchar();}
17 while(c>='a'&&c<='z'){str[++len]=c;c=getchar();} 18 } 19 int gint() 20 { 21 int ret=0,fh=1;char c=getchar(); 22 while(c<'0'||c>'9'){if(c=='-')fh=-1;c=getchar();} 23 while(c>='0'&&c<='9'){ret=ret*10+c-'0';c=getchar();} 24 return ret*fh; 25 } 26 char str[N1]; 27 int rk[N1],tr[N1],sa[N1],hs[N1],h[N1]; 28 int check(int i,int j,int k){ 29 if(i+k>len||j+k>len) return 0; 30 return (rk[i]==rk[j]&&rk[i+k]==rk[j+k])?1:0;} 31 void get_sa() 32 { 33 rint i,cnt=0; 34 for(i=1;i<=len;i++) hs[str[i]]++; 35 for(i=1;i<=128;i++) if(hs[i]) tr[i]=++cnt; 36 for(i=1;i<=128;i++) hs[i]+=hs[i-1]; 37 for(i=1;i<=len;i++) rk[i]=tr[str[i]],sa[hs[str[i]]--]=i; 38 for(int k=1;cnt<len;k<<=1) 39 { 40 for(i=1;i<=cnt;i++) hs[i]=0; 41 for(i=1;i<=len;i++) hs[rk[i]]++; 42 for(i=1;i<=cnt;i++) hs[i]+=hs[i-1]; 43 for(i=len;i>=1;i--) if(sa[i]>k) tr[sa[i]-k]=hs[rk[sa[i]-k]]--; 44 for(i=1;i<=k;i++) tr[len-i+1]=hs[rk[len-i+1]]--; 45 for(i=1;i<=len;i++) sa[tr[i]]=i; 46 for(i=1,cnt=0;i<=len;i++) tr[sa[i]]=check(sa[i],sa[i-1],k)?cnt:++cnt; 47 for(i=1;i<=len;i++) rk[i]=tr[i]; 48 } 49 for(i=1;i<=len;i++){ 50 if(rk[i]==1) continue; 51 for(int j=max(1,h[rk[i-1]]-1);;j++) 52 if(str[i+j-1]==str[sa[rk[i]-1]+j-1]) h[rk[i]]=j; 53 else break; 54 } 55 } 56 int stk[N1],num[N1],L[N1],tp; 57 ll sum; 58 ll solve() 59 { 60 ll ans=0,tmp;tp=0; 61 for(int i=2;i<=len;i++) 62 { 63 tmp=0; 64 while(tp>0&&L[tp]>h[i]){ 65 tmp+=num[tp]; 66 sum-=1ll*(L[tp]-h[i])*num[tp]; 67 L[tp]=0,num[tp]=0,tp--; 68 } 69 if(h[i]>L[tp]) 70 tp++,L[tp]=h[i]; 71 num[tp]+=tmp+1; 72 sum+=h[i],ans+=sum; 73 } 74 return ans; 75 } 76 77 int main() 78 { 79 gch(str); 80 get_sa(); 81 ll ans=0; 82 for(int i=1;i<=len-1;i++) 83 ans+=1ll*(i+1)*i; 84 ans=ans/2*3; 85 printf("%lld\n",ans-2ll*solve()); 86 return 0; 87 }