luogu P2408 不同子串個數
考慮反向操作,去計算有多少組相同的子串,對於一組大小為k的極大相同子串的集合,ans-=k-1。
為了避免重複計算,需要一種有效的,有順序的記錄方案。
比如說,對於每一個相同組,按其起始點所在的位置排序,對於除了第一個串以外的串,均記-1的貢獻。
但這種東西是非常難以快速統計的。
但是,可以對於每一個相同組,按其所在的字尾字典序排序,對於除了第一個串以外的串,均記-1的貢獻。
下面引用別人的一段話,主要是利用lcp來快速統計了不用長度相同組。
========================================================================
每個子串一定是某個字尾的字首,那麼原問題等價於求所有後綴之間的不相同的字首的個數。
如果所有的字尾按照 suffix(sa[1]), suffix(sa[2]),suffix(sa[3]), …… ,suffix(sa[n])的順序計算。
不難發現,對於每一次新加進來的字尾 suffix(sa[k]),它將產生 n-sa[k]+1 個新的字首。
但是其中有height[k]個是和前面的字串的字首是相同的。所以 suffix(sa[k])將“貢獻”出 n-sa[k]+1- height[k]個不同的子串。
累加後便是原問題的答案。這個做法的時間複雜度為 O(n)。
========================================================================
最後再強調一下為什麼只需要統計height[k],而不需要和之前所有的字尾均計算lcp。
因為,按照剛才我們的分析。把每一個相同組看成一條鏈,計數只能發生在邊上。
如果去和前面的再統計一遍的話,顯然是一種錯誤的越級的行為,造成重複統計。
此外,由於按照字典序排序後,再前面的所有串中,與它相鄰的串顯然是與它lcp最大的串。
一定可以穩穩地不重不漏的對每一個之前每一個出現過的過的字首進行統計。
即:按照字典序排序後,如果某個 當前字尾的一個字首 與前面的某個字尾的一個字首相同。
那麼一定是下圖這種情況。
紅色代表可能的位置,因為字典序的緣故,與它靠的越緊,相似度越高。
所以 要麼貢獻已經在之前算過了,要麼就會體現在它和與它相鄰串的lcp中。
#include<iostream> #include<cctype> #include<cstdio> #include<cstring> #include<string> #include<cmath> #include<ctime> #include<cstdlib> #include<algorithm> #define N 1100000 #define L 1000000 #define eps 1e-7 #define inf 1e9+7 #define ll long long using namespace std; inline int read() { char ch=0; int x=0,flag=1; while(!isdigit(ch)){ch=getchar();if(ch=='-')flag=-1;} while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return x*flag; } char s[N]; int n,m,c[N],x[N],y[N],sa[N],rank[N],height[N]; int main() { n=read();m=122;scanf("%s",s+1); for(int i=1;i<=n;i++)c[x[i]=s[i]]++; for(int i=1;i<=m;i++)c[i]+=c[i-1]; for(int i=n;i>=1;i--)sa[c[x[i]]--]=i; for(int k=1;k<=n;k<<=1) { int num=0; for(int i=n-k+1;i<=n;i++)y[++num]=i; for(int i=1;i<=n;i++)if(sa[i]>k)y[++num]=sa[i]-k; for(int i=1;i<=m;i++)c[i]=0; for(int i=1;i<=n;i++)c[x[i]]++; for(int i=1;i<=m;i++)c[i]+=c[i-1]; for(int i=n;i>=1;i--)sa[c[x[y[i]]]--]=y[i],y[i]=0; swap(x,y); x[sa[1]]=num=1; for(int i=2;i<=n;i++) x[sa[i]]=(y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k])?num:++num; if(num==n)break; m=num; } ll ans=(ll)n*((ll)n+(ll)1)/(ll)2; for(int i=1;i<=n;i++)rank[sa[i]]=i; for(int i=1,k=0;i<=n;i++) { if(k)k--; int j=sa[rank[i]-1]; while(s[i+k]==s[j+k])k++; height[rank[i]]=k; ans-=height[rank[i]]; } printf("%lld",ans); return 0; }