【GMOJ5363】生命之樹
阿新 • • 發佈:2020-11-02
題目
題目連結:https://gmoj.net/senior/#main/show/5363
思路
這個異或很煩,二進位制拆位搞掉。
然後兩個節點能造成貢獻當且僅當他們這一位下為 \(1\)。維護兩棵 Trie,存二進位制下這一位為 \(0/1\) 的所有字串。
對於一個節點 \(x\),先計算它所有子樹答案,然後發現需要將子樹中所有節點的字串扔到 Trie 中。暴力扔顯然是不可以的,dsu on tree 搞一下即可。
計算答案的時候就在 Tire 的路勁上求一下即可。
時間複雜度 \(O(n\log n\log |S|)\)。
程式碼
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N=500010,LG=18; int n,t,tot,head[N],a[N],pos[N],size[N],son[N]; ll ans[N],ans2[N]; char s[N],ss[N]; struct edge { int next,to; }e[N]; void add(int from,int to) { e[++tot].to=to; e[tot].next=head[from]; head[from]=tot; } struct Trie { int tot,c[N][27],size[N]; void ins(int k) { int p=1; size[p]++; for (int i=pos[k];i<pos[k+1];i++) { if (!c[p][s[i]-'a'+1]) c[p][s[i]-'a'+1]=++tot; p=c[p][s[i]-'a'+1]; size[p]++; } } ll query(int k) { int p=1; ll sum=0; for (int i=pos[k];i<pos[k+1];i++) { p=c[p][s[i]-'a'+1]; sum+=size[p]; } return sum; } void clr(int x) { for (int i=1;i<=26;i++) if (c[x][i]) clr(c[x][i]),c[x][i]=0; size[x]=0; } }trie[2]; void dfs3(int x,int fa,int rt,int val) { int id=((a[x]&val)!=0); ans[rt]+=trie[id^1].query(x)*val; trie[id].ins(x); for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (v!=fa) dfs3(v,x,rt,val); } } void dfs1(int x,int fa) { size[x]=pos[x+1]-pos[x]; for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (v!=fa) { dfs1(v,x); size[x]+=size[v]; if (size[v]>size[son[x]]) son[x]=v; } } } void dfs2(int x,int fa,int val,bool flag) { for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (v!=fa && v!=son[x]) dfs2(v,x,val,0); } if (son[x]) { dfs2(son[x],x,val,1); ans[x]+=ans[son[x]]; } for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (v!=fa && v!=son[x]) dfs3(v,x,x,val); } int id=((a[x]&val)!=0); ans[x]+=trie[id^1].query(x)*val; trie[id].ins(x); if (!flag) { trie[0].clr(1); trie[0].tot=1; trie[1].clr(1); trie[1].tot=1; } } int main() { freopen("tree.in","r",stdin); freopen("tree.out","w",stdout); memset(head,-1,sizeof(head)); scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%d",&a[i]); pos[1]=1; for (int i=1;i<=n;i++) { scanf("%s",ss+1); int len=strlen(ss+1); for (int j=pos[i];j<pos[i]+len;j++) s[j]=ss[j-pos[i]+1]; pos[i+1]=pos[i]+len; } for (int i=1,x,y;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs1(1,0); trie[0].tot=trie[1].tot=1; for (int i=0;i<=LG;i++) { dfs2(1,0,(1<<i),0); for (int j=1;j<=n;j++) ans2[j]+=ans[j],ans[j]=0; } for (int i=1;i<=n;i++) printf("%lld\n",ans2[i]); return 0; }