洛谷【P2664】樹上游戲
阿新 • • 發佈:2018-12-17
淺談樹分治:https://www.cnblogs.com/AKMer/p/10014803.html
題目傳送門:https://www.luogu.org/problemnew/show/P2664
對於所有求顏色種類數的問題,我們都可以定義一個方向,使得所有的顏色在最靠這個方向第一次出現的位置有效,而其它位置都是無效的。對於樹分治,我們可以定義這個方向為當前需要遍歷的子樹,反方向就是已經遍歷完的子樹。
對於一個點\(u\),如果從當前重心到他這一條路徑上,該點顏色是第一次出現,那麼它的顏色將給後面的遍歷帶來\(siz[u]\)的貢獻。另外,在遍歷當前子樹時,所有在重心到當前點這條路徑的上的顏色,貢獻都是已經遍歷過的子樹的總結點數。正過來做一遍,反過來做一遍就可以了。對於單獨的從重心到當前點的路徑會被統計兩次,所以要減掉一次。
邊分治重構樹之後不知道怎麼消除新結點的影響,如果有大佬願意教教我請在評論下方回覆。
這題資料貌似比較水,不卡不重構樹的邊分治。
時間複雜度:\(O(nlogn)\)
空間複雜度:\(O(n)\)
點分治版程式碼如下:
#include <cstdio> #include <algorithm> using namespace std; typedef long long ll; const int maxn=1e5+5; bool vis[maxn]; ll ans[maxn],res; int n,tot,mx,rt,N,Siz; int now[maxn],pre[maxn<<1],son[maxn<<1]; int col[maxn],siz[maxn],cnt[maxn],V[maxn],sum[maxn]; int read() { int x=0,f=1;char ch=getchar(); for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1; for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0'; return x*f; } void add(int a,int b) { pre[++tot]=now[a]; now[a]=tot,son[tot]=b; } struct rubbish { bool bo[maxn]; int sta[maxn],top; void clear() { Siz=res=0; while(top) { bo[sta[top]]=0; cnt[sta[top--]]=0; } } void ins(int id) { if(bo[id])return; bo[id]=1,sta[++top]=id; } }R; void find_rt(int fa,int u) { int res=0;siz[u]=1; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[v]&&v!=fa)find_rt(u,v),siz[u]+=siz[v],res=max(res,siz[v]); res=max(res,N-siz[u]); if(res<mx)mx=res,rt=u; } void dfs(int fa,int u) { sum[col[u]]++,res+=(sum[col[u]]==1); ans[u]-=res,siz[u]=1; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[v]&&v!=fa)dfs(u,v),siz[u]+=siz[v]; sum[col[u]]--,res-=(sum[col[u]]==0); } void query(int fa,int u) { sum[col[u]]++;if(sum[col[u]]==1)res-=cnt[col[u]],res+=Siz+1; ans[u]+=res; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[v]&&v!=fa)query(u,v); sum[col[u]]--;if(sum[col[u]]==0)res+=cnt[col[u]],res-=Siz+1; } void solve(int fa,int u) { sum[col[u]]++; if(sum[col[u]]==1) { cnt[col[u]]+=siz[u]; res+=siz[u];R.ins(col[u]); } for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[v]&&v!=fa)solve(u,v); sum[col[u]]--; } void print() { for(int i=1;i<=n;i++) printf("%lld ",ans[i]); puts(""); } void work(int u,int size) { N=size,mx=rt=n+1,find_rt(0,u); u=rt,vis[u]=1,tot=0; sum[col[u]]++;res++; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[v])V[++tot]=v,dfs(u,v); sum[col[u]]--;res--; for(int i=1;i<=tot;i++) { int v=V[i]; sum[col[u]]++,res-=cnt[col[u]],res+=Siz+1; query(u,v); sum[col[u]]--,res+=cnt[col[u]],res-=Siz+1; solve(u,v),Siz+=siz[v]; }R.clear(); for(int i=tot;i;i--) { int v=V[i]; sum[col[u]]++,res-=cnt[col[u]],res+=Siz+1; query(u,v); sum[col[u]]--,res+=cnt[col[u]],res-=Siz+1; solve(u,v),Siz+=siz[v]; }ans[u]+=res+Siz+1-cnt[col[u]];R.clear(); for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[v])work(v,siz[v]); } int main() { n=read(); for(int i=1;i<=n;i++) col[i]=read(); for(int i=1;i<n;i++) { int a=read(),b=read(); add(a,b),add(b,a); }work(1,n); for(int i=1;i<=n;i++) printf("%lld\n",ans[i]); return 0; }
不重構樹的邊分治版程式碼如下:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int maxn=2e5+5; bool vis[maxn]; ll ans[maxn],res; int m,n,tot=1,mx,id,N; int now[maxn],pre[maxn<<1],son[maxn<<1]; int col[maxn],siz[maxn],cnt[maxn],sum[maxn]; vector<int>to[maxn]; vector<int>::iterator it; int read() { int x=0,f=1;char ch=getchar(); for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1; for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0'; return x*f; } void add(int a,int b) { pre[++tot]=now[a]; now[a]=tot,son[tot]=b; } struct Rubbish { bool bo[maxn]; int sta[maxn],top; void clear() { res=0; while(top)cnt[sta[top]]=bo[sta[top]]=0,top--; } void ins(int id) { if(bo[id])return; bo[id]=1,sta[++top]=id; } }R; void find_edge(int fa,int u) { siz[u]=1; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[p>>1]&&v!=fa) { find_edge(u,v),siz[u]+=siz[v]; if(abs(N-2*siz[v])<mx) mx=abs(N-2*siz[v]),id=p>>1; } } void dfs(int fa,int u) { siz[u]=1; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[p>>1]&&v!=fa)dfs(u,v),siz[u]+=siz[v]; } void solve(int fa,int u) { sum[col[u]]++; if(sum[col[u]]==1) { cnt[col[u]]+=siz[u]; res+=siz[u],R.ins(col[u]); } for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[p>>1]&&v!=fa)solve(u,v); sum[col[u]]--; } void query(int fa,int u,int num) { sum[col[u]]++; if(sum[col[u]]==1)res+=num,res-=cnt[col[u]]; ans[u]+=res; for(int p=now[u],v=son[p];p;p=pre[p],v=son[p]) if(!vis[p>>1]&&v!=fa)query(u,v,num); sum[col[u]]--; if(sum[col[u]]==0)res-=num,res+=cnt[col[u]]; } void work(int u,int size) { if(size<2)return; N=size,mx=id=m+1,find_edge(0,u),vis[id]=1; int u1=son[id<<1],u2=son[id<<1|1]; dfs(0,u1),dfs(0,u2); solve(0,u1),query(0,u2,siz[u1]),R.clear(); solve(0,u2),query(0,u1,siz[u2]),R.clear(); work(u1,siz[u1]),work(u2,siz[u2]); } int main() { m=n=read(); for(int i=1;i<=n;i++) col[i]=read(); for(int i=1;i<n;i++) { int a=read(),b=read(); add(a,b),add(b,a); } work(1,m); for(int i=1;i<=n;i++)printf("%lld\n",ans[i]+1); return 0; }