洛谷2408不同字串個數/SPOJ 694/705 (字尾陣列SA)
阿新 • • 發佈:2018-12-22
真是一個三倍經驗好題啊。
我們來觀察這個題目,首先如果直接整體計算,怕是不太好計算。
首先,我們可以將每個子串都看成一個字尾的的字首。那我們就可以考慮一個一個字尾來計算了。
為了方便起見,我們選擇按照字典序來一次插入每個字尾,然後每次考慮當前字尾會產生的新串和與之前插入的串重複的串(這裡之所以可以這麼考慮,是因為如果他會對後面的串產生重複的話,那麼會在後面那個串加入的時候計算的)
那麼我們考慮,一個排名為\(i\)的字尾,插入之後不考慮重複的話,會新增多少個子串呢?
不難發現是\(n-sa[i]+1\)個(注意字尾的位置編號是從前開始,而後綴的貢獻是後面的子串個數。
那麼重複的該怎麼計算呢?
我們發現重複的部分實際是當前這個字尾和之前的字尾的\(lcp\)部分會重複,而且應該是最大的\(lcp\) (如果取小的會算少,直接求sum會算多)。
而有一個比較經典的性質就是,在字典序\(1到i\)中與\(i\)的\(lcp\)長度最長的,一定是\(i-1\),這裡有兩種理解方式,一個是越遠差距越大,另一種是越靠前,取\(min\)的範圍越大,\(min\)就會可能越小
那麼列舉+計算,記得開\(long \ long\)就三倍經驗辣
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<queue> #include<map> #include<set> #define mk makr_pair #define ll long long using namespace std; inline int read() { int x=0,f=1;char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();} while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } const int maxn = 2e5+1e2; int rk[maxn],sa[maxn]; int wb[maxn]; int tmp[maxn]; char a[maxn]; int n; int h[maxn],height[maxn]; void getsa() { int *x=rk,*y=tmp; int s=128; int p=0; for (int i=1;i<=n;i++) x[i]=a[i],y[i]=i; for (int i=1;i<=s;i++) wb[i]=0; for (int i=1;i<=n;i++) wb[x[y[i]]]++; for (int i=1;i<=s;i++) wb[i]+=wb[i-1]; for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--]=y[i]; for (int j=1;p<n;j<<=1) { p=0; for (int i=n-j+1;i<=n;i++) y[++p]=i; for (int i=1;i<=n;i++) if (sa[i]>j) y[++p]=sa[i]-j; for (int i=1;i<=s;i++) wb[i]=0; for (int i=1;i<=n;i++) wb[x[y[i]]]++; for (int i=1;i<=s;i++) wb[i]+=wb[i-1]; for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--]=y[i]; swap(x,y); p=1; x[sa[1]]=1; for (int i=2;i<=n;i++) x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+j]==y[sa[i-1]+j]) ? p : ++p; s=p; } for (int i=1;i<=n;i++) rk[sa[i]]=i; h[0]=0; for (int i=1;i<=n;i++) { h[i]=max(h[i-1]-1,0); while(i+h[i]<=n && sa[rk[i]-1]+h[i]<=n && a[i+h[i]]==a[sa[rk[i]-1]+h[i]]) h[i]++; } for (int i=1;i<=n;i++) height[i]=h[sa[i]]; } int t; void init() { memset(wb,0,sizeof(wb)); memset(rk,0,sizeof(rk)); memset(sa,0,sizeof(sa)); memset(tmp,0,sizeof(tmp)); memset(h,0,sizeof(h)); memset(height,0,sizeof(height)); } int main() { //cin>>t; //while (t--) //{ n=read(); init(); scanf("%s",a+1); getsa(); long long ans=0; for (int i=1;i<=n;i++) { ans=ans+(long long)(n-sa[i]+1)-(long long)h[i];//這裡可以理解成我們順著字典序的順序,加入每個字尾,將子串看成字尾的字首 // 而每次加入會產生新的n-sa[i]+1個字串,其中重複的就是和之前的子串的某些lcp,而字典序上,在這個串前面,與某個串lcp最長的應該是i-1那個串(這裡可以理解成越往前差距越大) } cout<<ans<<"\n"; // } return 0; }