1. 程式人生 > 實用技巧 >Codeforces 932G - Palindrome Partition 迴文自動機 dp

Codeforces 932G - Palindrome Partition 迴文自動機 dp

Codeforces 932G - Palindrome Partition

題意

給定一個字串 \(s\),要求將 \(s\) 劃分為 \(t_1,t_2,\dots,t_k\),其中 \(k\) 是偶數,且 \(t_i=t_{k-i+1}\),求這樣的劃分方案數。

\(|S|\le 10^6\)

分析

構造字串\(t=s[1]s[n]s[2]s[n-1]s[3]s[n-2]\dots s[n/2]s[n/2+1]\),問題就等價於求\(t\)的最小偶迴文劃分方案數,對\(t\)構建迴文自動機。根據這篇文章的證明,\(s\) 的所有迴文字尾按照長度排序後,可以劃分成\(\text{log|s|}\)

段等差數列。在迴文自動機上的每個結點\(u\)多維護兩個資訊,\(diff[u]\)\(slink[u]\)\(diff[u]=len[u]-len[fail[u]]\)\(slink[u]\)表示\(u\)一直沿著\(fail\)向上跳第一個結點\(v\),使得\(diff[v]\ne diff[u]\)。用\(g[x]\)維護\(x\)所在的等差數列的\(dp\)值的和,然後不斷的跳\(slink[x]\),用\(g[x]\)更新\(dp[i]\)。對於每個\(i\)\(slink\)最多跳\(log\)次,時間複雜度為\(nlogn\)

Code

#include<bits/stdc++.h>
#define mp make_pair
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
#define ll long long
using namespace std;
const int inf=1e9;
const int mod=1e9+7;
const int maxn=1e6+10;
char s[maxn],t[maxn];
int T,n,now;
int dp[maxn],g[maxn];
void add(int &x,int y){
    x+=y;
    if(x>mod) x-=mod;
}
struct PAM{
    int son[maxn][26],fail[maxn],len[maxn],dif[maxn],slink[maxn],tot,last;
    int newnode(int x){
        ++tot;
        for(int i=0;i<26;i++) son[tot][i]=0;
        fail[tot]=0;len[tot]=x;
        return tot;
    }
    void init(){
        tot=-1;newnode(0);newnode(-1);
        fail[0]=1;
        last=0;
    }
    int gao(int x){
        while(s[now-len[x]-1]!=s[now]) x=fail[x];
        return x;
    }
    void insert(){
        int p=gao(last);
        if(!son[p][s[now]-'a']){
            int tmp=son[gao(fail[p])][s[now]-'a'];
            son[p][s[now]-'a']=newnode(len[p]+2);
            fail[tot]=tmp;
            dif[tot]=len[tot]-len[tmp];
            if(dif[tot]==dif[fail[tot]]){
                slink[tot]=slink[fail[tot]];
            }else{
                slink[tot]=fail[tot];
            }
        }
        last=son[p][s[now]-'a'];
    }
    int solve(){
        for(now=1;now<=n;now++){
            insert();
            for(int x=last;x>1;x=slink[x]){
                g[x]=dp[now-dif[x]-len[slink[x]]];
                if(dif[x]==dif[fail[x]]) add(g[x],g[fail[x]]);
                if(now%2==0) add(dp[now],g[x]);
            }
        }
        return dp[n];
    }
}P;
int main(){
    scanf("%s",t+1);
    n=strlen(t+1);
    P.init();
    dp[0]=1;
    for(int i=1,j=1;i<=n;i+=2,j++){
        s[i]=t[j];
        s[i+1]=t[n-j+1];
    }
    printf("%d\n",P.solve());
    return 0;
}