1. 程式人生 > >『題解』[NOI2016]優秀的拆分

『題解』[NOI2016]優秀的拆分

如果一個字串可以被拆分為\(AABB\)的形式,其中$A和 B是任意非空字串,則我們稱該字串的這種拆分是優秀的。

例如,對於字串\(aabaabaa\),如果令\(A=aab\)\(B=a\),我們就找到了這個字串拆分成\(AABB\)的一種方式。

一個字串可能沒有優秀的拆分,也可能存在不止一種優秀的拆分。比如我們令\(A=a\)\(B=baa\),也可以用 AABB表示出上述字串;但是,字串\(abaabaa\)就沒有優秀的拆分。

現在給出一個長度為\(n\)的字串\(S\),我們需要求出,在它所有子串的所有拆分方式中,優秀拆分的總個數。這裡的子串是指字串中連續的一段。

以下事項需要注意:

  • 出現在不同位置的相同子串,我們認為是不同的子串,它們的優秀拆分均會被記入答案。

  • 在一個拆分中,允許出現\(A=B\)。例如\(cccc\)存在拆分\(A=B=c\)

  • 字串本身也是它的一個子串。

對於\(AABB\),我們完全可以只考慮\(AA\),這樣令\(f[i]\)表示以i結尾的\(AA\)數量,\(g[i]\)表示以\(i\)開頭的\(AA\)數量,那麼顯然就是\(sigma(f[i]g[i+1])\)

對於\(AA\)怎麼求,大體的思路和URAL1297:Palindrome求迴文串是一樣的,就是通過比較字尾的公共字首來得到AA的長度,進而求出這段區間內\(f[i]g[i]\)

的值。

但是這樣顯然是\(O(n^{2})\)的。

我們用分塊的思想,列舉\(l\),將字串分成\(l\)大小的塊,則長度為\(l\)\(AA\)一定最少跨過兩個塊,於是對於塊邊界點,求一次公共字首和字尾,拼在一起就是我們所要的答案,複雜度調和級數\(O(n×log_{2}^{n})\)

注意,為了讓複雜度正確,我們對區間的\(f\)\(g\)差分。

程式碼:

#include<cstdio>
#include<cmath>
#include<iostream>
#include<vector>
#include<cstring>
#include<algorithm>
#include<cctype>
using namespace std;
typedef long long ll;
const int N=2e6+10;
char s[N];
int n,m,rk[N],height[N],sa[N],w[N],cas,dp[N][21],lg[N];
int f[N],g[N];
inline int qpow(int a)
{
    return 1<<a;
}
inline bool pan(int *x,int i,int j,int k)
{
    int ti=i+k<n?x[i+k]:-1;
    int tj=j+k<n?x[j+k]:-1;
    return ti==tj&&x[i]==x[j];
}
void SA_init()
{
    int *x=rk,*y=height,r=256;
    for(int i=0; i<r; i++)w[i]=0;
    for(int i=0; i<n; i++)w[s[i]]++;
    for(int i=1; i<r; i++)w[i]+=w[i-1];
    for(int i=n-1; i>=0; i--)sa[--w[s[i]]]=i;
    r=1;
    x[sa[0]]=0;
    for(int i=1; i<n; i++)
        x[sa[i]]=s[sa[i]]==s[sa[i-1]]?r-1:r++;
    for(int k=1; r<n; k<<=1)
    {
        int yn=0;
        for(int i=n-k; i<n; i++)y[yn++]=i;
        for(int i=0; i<n; i++)
            if(sa[i]>=k)y[yn++]=sa[i]-k;
        for(int i=0; i<r; i++)w[i]=0;
        for(int i=0; i<n; i++)w[x[y[i]]]++;
        for(int i=1; i<r; i++)w[i]+=w[i-1];
        for(int i=n-1; i>=0; i--)sa[--w[x[y[i]]]]=y[i];
        swap(x,y);
        r=1;
        x[sa[0]]=0;
        for(int i=1; i<n; i++)
            x[sa[i]]=pan(y,sa[i],sa[i-1],k)?r-1:r++;
    }
}
inline void height_init()
{
    int i,j,k=0;
    for(int i=1; i<=n; i++)rk[sa[i]]=i;
    for(int i=0; i<n; i++)
    {
        if(k)k--;
        j=sa[rk[i]-1];
        while(s[i+k]==s[j+k])k++;
        height[rk[i]]=k;
    }
}
void st_init()
{
    for(int i=1; i<=n; i++)
    {
        dp[i-1][0]=height[i];
        lg[i]=lg[i-1];
        if((1<<lg[i]+1)==i)lg[i]++;
    }
    for(int j=1; j<=lg[n]; j++)
    {
        for(int i=0; i<n; i++)
        {
            if(i+qpow(j)-1>=n)break;
            dp[i][j]=min(dp[i][j-1],dp[i+qpow(j-1)][j-1]);
        }
    }
}
int lcp(int a,int b)
{
    int l=rk[a],r=rk[b];
    if(r<l)swap(l,r);
    l--;
    r--;
    if(r<0)return 0;
    l++;
    int len=r-l+1;
    int k=lg[len];
    int h=qpow(k);
    return min(dp[l][k],dp[r-h+1][k]);
}
int main()
{
    scanf("%d",&cas);
    while(cas--)
    {
        memset(f,0,sizeof(f));
        memset(g,0,sizeof(g));
        cin>>s;
        m=strlen(s),n=2*m+1;
        s[m]='$';
        for(int i=m+1; i<n; i++)
        {
            s[i]=s[n-i-1];
        }
        s[n++]=0;
        SA_init();
        n--;
        height_init();
        st_init();
        for(int l=1; l<=m/2; l++)
        {
            for(int i=0,j=l; j<m; i+=l,j+=l)
            {
                int p=min(l,lcp(i,j));
                int s=min(l,lcp(n-i-1,n-j-1));
                if(p+s-1>=l)
                {
                    f[j-s+l]++;
                    f[j+p]--;
                    g[i-s+1]++;
                    g[i+p-l+1]--;
                }
            }
        }
        ll ans=0;
        for(int i=1; i<m; i++)
        {
            f[i]+=f[i-1];
            g[i]+=g[i-1];
        }
        for(int i=0; i<m-1; i++)
        {
            ans+=(ll)f[i]*g[i+1];
        }
        printf("%lld\n",ans);
    }
    return 0;
}