1. 程式人生 > >bzoj4231 回憶樹(AC自動機+fail樹+KMP(+樹狀陣列))

bzoj4231 回憶樹(AC自動機+fail樹+KMP(+樹狀陣列))

bzoj4231 回憶樹

題意:
回憶樹是樹。
具體來說,是n個點n-1條邊的無向連通圖,點標號為1~n,每條邊上有一個字元(出於簡化目的,我們認為只有小寫字母)。
對一棵回憶樹來說,回憶當然是少不了的。
一次回憶是這樣的:你想起過往,觸及心底…唔,不對,我們要說題目。
這題中我們認為回憶是這樣的:給定2個點u,v(u可能等於v)和一個非空字串s,問從u到v的簡單路徑上的所有邊按照到u的距離從小到大的順序排列後,邊上的字元依次拼接形成的字串中給定的串s出現了多少次。

資料範圍
n<=100000,m<=100000,詢問串的總長<=300000

題解:
好題…
要求字串在一條鏈上匹配多少次,對回憶樹建樹是沒有辦法的,不能像bzoj3926那樣每個葉子提出來建一棵Trie。
於是這道題是離線,對要查詢的串(的正串和反串)建AC自動機。

同樣,A包含串B多少次,就是A在AC自動機上的每個節點,有多少在B結尾節點的fail樹子樹中。
於是,DFS原樹的同時,在AC自動機上匹配,

由於鏈是要拐彎的,同時還有方向,這個不好處理,於是把一條鏈拆成三部分:
拐彎處:長度為2|T|,用KMP暴力匹配
剩下的兩段一個是正著匹配,一個是倒著,於是詢問串需要把正反串都插入AC自動機,

於是詢問都變成了從根到某個點路徑的一部分(直了…),就可以一邊dfs一邊處理了。

就如同天天愛跑步的處理方式,在對應的起始端點push進這個詢問,以及是需要加還是減,查詢的是fail樹的哪個子樹。

然後在DFS原樹同時在AC自動機上轉移時,進入這個點把AC自動機上對應點+1,離開時-1,
詢問就是查詢fail樹子樹權值和。

程式碼:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<queue>
using namespace std;
const int N=110005;
const int M=300015;
queue<int> Q;
vector<int> V[N],ID[N];
char str[M];
namespace AC
{
    int ch[2*M][26
],root,tail,head[2*M],nxt[2*M],to[2*M],num,fail[2*M]; int in[2*N],out[2*M],inc,C[2*M]; inline void build(int u,int v) {num++; to[num]=v; nxt[num]=head[u]; head[u]=num; inc=0;} inline void init(){root=1,tail=1,num=0;} inline int insert(int opt) { int len=strlen(str); int tmp=root; for(int j=0,i;j<len;j++) { if(opt==1) i=j; else i=len-j-1; int c=str[i]-'a'; if(!ch[tmp][c]) ch[tmp][c]=++tail; tmp=ch[tmp][c]; } return tmp; } inline void getfail() { for(int i=0;i<26;i++) if(ch[root][i]) fail[ch[root][i]]=root,build(root,ch[root][i]),Q.push(ch[root][i]); else ch[root][i]=root; while(!Q.empty()) { int top=Q.front(); Q.pop(); for(int i=0;i<26;i++) { if(!ch[top][i]) ch[top][i]=ch[fail[top]][i]; else { int u=ch[top][i]; fail[u]=ch[fail[top]][i]; build(fail[u],u); Q.push(u); } } } } inline void dfs(int u) { inc++; in[u]=inc; for(int i=head[u];i;i=nxt[i]) dfs(to[i]); out[u]=inc; } inline void add(int x,int d){for(int i=x;i<=inc;i=i+(i&(-i))) C[i]+=d;} inline int query(int x){int ret=0;for(int i=x;i;i=i-(i&(-i))) ret+=C[i]; return ret;} } int head[N],to[2*N],w[2*N],nxt[2*N],num=0,n,m,ans[N],fa[N],dep[N],_w[N],size[N],son[N],top[N],dfn=0,seq[N],loc[N]; int nx[M],s[M],t[M],pos[N][2]; inline void build(int u,int v,int ww) { num++; w[num]=ww; to[num]=v; nxt[num]=head[u]; head[u]=num; } inline void dfs(int u,int f) { dep[u]=dep[f]+1; fa[u]=f; size[u]=1; for(int i=head[u];i;i=nxt[i]) { if(to[i]==f) continue; _w[to[i]]=w[i]; dfs(to[i],u); size[u]+=size[to[i]]; if(size[son[u]]<size[to[i]]) son[u]=to[i]; } } inline void dfs1(int u,int f,int tp) { loc[u]=++dfn,seq[dfn]=u; top[u]=tp; if(son[u]) dfs1(son[u],u,tp); for(int i=head[u];i;i=nxt[i]) { if(to[i]==f||to[i]==son[u]) continue; dfs1(to[i],u,to[i]); } } inline int getlca(int u,int v) { while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); u=fa[top[u]]; } return dep[u]<dep[v]?u:v; } inline int getpoint(int u,int d) { while(dep[u]-dep[top[u]]<d){d-=(dep[u]-dep[top[u]]+1),u=fa[top[u]];} return seq[loc[u]-d]; } inline void kmp(int u,int v,int lca,int id) { int lent=strlen(str),lens=0; int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1)); int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1)); lens=dep[x]-dep[lca]+dep[y]-dep[lca]; int tmp=x,i=0,j; while(tmp!=lca) s[i++]=_w[tmp],tmp=fa[tmp]; tmp=y,i=1; while(tmp!=lca) s[lens-i]=_w[tmp],tmp=fa[tmp],i++; for(int i=0;i<lent;i++) t[i]=str[i]-'a'; nx[0]=-1; i=0,j=-1; while(i<lent) { if(j==-1||t[i]==t[j]) {i++; j++; nx[i]=j;} else j=nx[j]; } i=0,j=0; int ret=0; while(i<lens) { if(j==-1||s[i]==t[j]) { i++; j++; if(j==lent){ret++; j=nx[j];} } else j=nx[j]; } pos[id][0]=AC::insert(1); pos[id][1]=AC::insert(-1); ans[id]=ret; if(u!=x) { ID[x].push_back(-id); ID[u].push_back(id); V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]); } if(v!=y) { ID[y].push_back(-id); ID[v].push_back(id); V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]); } } inline void dfs2(int u,int f,int x) { AC::add(AC::in[x],1); int sz=V[u].size(); for(int i=0;i<sz;i++) { int ret=AC::query(AC::out[V[u][i]])-AC::query(AC::in[V[u][i]]-1); if(ID[u][i]>0) ans[ID[u][i]]+=ret; else ans[-ID[u][i]]-=ret; } for(int i=head[u];i;i=nxt[i]) { int v=to[i]; if(v==f) continue; dfs2(v,u,AC::ch[x][w[i]]); } AC::add(AC::in[x],-1); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); scanf("%s",str); build(u,v,str[0]-'a'); build(v,u,str[0]-'a'); } dfs(1,1); dfs1(1,1,1); AC::init(); for(int i=1;i<=m;i++) { int u,v; scanf("%d%d",&u,&v); scanf("%s",str); if(u==v) continue; int lca=getlca(u,v); kmp(u,v,lca,i); } AC::getfail(); AC::dfs(1); dfs2(1,1,1); for(int i=1;i<=m;i++) printf("%d\n",ans[i]); return 0; }

然後這是我原來寫的倍增版本的程式碼(改正後),沒有namespace套namespace這種鬼畜玩意兒:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<queue>
using namespace std;
const int N=100005;
const int M=300015;
const int P=17;
queue<int> Q;
vector<int> V[N],ID[N];
char str[M];
struct AC_
{
    int ch[2*M][26],root,tail,head[2*M],nxt[2*M],to[2*M],num,fail[2*M];
    int in[2*N],out[2*M],inc,C[2*M];
    void build(int u,int v) {num++; to[num]=v; nxt[num]=head[u]; head[u]=num; inc=0;}
    void init(){root=1,tail=1,num=0;}
    int insert(int opt)
    {
        int len=strlen(str); int tmp=root;
        for(int j=0,i;j<len;j++)
        {
            if(opt==1) i=j; else i=len-j-1;
            int c=str[i]-'a';
            if(!ch[tmp][c]) ch[tmp][c]=++tail;
            tmp=ch[tmp][c];
        }
        return tmp;
    }
    void getfail()
    {
        for(int i=0;i<26;i++) if(ch[root][i]) fail[ch[root][i]]=root,build(root,ch[root][i]),Q.push(ch[root][i]); else ch[root][i]=root;
        while(!Q.empty())
        {
            int top=Q.front(); Q.pop();
            for(int i=0;i<26;i++)
            {
                if(!ch[top][i]) ch[top][i]=ch[fail[top]][i];
                else
                {
                    int u=ch[top][i];
                    fail[u]=ch[fail[top]][i];
                    build(fail[u],u);
                    Q.push(u);
                }
            }
        }       
    }
    void dfs(int u)
    {
        inc++; in[u]=inc;
        for(int i=head[u];i;i=nxt[i]) dfs(to[i]);
        out[u]=inc;
    }
    void add(int x,int d){for(int i=x;i<=inc;i=i+(i&(-i))) C[i]+=d;}
    inline int query(int x){int ret=0;for(int i=x;i;i=i-(i&(-i))) ret+=C[i]; return ret;}
}AC;
int head[N],to[2*N],w[2*N],nxt[2*N],num=0,n,m,ans[N],anc[N][P+3],dep[N],_w[N];
int nx[M],s[M],t[M],pos[N][2];
void build(int u,int v,int ww)
{
    num++;
    w[num]=ww;
    to[num]=v;
    nxt[num]=head[u];
    head[u]=num;
}
void dfs1(int u,int f)
{
    dep[u]=dep[f]+1; anc[u][0]=f;
    for(int i=1;i<P;i++) anc[u][i]=anc[anc[u][i-1]][i-1];
    for(int i=head[u];i;i=nxt[i]) if(to[i]!=f) _w[to[i]]=w[i],dfs1(to[i],u);
}
inline int getlca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    int d=dep[u]-dep[v];
    for(int i=0;d;d>>=1,i++) if(d&1) u=anc[u][i];
    if(u==v) return u;
    for(int i=P-1;i>=0;i--)
    if(anc[u][i]!=anc[v][i]) u=anc[u][i],v=anc[v][i];
    return anc[u][0];
}
inline int getpoint(int u,int d) {for(int i=0;d;d>>=1,i++) if(d&1) u=anc[u][i]; return u;}
void kmp(int u,int v,int lca,int id)
{
    int lent=strlen(str),lens=0;
    int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1));
    int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1));
    lens=dep[x]-dep[lca]+dep[y]-dep[lca];
    int tmp=x,i=0,j; while(tmp!=lca) s[i++]=_w[tmp],tmp=anc[tmp][0];
    tmp=y,i=1; while(tmp!=lca) s[lens-i]=_w[tmp],tmp=anc[tmp][0],i++;
    for(int i=0;i<lent;i++) t[i]=str[i]-'a';
    nx[0]=-1; i=0,j=-1;
    while(i<lent) 
    {
        if(j==-1||t[i]==t[j]) {i++; j++; nx[i]=j;}
        else j=nx[j];
    }
    i=0,j=0; int ret=0;
    while(i<lens)
    {
        if(j==-1||s[i]==t[j]) 
        {
            i++; j++;
            if(j==lent){ret++; j=nx[j];}
        }
        else j=nx[j];
    }
    pos[id][0]=AC.insert(1); pos[id][1]=AC.insert(-1);
    ans[id]=ret;
    if(u!=x)
    {
        ID[x].push_back(-id); ID[u].push_back(id);
        V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]);
    }
    if(v!=y)
    {
        ID[y].push_back(-id); ID[v].push_back(id);
        V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]);
    }
}
void dfs2(int u,int f,int x)
{
    AC.add(AC.in[x],1); 
    int sz=V[u].size(); 
    for(int i=0;i<sz;i++)
    {   
        int ret=AC.query(AC.out[V[u][i]])-AC.query(AC.in[V[u][i]]-1);
        if(ID[u][i]>0) ans[ID[u][i]]+=ret;
        else ans[-ID[u][i]]-=ret;
    }
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f) continue;
        dfs2(v,u,AC.ch[x][w[i]]);
    }
    AC.add(AC.in[x],-1);
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int u,v; scanf("%d%d%s",&u,&v,str);
        build(u,v,str[0]-'a'); build(v,u,str[0]-'a');
    }
    dfs1(1,1); AC.init();
    for(int i=1;i<=m;i++)
    {
        int u,v; scanf("%d%d",&u,&v); scanf("%s",str);
        if(u==v) continue; int lca=getlca(u,v); 
        kmp(u,v,lca,i);
    }
    AC.getfail();   
    AC.dfs(1);
    dfs2(1,1,1);
    for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
    return 0;
}

昨晚上腦子有點不清楚這裡寫圖片描述

自己沒寫對拍就一直交,卡了一晚上評測,非常抱歉,感謝大家的不殺之恩。

最後查出來的錯:
1.求lca時 swap(u,u);
2.kmp忘記 j=nxt[j];
3.自己nxt[N]nx[N]兩個陣列搞混了。

最初以為是倍增慢了,沒有去查無限迴圈的錯,改成鏈剖還是T,才想到是不是哪裡寫掛了,
長程式碼一定要靜態查錯+對拍。