1. 程式人生 > >洛谷4248 AHOI2013差異 (字尾陣列SA+單調棧)

洛谷4248 AHOI2013差異 (字尾陣列SA+單調棧)

題目連結

補部落格!

首先我們觀察題目中給的那個求\(ans\)的方法,其實前兩項沒什麼用處,直接\(for\)一遍就求得了

for (int i=1;i<=n;i++) ans=ans+i*(n-1);

那麼我們考慮剩下的部分應該怎麼求解!
首先這裡有一個性質。對於任意兩個字尾\(i,j\),他們的\(lcp\)長度是他們對應的\(rank\)之間的\(height\)\(min\) (左開右閉)

或者這樣說
\(lcp(i,j) = min(height[rank[i]+1],height[rank[i]+2].....,height[rank[j]]) 其中rank[i]<rank[j]\)

那麼對於這個題,我們就可以直接維護出每個\(height\)作為最小值的區間,然後用他的區間個乘上貢獻即可(但是具體這裡求的時候需要仔細想想,因為那個左開右閉的區間,假設右邊能選的端點是\(r[i]-l+1\),那麼合法的右端點實際上是由\(i-l[i]+1\)因為,能覆蓋到\(l[i]\)這個\(height\)的點實際上是\(l[i]-1\)。)

總之就是比較難理解啊

for (int i=1;i<=n;i++) ans=ans-2ll*(r[i]-i+1)*(i-l[i]+1)*height[i];

那麼現在的問題就是應該怎麼求\(l[i]和r[i]\)呢?

QWQ這貌似是單調棧的經典應用?

直接從左到右,從右到左掃兩遍即可.
這裡有一個很好的防止計算重複的方法

就是我們從左到右掃維護的棧是單調的。然後從右到左不單調(非嚴格)

或者說,一遍單調,一遍不單調,即可解決重複的問題了!

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}
const int maxn = 2e6+1e2;
struct Node{
    int val,pos;
};
int wb[maxn],sa[maxn];
Node s[maxn];
int l[maxn],r[maxn];
int rk[maxn],h[maxn],height[maxn];
int tmp[maxn];
int n,m;
char a[maxn];
int ans;
void getsa()
{
    int *x = rk,*y = tmp;
    int s = 128;
    int p = 0;
    for (int i=1;i<=n;i++) x[i]=a[i],y[i]=i;
    for (int i=1;i<=s;i++) wb[i]=0;
    for (int i=1;i<=n;i++) wb[x[y[i]]]++;
    for (int i=1;i<=s;i++) wb[i]+=wb[i-1];
    for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--] = y[i];
    for (int j=1;p<n;j<<=1)
    {
        p=0;
        for (int i=n-j+1;i<=n;i++) y[++p]=i;
        for (int i=1;i<=n;i++) if (sa[i]>j) y[++p]=sa[i]-j;
        for (int i=1;i<=s;i++) wb[i]=0;
        for (int i=1;i<=n;i++) wb[x[y[i]]]++;
        for (int i=1;i<=s;i++) wb[i]+=wb[i-1];
        for (int i=n;i>=1;i--) sa[wb[x[y[i]]]--] =y[i];
        swap(x,y);
        p=1;
        x[sa[1]]=1;
        for (int i=2;i<=n;i++)
        {
            x[sa[i]] = (y[sa[i-1]]==y[sa[i]] && y[sa[i]+j]==y[sa[i-1]+j]) ? p : ++p;
        }
        s=p;
    }
    for (int i=1;i<=n;i++) rk[sa[i]]=i;
    h[0]=0;
    for (int i=1;i<=n;i++)
    {
        h[i]=max(h[i-1]-1,(long long)0);
        while (i+h[i]<=n && sa[rk[i]-1]+h[i]<=n && a[i+h[i]]==a[sa[rk[i]-1]+h[i]]) h[i]++;
    }
    for (int i=1;i<=n;i++) height[i] = h[sa[i]];
}
int top;
signed main()
{
  scanf("%s",a+1);
  n = strlen(a+1);
  getsa();
  for (int i=1;i<=n;i++) ans=ans+i*(n-1);
  l[1]=1;
  s[++top].val=height[1];
  s[1].pos=1;
  for (int i=2;i<=n;i++)
  {
    while (top>=1 && s[top].val>=height[i]) top--;
    if (!top) l[i]=1;
    else l[i]=s[top].pos+1;
    s[++top].val=height[i];
    s[top].pos=i;
  } 
  memset(s,0,sizeof(s));
  top=1;
  r[n]=n;
  s[top].val=height[n];
  s[top].pos=n;
  for (int i=n-1;i>=1;i--)
  {
    while (top>=1 && s[top].val>height[i]) top--;
    if (!top) r[i]=n;
    else r[i]=s[top].pos-1;
    s[++top].val=height[i];
    s[top].pos=i;
  }
  for (int i=1;i<=n;i++) ans=ans-2ll*(r[i]-i+1)*(i-l[i]+1)*height[i];
  cout<<ans;
  return 0;
}