ICPC2018焦作站 H. Can You Solve the Harder Problem?
阿新 • • 發佈:2020-11-25
題意:
給出一個數字串,問所有本質不同的子串的最大值之和
如果沒有本質不同的要求,就是用單調棧求出每個數字前後第一個大於它的位置,掃一遍計算即可
現在要本質不同,用字尾陣列
按字典序依次計算每個字尾的貢獻
對於已經按字典序從小到大排好序的字尾i-1和i來說
以i為子串左端點,[i,height[i]]之間的為子串右端點 的子串已經在以i-1位左端點的字串中統計過了
所以對於每一個字尾i,新增的是以i為左端點,以[i+height[i],n]為右端點的子串的答案
每個字尾的答案分為兩部分
設區間[i,i+height[i]-1]之間最大的是第j個數,第j個數後面第一個更大的是第k個數
第一部分為以i為左端點,以[k,n]之間的為右端點的答案
它的答案等同於以k為左端點,以[k,n]之間的為右端點的答案
這可以提前對每一個k都求出來
求法:
假設第i個數後面第一個更大的是第nxt個數,那麼以i為左端點的貢獻就是(nxt-i)* 第i個數
從右往左掃一遍,可以求出i為左端點,以[i,n]之間的為右端點的答案
第二部分為以i為左端點,以[i+height[i],k-1]為右端點的子串答案
這一共有k-1-(i+height[i])個區間,他們的最大值都是第j個數
找這個j的方法有很多
可以根據nxt倍增
求nxt可以用單調棧
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 200002 #define M 1000002 typedef long long LL; int n,k,mx,a[N],v[M],p,q,sa[2][N],rk[2][N],h[N]; int st[N],top; int nxt[N][20],qp,bit[20]; LL sum[N]; void mul(int *sa,int *rk,int *SA,int *RK) { for(int i=1;i<=n;i++) v[rk[sa[i]]]=i; for(int i=n;i;i--)if(sa[i]>k) SA[v[rk[sa[i]-k]]--]=sa[i]-k; for(int i=n-k+1;i<=n;i++) SA[v[rk[i]]--]=i; for(int i=1;i<=n;i++) RK[SA[i]]=RK[SA[i-1]]+(rk[SA[i]]!=rk[SA[i-1]]||rk[SA[i]+k]!=rk[SA[i-1]+k]); } void presa() { p=0; q=1; qp=0; for(int i=0;i<=mx;++i) v[i]=0; for(int i=1;i<=n;i++) v[a[i]]++; for(int i=1;i<=mx;i++) v[i]+=v[i-1]; for(int i=1;i<=n;i++) sa[p][v[a[i]]--]=i; for(int i=1;i<=n;i++) rk[p][sa[p][i]]=rk[p][sa[p][i-1]]+(a[sa[p][i-1]]!=a[sa[p][i]]); for(k=1;k<n;k<<=1,swap(p,q),++qp) mul(sa[p],rk[p],sa[q],rk[q]); for(int i=1,kk=0;i<=n;i++) { int j=sa[p][rk[p][i]-1]; while(a[i+kk]==a[j+kk]) kk++; h[rk[p][i]]=kk; if(kk) kk--; } } void solve() { st[top=1]=n+1; for(int i=n;i;--i) { while(top && a[i]>=a[st[top]]) top--; nxt[i][0]=st[top]; st[++top]=i; } nxt[n+1][0]=n+1; for(int i=1;i<=qp;++i) for(int j=1;j<=n+1;++j) nxt[j][i]=nxt[nxt[j][i-1]][i-1]; for(int i=1;i<=n;++i) sum[i]=1ll*(nxt[i][0]-i)*a[i]; for(int i=n;i;--i) if(nxt[i][0]<=n) sum[i]+=sum[nxt[i][0]]; LL ans=sum[sa[p][1]]; int s,pos,now,to; for(int i=2;i<=n;++i) { s=sa[p][i]+h[i]; now=sa[p][i]; for(int j=qp;j>=0;--j) if(nxt[now][j]<s) now=nxt[now][j]; to=nxt[now][0]; if(to<=n) ans+=sum[to]; ans+=1ll*(to-s)*a[now]; } printf("%lld\n",ans); } int main() { int T; scanf("%d",&T); bit[0]=1; for(int i=1;i<=19;++i) bit[i]=bit[i-1]<<1; while(T--) { scanf("%d",&n); for(int i=1;i<=n;++i) { scanf("%d",&a[i]); mx=max(mx,a[i]); } a[n+1]=2e6; presa(); //for(int i=1;i<=n;i++) printf("%d ",sa[p][i]);puts(""); //for(int i=2;i<=n;i++) printf("%d ",h[i]); solve(); } }