題解 洛谷 P4218 【[CTSC2010]珠寶商】
首先對特徵字串建 \(SAM\),來實現對子串的匹配。
有一個 \(O(n^2)\) 的暴力,分別以每個點為根進行 \(dfs\),遍歷樹時記錄當前字串在 \(SAM\) 上匹配到的節點即可。
考慮用點分治來解決本題這樣的樹上路徑統計問題。對於當前的分治重心 \(x\),統計連通塊中經過 \(x\) 的路徑的貢獻。經過 \(x\) 的路徑拆分為 \(x\) 子樹內一個點到 \(x\) 的路徑和 \(x\) 到 \(x\) 子樹內一個點的路徑。因為 \(SAM\) 可以用 \(endpos\) 統計以 \(x\) 所對應的字元結束的子串,所以對特徵字串的反串也建出 \(SAM\),把第二種路徑也轉化為第一種路徑來便於統計。
從 \(x\) 出發 \(dfs\) 統計路徑時,到達一個節點後要往當前字串前端加入該節點所對應的字元,直接用 \(SAM\) 是無法處理匹配的。發現在 \(SAM\) 對應的 \(Parent\) 樹上,一個點到其兒子節點,就是在其對應的子串前端加入字元,所以可以處理出 \(Parent\) 樹上每個點加入字元後對應的兒子節點,這樣就可以通過 \(Parent\) 樹來實現前端加入字元來匹配了。這裡其實就是建出了字尾樹。
兩條路徑對應的字串要保證是在特徵字串上是相鄰的,因此在 \(Parent\) 樹上打標記,統計時遍歷整棵 \(Parent\) 樹來使標記下放到兒子,對於特徵字串的每個位置,兩種路徑的方案相乘來貢獻答案。
直接點分治的複雜度是 \(O(n \log n + nm)\) 的,仍然無法接受,考慮結合 \(O(n^2)\) 的暴力進行根號分治。點分治時,連通塊大小 \(\leqslant \sqrt n\) 時採取 \(O(n^2)\) 暴力,大小 \(> \sqrt n\) 時採取統計路徑貢獻。
對這種方法來進行復雜度分析。連通塊大小 \(\leqslant \sqrt n\) 的情況複雜度為 \(O(n \sqrt n)\)。考慮點分治進行到第 \(k\) 層時,最大的連通塊大小為 \(\frac{n}{2^k}\),連通塊個數為 \(2^k\),若限制連通塊大小最小為 \(\sqrt n\)
注意統計路徑貢獻時還需容斥,減去兩條路徑來自同一個子樹的情況。容斥時也需要對每個兒子進行根號分治來保證複雜度。
\(code:\)
#include<bits/stdc++.h>
#define maxn 100010
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,S,root,tot,goal,rt;
ll ans;
int siz[maxn],ma[maxn];
bool vis[maxn];
char c[maxn],s[maxn];
struct edge
{
int to,nxt;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to)
{
e[++edge_cnt]=(edge){to,head[from]};
head[from]=edge_cnt;
}
struct SAM
{
int tot=1,root=1,las=1;
int fa[maxn],ch[maxn][30],son[maxn][30],len[maxn],siz[maxn],pos[maxn],bel[maxn];
ll tag[maxn];
char s[maxn];
vector<int> ve[maxn];
void insert(int c,int id)
{
int p=las,np=las=++tot;
len[np]=len[p]+1,siz[np]=1,pos[np]=id,bel[id]=np;
while(p&&!ch[p][c]) ch[p][c]=np,p=fa[p];
if(!p) fa[np]=root;
else
{
int q=ch[p][c];
if(len[q]==len[p]+1) fa[np]=q;
else
{
int nq=++tot;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
len[nq]=len[p]+1,fa[nq]=fa[q],fa[q]=fa[np]=nq;
while(ch[p][c]==q) ch[p][c]=nq,p=fa[p];
}
}
}
void dfs(int x)
{
for(int i=0;i<ve[x].size();++i)
{
int y=ve[x][i];
dfs(y),siz[x]+=siz[y];
pos[x]=pos[y],son[x][s[pos[y]-len[x]]]=y;
}
}
void build()
{
for(int i=1;i<=m;++i) insert(s[i],i);
for(int i=2;i<=tot;++i) ve[fa[i]].push_back(i);
dfs(root);
}
void match(int x,int fa,int p,int lenth)
{
if(lenth==len[p]) p=son[p][c[x]];
else if(s[pos[p]-lenth]!=c[x]) p=0;
if(!p) return;
tag[p]++;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fa) continue;
match(y,x,p,lenth+1);
}
}
void update(int x)
{
for(int i=0;i<ve[x].size();++i)
tag[ve[x][i]]+=tag[x],update(ve[x][i]);
}
void clear()
{
for(int i=1;i<=tot;++i) tag[i]=0;
}
}A,B;
void dfs_root(int x,int fa)
{
siz[x]=1,ma[x]=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fa) continue;
dfs_root(y,x),siz[x]+=siz[y];
ma[x]=max(ma[x],siz[y]);
}
ma[x]=max(ma[x],tot-siz[x]);
if(ma[x]<ma[root]) root=x;
}
void dfs_get(int x,int fa,int p,int type)
{
p=A.ch[p][c[x]];
if(!p) return;
ans+=A.siz[p]*type;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fa) continue;
dfs_get(y,x,p,type);
}
}
void dfs_del(int x,int fa,int p)
{
if(x!=goal) p=A.ch[p][c[x]];
else
{
p=A.ch[p][c[x]],p=A.ch[p][c[rt]];
if(p) dfs_get(x,0,p,-1);
return;
}
if(!p) return;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fa) continue;
dfs_del(y,x,p);
}
}
void dfs_find(int x,int fa,int type)
{
if(type) dfs_get(x,0,1,1);
else dfs_del(x,0,1);
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fa) continue;
dfs_find(y,x,type);
}
}
void calc(int x,int fa)
{
A.clear(),B.clear();
if(!fa) A.match(x,fa,1,0),B.match(x,fa,1,0);
else A.match(x,fa,A.ch[1][c[fa]],1),B.match(x,fa,B.ch[1][c[fa]],1);
A.update(1),B.update(1);
for(int i=1;i<=m;++i)
{
if(!fa) ans+=A.tag[A.bel[i]]*B.tag[B.bel[m-i+1]];
else ans-=A.tag[A.bel[i]]*B.tag[B.bel[m-i+1]];
}
}
void solve(int x)
{
if(tot<=S)
{
dfs_find(x,0,1);
return;
}
int now=tot;
vis[x]=true,calc(x,0);
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]) continue;
root=0,tot=siz[y];
if(siz[y]>siz[x]) tot=now-siz[x];
if(tot<=S) rt=x,goal=y,dfs_find(y,x,0);
else calc(y,x);
dfs_root(y,x),solve(root);
}
}
int main()
{
read(n),read(m),S=sqrt(n);
for(int i=1;i<n;++i)
{
int x,y;
read(x),read(y);
add(x,y),add(y,x);
}
scanf("%s%s",c+1,s+1);
for(int i=1;i<=n;++i) c[i]-='a';
for(int i=1;i<=m;++i) s[i]-='a',A.s[i]=s[i],B.s[m-i+1]=s[i];
A.build(),B.build(),tot=ma[0]=n,dfs_root(1,0),solve(root);
printf("%lld",ans);
return 0;
}