1. 程式人生 > 實用技巧 >Colorful Tree(查詢用最少邊數使相同點聯通,單點修改)

Colorful Tree(查詢用最少邊數使相同點聯通,單點修改)

題:https://ac.nowcoder.com/acm/contest/7831/H

題意:給定n個點的樹,每個節點都有顏色;

  • 詢問[Q,y]:求把所有y顏色的節點聯通起來用的最少的邊數。
  • 更新[U,x,y]:將x節點的顏色改為y;

分析:

  • 對於詢問,我們可以假象為有倆個點作為總邊,剩餘顏色的點就連上這條總邊;
  • 關鍵在於如何確定這條總邊,讓其他相同顏色的點x連上這條總邊的點y,dis(x,y)不與其他“邊”重合;
  • 這條總邊就可以用最小的dfs序的點和最大的dfs序的點連線的邊,就可以保證statement.2;
  • 關於更新答案可以發現,插入一個點x的影響只與和x的dfs序相鄰的點(l和r)有關,+dis(l,x)+dis(r,x)-dis(l,r)這部分就可以用過載排序的set來維護;
  • 刪除就上面表示式去反號,簡單看出統計的答案的“分支”是2倍的,而總邊只有1倍,那麼查詢就是(ans[col[y]]+總邊長度)/2;
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define MP make_pair
#define lson root<<1,l,midd
#define rson root<<1|1,midd+1,r
typedef long long ll;
const int mod=1e9+7;
const int M=1e5+5;
const
int inf=0x3f3f3f3f; const ll INF=1e18; vector<int>g[M]; int tot; int dfn[M],f[M],deep[M],sz[M],son[M],top[M],ans[M],col[M]; struct cmp{ bool operator() (const int &x,const int &y)const{ return dfn[x]<dfn[y]; } }; set<int,cmp>st[M]; void dfs1(int u){///cout<<u<<"!!"<<endl;
deep[u]=deep[f[u]]+1; sz[u]=1; for(auto v:g[u]){ if(v!=f[u]){ f[v]=u; dfs1(v); sz[u]+=sz[v]; if(!son[u]||sz[v]>sz[son[u]]) son[u]=v; } } } void dfs2(int u,int tp){ dfn[u]=++tot; top[u]=tp; if(son[u]) dfs2(son[u],tp); for(auto v:g[u]){ if(v!=f[u]&&v!=son[u]) dfs2(v,v); } } int LCA(int u,int v){ while(top[u]!=top[v]){ if(deep[top[u]]<deep[top[v]]) swap(u,v); u=f[top[u]]; } if(deep[u]>deep[v]) swap(u,v); return u; } int dis(int u,int v){ return deep[u]+deep[v]-2*deep[LCA(u,v)]; } void add(int x){ st[col[x]].insert(x); auto it=st[col[x]].find(x); int l=0,r=0; ++it; if(it!=st[col[x]].end()){ r=*it; } it--; if(it!=st[col[x]].begin()){ it--; l=*it; } if(l) ans[col[x]]+=dis(l,x); if(r) ans[col[x]]+=dis(r,x); if(l&&r) ans[col[x]]-=dis(l,r); } void del(int x){ auto it=st[col[x]].find(x); it++; int l=0,r=0; if(it!=st[col[x]].end()){ r=*it; } it--; if(it!=st[col[x]].begin()){ it--; l=*it; } if(l) ans[col[x]]-=dis(l,x); if(r) ans[col[x]]-=dis(r,x); if(l&&r) ans[col[x]]+=dis(l,r); st[col[x]].erase(x); } char s[2]; int main(){ int n; scanf("%d",&n); for(int u,v,i=1;i<n;i++){ scanf("%d%d",&u,&v); g[u].pb(v); g[v].pb(u); } dfs1(1); dfs2(1,1); for(int i=1;i<=n;i++){ scanf("%d",&col[i]); add(i); } int m; scanf("%d",&m); while(m--){ scanf("%s",s); int x,y; if(s[0]=='U'){ scanf("%d%d",&x,&y); del(x); col[x]=y; add(x); } else{ scanf("%d",&y); if(st[y].size()==0) puts("-1"); else printf("%d\n",(ans[y]+dis(*st[y].begin(),*st[y].rbegin()))/2); } } return 0; }
View Code