【LUOGU???】WD與數列 sam 啟發式合併
阿新 • • 發佈:2018-12-30
題目大意
給你一個字串,求有多少對不相交且相同的子串。
位置不同算多對。
\(n\leq 300000\)
題解
先把字尾樹建出來。
DFS 整棵樹,維護當前子樹的 right 集合。
合併兩個集合的時候暴力列舉小的那個集合,然後在另一個集合的線段樹中查詢相應的資訊計算貢獻。
怎麼計算呢?
如果兩個位置之差 \(>\) 這兩個位置的 \(lcp\)(即當前點的深度),那麼貢獻就是 \(lcp\),否則是位置之差。
線段樹記錄區間點數和位置之和即可。
時間複雜度:\(O(n\log^2n)\),好像能做到 \(O(n\log n)\)。
程式碼
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<ctime> #include<functional> #include<cmath> #include<vector> #include<assert.h> #include<map> using namespace std; using std::min; using std::max; using std::swap; using std::sort; using std::reverse; using std::random_shuffle; using std::lower_bound; using std::upper_bound; using std::unique; using std::vector; typedef long long ll; typedef unsigned long long ull; typedef double db; typedef std::pair<int,int> pii; typedef std::pair<ll,ll> pll; void open(const char *s){ #ifndef ONLINE_JUDGE char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout); #endif } void open2(const char *s){ #ifdef DEBUG char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout); #endif } int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;} void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');} int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;} int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;} const int N=300010; int n; int a[N]; int d[N]; ll ans; int min(int a,int b) { return a<b?a:b; } int max(int a,int b) { return a>b?a:b; } namespace seg { int s1[20000000]; ll s2[20000000]; int lc[20000000]; int rc[20000000]; int cnt; #define mid ((L+R)>>1) void mt(int p) { s1[p]=s1[lc[p]]+s1[rc[p]]; s2[p]=s2[lc[p]]+s2[rc[p]]; } int insert(int p,int x,int L,int R) { if(!p) p=++cnt; if(L==R) { s1[p]=1; s2[p]=x; return p; } if(x<=mid) lc[p]=insert(lc[p],x,L,mid); else rc[p]=insert(rc[p],x,mid+1,R); mt(p); return p; } int query1(int p,int l,int r,int L,int R) { if(!p||(l<=L&&r>=R)) return s1[p]; int res=0; if(l<=mid) res+=query1(lc[p],l,r,L,mid); if(r>mid) res+=query1(rc[p],l,r,mid+1,R); return res; } ll query2(int p,int l,int r,int L,int R) { if(!p||(l<=L&&r>=R)) return s2[p]; ll res=0; if(l<=mid) res+=query2(lc[p],l,r,L,mid); if(r>mid) res+=query2(rc[p],l,r,mid+1,R); return res; } int merge(int p1,int p2,int L,int R) { if(!p1||!p2) return p1+p2; if(L==R) { s1[p1]+=s1[p2]; s2[p1]+=s2[p2]; return p1; } lc[p1]=merge(lc[p1],lc[p2],L,mid); rc[p1]=merge(rc[p1],rc[p2],mid+1,R); mt(p1); return p1; } } namespace sam { map<int,int> next[2*N]; int fail[2*N]; int len[2*N]; int c[2*N]; int last,cnt; void init() { last=cnt=1; } void append(int x,int v) { int np=++cnt; int p=last; c[np]=x; len[np]=len[p]+1; for(;p&&!next[p][v];p=fail[p]) next[p][v]=np; if(!p) fail[np]=1; else { int q=next[p][v]; if(len[q]==len[p]+1) fail[np]=q; else { int nq=++cnt; len[nq]=len[p]+1; next[nq]=next[q]; fail[nq]=fail[q]; fail[q]=fail[np]=nq; for(;p&&next[p][v]==q;p=fail[p]) next[p][v]=nq; } } last=np; } vector<int> g[2*N],*e[2*N]; int sz[2*N]; int rt[2*N]; void merge(int x,int y,int l) { if(sz[x]<sz[y]) { for(auto v:*e[x]) { e[y]->push_back(v); int s1=0,_; // if(v-l-1>=1) // ans+=(ll)l*seg::query1(rt[y],1,v-l-1,1,n); if(v-1>=1) { ans+=(ll)(v-1)*(_=seg::query1(rt[y],max(1,v-l),v-1,1,n))-seg::query2(rt[y],max(1,v-l),v-1,1,n); s1+=_; } if(v+1<=n) { ans+=seg::query2(rt[y],v+1,min(n,v+l),1,n)-(ll)(v+1)*(_=seg::query1(rt[y],v+1,min(n,v+l),1,n)); s1+=_; } // if(v+l+1<=n) // ans+=(ll)l*seg::query1(rt[y],v+l+1,n,1,n); ans+=(ll)l*(sz[y]-s1); } e[x]=e[y]; } else { for(auto v:*e[y]) { e[x]->push_back(v); int s1=0,_; // if(v-l-1>=1) // ans+=(ll)l*seg::query1(rt[x],1,v-l-1,1,n); if(v-1>=1) { ans+=(ll)(v-1)*(_=seg::query1(rt[x],max(1,v-l),v-1,1,n))-seg::query2(rt[x],max(1,v-l),v-1,1,n); s1+=_; } if(v+1<=n) { ans+=seg::query2(rt[x],v+1,min(n,v+l),1,n)-(ll)(v+1)*(_=seg::query1(rt[x],v+1,min(n,v+l),1,n)); s1+=_; } // if(v+l+1<=n) // ans+=(ll)l*seg::query1(rt[x],v+l+1,n,1,n); ans+=(ll)l*(sz[x]-s1); } } sz[x]+=sz[y]; rt[x]=seg::merge(rt[x],rt[y],1,n); } void dfs(int x) { e[x]=new vector<int>(); if(c[x]) { rt[x]=seg::insert(rt[x],c[x],1,n); e[x]->push_back(c[x]); sz[x]=1; } for(auto v:g[x]) { dfs(v); merge(x,v,len[x]); } } void solve() { for(int i=2;i<=cnt;i++) g[fail[i]].push_back(i); dfs(1); } } int main() { open("c"); scanf("%d",&n); ans+=(ll)n*(n-1)/2; for(int i=1;i<=n;i++) a[i]=rd(); n--; for(int i=1;i<=n;i++) a[i]=a[i]-a[i+1]; for(int i=1;i<=n;i++) d[i]=a[i]; sort(d+1,d+n+1); int t=unique(d+1,d+n+1)-d-1; for(int i=1;i<=n;i++) a[i]=lower_bound(d+1,d+t+1,a[i])-d; sam::init(); for(int i=n;i>=1;i--) sam::append(i,a[i]); sam::solve(); printf("%lld\n",ans); return 0; }