bzoj4231 回憶樹(AC自動機+fail樹+KMP(+樹狀陣列))
bzoj4231 回憶樹
題意:
回憶樹是樹。
具體來說,是n個點n-1條邊的無向連通圖,點標號為1~n,每條邊上有一個字元(出於簡化目的,我們認為只有小寫字母)。
對一棵回憶樹來說,回憶當然是少不了的。
一次回憶是這樣的:你想起過往,觸及心底…唔,不對,我們要說題目。
這題中我們認為回憶是這樣的:給定2個點u,v(u可能等於v)和一個非空字串s,問從u到v的簡單路徑上的所有邊按照到u的距離從小到大的順序排列後,邊上的字元依次拼接形成的字串中給定的串s出現了多少次。
資料範圍
n<=100000,m<=100000,詢問串的總長<=300000
題解:
好題…
要求字串在一條鏈上匹配多少次,對回憶樹建樹是沒有辦法的,不能像bzoj3926那樣每個葉子提出來建一棵Trie。
於是這道題是離線,對要查詢的串(的正串和反串)建AC自動機。
同樣,A包含串B多少次,就是A在AC自動機上的每個節點,有多少在B結尾節點的fail樹子樹中。
於是,DFS原樹的同時,在AC自動機上匹配,
由於鏈是要拐彎的,同時還有方向,這個不好處理,於是把一條鏈拆成三部分:
拐彎處:長度為2|T|,用KMP暴力匹配
剩下的兩段一個是正著匹配,一個是倒著,於是詢問串需要把正反串都插入AC自動機,
於是詢問都變成了從根到某個點路徑的一部分(直了…),就可以一邊dfs一邊處理了。
就如同天天愛跑步的處理方式,在對應的起始端點push進這個詢問,以及是需要加還是減,查詢的是fail樹的哪個子樹。
然後在DFS原樹同時在AC自動機上轉移時,進入這個點把AC自動機上對應點+1,離開時-1,
詢問就是查詢fail樹子樹權值和。
程式碼:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<queue>
using namespace std;
const int N=110005;
const int M=300015;
queue<int> Q;
vector<int> V[N],ID[N];
char str[M];
namespace AC
{
int ch[2*M][26 ],root,tail,head[2*M],nxt[2*M],to[2*M],num,fail[2*M];
int in[2*N],out[2*M],inc,C[2*M];
inline void build(int u,int v) {num++; to[num]=v; nxt[num]=head[u]; head[u]=num; inc=0;}
inline void init(){root=1,tail=1,num=0;}
inline int insert(int opt)
{
int len=strlen(str); int tmp=root;
for(int j=0,i;j<len;j++)
{
if(opt==1) i=j; else i=len-j-1;
int c=str[i]-'a';
if(!ch[tmp][c]) ch[tmp][c]=++tail;
tmp=ch[tmp][c];
}
return tmp;
}
inline void getfail()
{
for(int i=0;i<26;i++) if(ch[root][i]) fail[ch[root][i]]=root,build(root,ch[root][i]),Q.push(ch[root][i]); else ch[root][i]=root;
while(!Q.empty())
{
int top=Q.front(); Q.pop();
for(int i=0;i<26;i++)
{
if(!ch[top][i]) ch[top][i]=ch[fail[top]][i];
else
{
int u=ch[top][i];
fail[u]=ch[fail[top]][i];
build(fail[u],u);
Q.push(u);
}
}
}
}
inline void dfs(int u)
{
inc++; in[u]=inc;
for(int i=head[u];i;i=nxt[i]) dfs(to[i]);
out[u]=inc;
}
inline void add(int x,int d){for(int i=x;i<=inc;i=i+(i&(-i))) C[i]+=d;}
inline int query(int x){int ret=0;for(int i=x;i;i=i-(i&(-i))) ret+=C[i]; return ret;}
}
int head[N],to[2*N],w[2*N],nxt[2*N],num=0,n,m,ans[N],fa[N],dep[N],_w[N],size[N],son[N],top[N],dfn=0,seq[N],loc[N];
int nx[M],s[M],t[M],pos[N][2];
inline void build(int u,int v,int ww)
{
num++;
w[num]=ww;
to[num]=v;
nxt[num]=head[u];
head[u]=num;
}
inline void dfs(int u,int f)
{
dep[u]=dep[f]+1; fa[u]=f; size[u]=1;
for(int i=head[u];i;i=nxt[i])
{
if(to[i]==f) continue;
_w[to[i]]=w[i];
dfs(to[i],u);
size[u]+=size[to[i]];
if(size[son[u]]<size[to[i]]) son[u]=to[i];
}
}
inline void dfs1(int u,int f,int tp)
{
loc[u]=++dfn,seq[dfn]=u; top[u]=tp;
if(son[u]) dfs1(son[u],u,tp);
for(int i=head[u];i;i=nxt[i])
{
if(to[i]==f||to[i]==son[u]) continue;
dfs1(to[i],u,to[i]);
}
}
inline int getlca(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
inline int getpoint(int u,int d)
{
while(dep[u]-dep[top[u]]<d){d-=(dep[u]-dep[top[u]]+1),u=fa[top[u]];}
return seq[loc[u]-d];
}
inline void kmp(int u,int v,int lca,int id)
{
int lent=strlen(str),lens=0;
int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1));
int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1));
lens=dep[x]-dep[lca]+dep[y]-dep[lca];
int tmp=x,i=0,j; while(tmp!=lca) s[i++]=_w[tmp],tmp=fa[tmp];
tmp=y,i=1; while(tmp!=lca) s[lens-i]=_w[tmp],tmp=fa[tmp],i++;
for(int i=0;i<lent;i++) t[i]=str[i]-'a';
nx[0]=-1; i=0,j=-1;
while(i<lent)
{
if(j==-1||t[i]==t[j]) {i++; j++; nx[i]=j;}
else j=nx[j];
}
i=0,j=0; int ret=0;
while(i<lens)
{
if(j==-1||s[i]==t[j])
{
i++; j++;
if(j==lent){ret++; j=nx[j];}
}
else j=nx[j];
}
pos[id][0]=AC::insert(1); pos[id][1]=AC::insert(-1);
ans[id]=ret;
if(u!=x)
{
ID[x].push_back(-id); ID[u].push_back(id);
V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]);
}
if(v!=y)
{
ID[y].push_back(-id); ID[v].push_back(id);
V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]);
}
}
inline void dfs2(int u,int f,int x)
{
AC::add(AC::in[x],1);
int sz=V[u].size();
for(int i=0;i<sz;i++)
{
int ret=AC::query(AC::out[V[u][i]])-AC::query(AC::in[V[u][i]]-1);
if(ID[u][i]>0) ans[ID[u][i]]+=ret;
else ans[-ID[u][i]]-=ret;
}
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==f) continue;
dfs2(v,u,AC::ch[x][w[i]]);
}
AC::add(AC::in[x],-1);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int u,v; scanf("%d%d",&u,&v); scanf("%s",str);
build(u,v,str[0]-'a'); build(v,u,str[0]-'a');
}
dfs(1,1); dfs1(1,1,1); AC::init();
for(int i=1;i<=m;i++)
{
int u,v; scanf("%d%d",&u,&v); scanf("%s",str);
if(u==v) continue; int lca=getlca(u,v);
kmp(u,v,lca,i);
}
AC::getfail();
AC::dfs(1);
dfs2(1,1,1);
for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
return 0;
}
然後這是我原來寫的倍增版本的程式碼(改正後),沒有namespace套namespace這種鬼畜玩意兒:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<queue>
using namespace std;
const int N=100005;
const int M=300015;
const int P=17;
queue<int> Q;
vector<int> V[N],ID[N];
char str[M];
struct AC_
{
int ch[2*M][26],root,tail,head[2*M],nxt[2*M],to[2*M],num,fail[2*M];
int in[2*N],out[2*M],inc,C[2*M];
void build(int u,int v) {num++; to[num]=v; nxt[num]=head[u]; head[u]=num; inc=0;}
void init(){root=1,tail=1,num=0;}
int insert(int opt)
{
int len=strlen(str); int tmp=root;
for(int j=0,i;j<len;j++)
{
if(opt==1) i=j; else i=len-j-1;
int c=str[i]-'a';
if(!ch[tmp][c]) ch[tmp][c]=++tail;
tmp=ch[tmp][c];
}
return tmp;
}
void getfail()
{
for(int i=0;i<26;i++) if(ch[root][i]) fail[ch[root][i]]=root,build(root,ch[root][i]),Q.push(ch[root][i]); else ch[root][i]=root;
while(!Q.empty())
{
int top=Q.front(); Q.pop();
for(int i=0;i<26;i++)
{
if(!ch[top][i]) ch[top][i]=ch[fail[top]][i];
else
{
int u=ch[top][i];
fail[u]=ch[fail[top]][i];
build(fail[u],u);
Q.push(u);
}
}
}
}
void dfs(int u)
{
inc++; in[u]=inc;
for(int i=head[u];i;i=nxt[i]) dfs(to[i]);
out[u]=inc;
}
void add(int x,int d){for(int i=x;i<=inc;i=i+(i&(-i))) C[i]+=d;}
inline int query(int x){int ret=0;for(int i=x;i;i=i-(i&(-i))) ret+=C[i]; return ret;}
}AC;
int head[N],to[2*N],w[2*N],nxt[2*N],num=0,n,m,ans[N],anc[N][P+3],dep[N],_w[N];
int nx[M],s[M],t[M],pos[N][2];
void build(int u,int v,int ww)
{
num++;
w[num]=ww;
to[num]=v;
nxt[num]=head[u];
head[u]=num;
}
void dfs1(int u,int f)
{
dep[u]=dep[f]+1; anc[u][0]=f;
for(int i=1;i<P;i++) anc[u][i]=anc[anc[u][i-1]][i-1];
for(int i=head[u];i;i=nxt[i]) if(to[i]!=f) _w[to[i]]=w[i],dfs1(to[i],u);
}
inline int getlca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
int d=dep[u]-dep[v];
for(int i=0;d;d>>=1,i++) if(d&1) u=anc[u][i];
if(u==v) return u;
for(int i=P-1;i>=0;i--)
if(anc[u][i]!=anc[v][i]) u=anc[u][i],v=anc[v][i];
return anc[u][0];
}
inline int getpoint(int u,int d) {for(int i=0;d;d>>=1,i++) if(d&1) u=anc[u][i]; return u;}
void kmp(int u,int v,int lca,int id)
{
int lent=strlen(str),lens=0;
int x=getpoint(u,dep[u]-min(dep[u],dep[lca]+lent-1));
int y=getpoint(v,dep[v]-min(dep[v],dep[lca]+lent-1));
lens=dep[x]-dep[lca]+dep[y]-dep[lca];
int tmp=x,i=0,j; while(tmp!=lca) s[i++]=_w[tmp],tmp=anc[tmp][0];
tmp=y,i=1; while(tmp!=lca) s[lens-i]=_w[tmp],tmp=anc[tmp][0],i++;
for(int i=0;i<lent;i++) t[i]=str[i]-'a';
nx[0]=-1; i=0,j=-1;
while(i<lent)
{
if(j==-1||t[i]==t[j]) {i++; j++; nx[i]=j;}
else j=nx[j];
}
i=0,j=0; int ret=0;
while(i<lens)
{
if(j==-1||s[i]==t[j])
{
i++; j++;
if(j==lent){ret++; j=nx[j];}
}
else j=nx[j];
}
pos[id][0]=AC.insert(1); pos[id][1]=AC.insert(-1);
ans[id]=ret;
if(u!=x)
{
ID[x].push_back(-id); ID[u].push_back(id);
V[x].push_back(pos[id][1]); V[u].push_back(pos[id][1]);
}
if(v!=y)
{
ID[y].push_back(-id); ID[v].push_back(id);
V[y].push_back(pos[id][0]); V[v].push_back(pos[id][0]);
}
}
void dfs2(int u,int f,int x)
{
AC.add(AC.in[x],1);
int sz=V[u].size();
for(int i=0;i<sz;i++)
{
int ret=AC.query(AC.out[V[u][i]])-AC.query(AC.in[V[u][i]]-1);
if(ID[u][i]>0) ans[ID[u][i]]+=ret;
else ans[-ID[u][i]]-=ret;
}
for(int i=head[u];i;i=nxt[i])
{
int v=to[i];
if(v==f) continue;
dfs2(v,u,AC.ch[x][w[i]]);
}
AC.add(AC.in[x],-1);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int u,v; scanf("%d%d%s",&u,&v,str);
build(u,v,str[0]-'a'); build(v,u,str[0]-'a');
}
dfs1(1,1); AC.init();
for(int i=1;i<=m;i++)
{
int u,v; scanf("%d%d",&u,&v); scanf("%s",str);
if(u==v) continue; int lca=getlca(u,v);
kmp(u,v,lca,i);
}
AC.getfail();
AC.dfs(1);
dfs2(1,1,1);
for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
return 0;
}
昨晚上腦子有點不清楚
自己沒寫對拍就一直交,卡了一晚上評測,非常抱歉,感謝大家的不殺之恩。
最後查出來的錯:
1.求lca時 swap(u,u);
2.kmp忘記 j=nxt[j];
3.自己nxt[N]
和nx[N]
兩個陣列搞混了。
最初以為是倍增慢了,沒有去查無限迴圈的錯,改成鏈剖還是T,才想到是不是哪裡寫掛了,
長程式碼一定要靜態查錯+對拍。