[八省聯考2018]制胡竄 (SAM+大討論)
阿新 • • 發佈:2021-09-03
正著做著實不太好做,正難則反,考慮反著做。
把i,j看成在切割字串,我們統計有多少對(i,j)會切割所有與\(s_{l,r}\)相同的串。對於在後綴自動機上表示\(s_{l,r}\)的節點x,x的parent子樹內的endpos節點集合,就是和\(s_{l,r}\)相等的串的最後一個字元的出現位置。我們相當於在s串裡得到了若干個線段,每個線段表示的子串都和\(s_{l,r}\)相等,然後用兩刀把這些串都割了。我們分最左邊的串和最右邊的串是否存在交集進行討論。
如果存在交集,線段數量是m
1.第一刀切串[1,i],第二刀切[i+1,m],方案數\((r_{i+1}-r_{i})(r_{i+1}-l_{m})\)
2.第一刀切[1,m],第二刀在第一刀右面隨便切,是一個等差數列
3.第一刀切在第一個串左邊,第二刀切在交集,一個乘法原理
如果不存在交集
可行的位置收到了限制,我們要求第一刀必須切第一個串,第二刀必須切第m個串,我們討論出第一刀可行的線段編號區間[L,R],再統計方案數。
總之兩種情況都需要維護\(\sum_{i=L}^{R}(r_{i+1}-r_{i})(r_{i+1}-l_{m})\)這個式子,把它拆開。
\[\sum_{i=L}^{R}(r_{i+1}-r_{i})(r_{i+1}-l_{m}) \\=\sum_{i=L}^{R}(\ (r_{i+1}^{2}-r_{i}r_{i+1})-l_{m}(r_{i+1}-r_{i})\ ) \\=\sum_{i=L}^{R}(r_{i+1}^{2}-r_{i}r_{i+1})-l_{m}(r_{R}-r_{L}) \]常用套路,用線段樹合併維護endpos集合,和式第二項維護相鄰兩項的乘積,對應pushup時左區間max和右區間min,我們需要維護一段區間內最大/最小值,再維護和式即可
#include <bits/stdc++.h> #define ll long long #define ull unsigned long long using namespace std; template <typename _T> void read(_T &ret) { ret=0; _T fh=1; char c=getchar(); while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); } while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); } ret=ret*fh; } const int N1=1e5+5, S1=N1*2, M1=S1*70, inf=0x3f3f3f3f; struct EDGE{ int to[S1],nxt[S1],head[S1],cte; void ae(int u,int v) { cte++; to[cte]=v, nxt[cte]=head[u], head[u]=cte; } }e; struct node{ ll sum; int mi,ma; friend node operator + (const node &s1,const node &s2) { return (node){s1.sum+s2.sum-((s2.mi!=inf)?1ll*s1.ma*s2.mi:0ll), min(s1.mi,s2.mi) , max(s1.ma,s2.ma)}; } }; int n,Q; char str[N1]; int idx(char c){ return c-'0'; } struct SEG{ int mi[M1],ma[M1],ls[M1],rs[M1],root[S1],tot; ll sum[M1]; void init(){ mi[0]=inf; } void pushup(int rt) { mi[rt]=min(mi[ls[rt]],mi[rs[rt]]); ma[rt]=max(ma[ls[rt]],ma[rs[rt]]); sum[rt]=sum[ls[rt]]+sum[rs[rt]]; if(mi[rs[rt]]!=inf) sum[rt]-=1ll*ma[ls[rt]]*mi[rs[rt]]; } void ins(int x,int l,int r,int &rt) { if(!rt) rt=++tot; if(l==r){ mi[rt]=ma[rt]=l; sum[rt]=1ll*l*l; return; } int mid=(l+r)>>1; if(x<=mid) ins(x,l,mid,ls[rt]); else ins(x,mid+1,r,rs[rt]); pushup(rt); } //位置互不相同 線上段樹葉節點一定會return 無需額外特判 int merge(int r1,int r2) { if(!r1||!r2) return r1+r2; int rt=++tot; ls[rt]=merge(ls[r1],ls[r2]); rs[rt]=merge(rs[r1],rs[r2]); pushup(rt); return rt; } int lower(int x,int l,int r,int rt) { if(l==r){ if(mi[rt]<=x) return mi[rt]; else return -1; } int mid=(l+r)>>1; if(mi[rs[rt]]<=x) return lower(x,mid+1,r,rs[rt]); else return lower(x,l,mid,ls[rt]); } int upper(int x,int l,int r,int rt) { if(l==r){ if(ma[rt]>=x) return ma[rt]; else return -1; } int mid=(l+r)>>1; if(ma[ls[rt]]>=x) return upper(x,l,mid,ls[rt]); else return upper(x,mid+1,r,rs[rt]); } node query(int L,int R,int l,int r,int rt) { if(L<=l&&r<=R){ return (node){sum[rt],mi[rt],ma[rt]}; } int mid=(l+r)>>1; node ans=(node){0ll,inf,0}; if(L<=mid) ans=(ans+query(L,R,l,mid,ls[rt])); if(R>mid) ans=(ans+query(L,R,mid+1,r,rs[rt])); return ans; } }s; int trs[S1][10],pre[S1],dep[S1],id[S1],tot,la; void init(){ tot=la=1; } void insert(int c,int i) { int p=la,np=++tot,q,nq; la=np; dep[np]=dep[p]+1; s.ins(i,1,n,s.root[np]); id[i]=np; for(;p&&!trs[p][c];p=pre[p]) trs[p][c]=np; if(!p){ pre[np]=1; return; } q=trs[p][c]; if(dep[q]==dep[p]+1) pre[np]=q; else{ pre[nq=++tot]=pre[q]; pre[q]=pre[np]=nq; dep[nq]=dep[p]+1; memcpy(trs[nq],trs[q],sizeof(trs[nq])); for(;p&&trs[p][c]==q;p=pre[p]) trs[p][c]=nq; } } int ff[S1][19]; void dfs(int x) { for(int j=2;j<=18;j++) ff[x][j]=ff[ ff[x][j-1] ][j-1]; for(int j=e.head[x];j;j=e.nxt[j]){ int v=e.to[j]; dfs(v); s.root[x]=s.merge(s.root[x],s.root[v]); } } void build() { for(int i=2;i<=tot;i++) e.ae(pre[i],i), ff[i][0]=i, ff[i][1]=pre[i]; dfs(1); } int main() { // freopen("1.in","r",stdin); read(n); read(Q); scanf("%s",str+1); init(); s.init(); for(int i=1;i<=n;i++) insert(idx(str[i]),i); build(); int l,r,x,len; for(int q=1;q<=Q;q++){ read(l); read(r); len=r-l+1; x=id[r]; // for(;dep[pre[x]]<=len;x=pre[x]) for(int j=18;j>=0;j--) if(dep[ff[x][j]]>=len) x=ff[x][j]; ll ans=1ll*(n-1)*(n-2)/2,tmp=0; int r1=s.mi[s.root[x]], rm=s.ma[s.root[x]], lm=rm-len+1, l1=r1-len+1; if(r1>lm){ //s1與sm有交 tmp+=s.sum[s.root[x]]-1ll*r1*r1-1ll*lm*(rm-r1); tmp+=max(0ll,1ll*(2*n-lm-1-r1)*(r1-lm)/2); tmp+=max(0ll,1ll*(l1-1)*(r1-lm)); }else{ int L=s.lower(lm,1,n,s.root[x]); int R=s.lower(r1+len-2,1,n,s.root[x]), lR=R-len+1; int nxt=s.upper(R+1,1,n,s.root[x]); if(L!=-1 && r!=-1 && L<=R){ node k=s.query(L,R,1,n,s.root[x]); tmp+=k.sum-1ll*L*L-1ll*lm*(R-L); tmp+=1ll*(r1-lR)*(nxt-lm); } } ans-=tmp; printf("%lld\n",ans); } // printf("%llu\n",(sizeof(s)+sizeof(ff)+sizeof(e)+sizeof(trs))/1024/1024); return 0; }