acwing 1140. 最短網路(prim)
阿新 • • 發佈:2022-03-16
分析
我們來根據操作來討論一下,需要維護的值有什麼。
將節點 a 到節點 b 的路徑上的所有點(包括 a 和 b)都染成顏色 c。
很明顯,我們需要維護一下tag,來儲存該區間是否發生了整體被某種顏色覆蓋
這並不困難,我們把眼光放到第二個操作上
詢問節點 a 到節點 b 的路徑上的顏色段數量。
此時,我們很明顯需要維護一個sum,表示該段上不同顏色段的數量。
同時為了維護合併後區間的sum,我們需要維護兩個值lc和rc分別表示左端點的顏色和右端點的顏色。
若合併區間時,左區間的右端點顏色 = 右區間的左端點顏色,則該區間的顏色段數量減1
同時需要注意的是,在從一個條鏈跳到另外一條鏈時,可能會發生顏色連續的事情,從而使答案減1
具體解決方案就是,可以在全域性開一個Lc和Rc變數,用來記此時該條鏈的左端點顏色和右端點顏色。
同時維護兩個變數ans1和ans2,用來分別統計u的上一條鏈的左端點顏色和v的上一條鏈的左端點顏色。
還需要注意的是,當top[u]==top[v],即u,v在同一條鏈時,因為此時區間的兩個端點分別為u,v,需要分別對u,v的上一條鏈的左端點顏色進行對比,若相同則減1
話不多說,直接看程式碼
AC_code
#include<bits/stdc++.h> using namespace std; const int N = 1e5 + 10,M = N*2; struct Node { int l,r,lc,rc,sum,tag; }tr[N<<2]; int h[N],e[M],ne[M],w[N],idx; int sz[N],fa[N],son[N],dep[N]; int id[N],top[N],nw[N],ts; int n,m,Lc,Rc; void add(int a,int b) { e[idx] = b,ne[idx] = h[a],h[a] = idx++; } void dfs1(int u,int pa,int depth) { fa[u] = pa,sz[u] = 1,dep[u] = depth; for(int i=h[u];~i;i=ne[i]) { int j = e[i]; if(j==pa) continue; dfs1(j,u,depth+1); if(sz[son[u]]<sz[j]) son[u] = j; sz[u] += sz[j]; } } void dfs2(int u,int tp) { id[u] = ++ts,nw[ts] = w[u],top[u] = tp; if(!son[u]) return ; dfs2(son[u],tp); for(int i=h[u];~i;i=ne[i]) { int j = e[i]; if(j==son[u]||j==fa[u]) continue; dfs2(j,j); } } void pushup(int u) { tr[u].lc = tr[u<<1].lc,tr[u].rc = tr[u<<1|1].rc; tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum; if(tr[u<<1].rc==tr[u<<1|1].lc) tr[u].sum--; } void change(Node &u,int k) { u.sum = u.tag = 1; u.lc = u.rc = k; } void pushdown(int u) { if(tr[u].tag) { change(tr[u<<1],tr[u].lc); change(tr[u<<1|1],tr[u].lc); tr[u].tag = 0; } } void build(int u,int l,int r) { if(l==r) { tr[u] = {l,r,nw[l],nw[l],1,0}; return ; } tr[u] = {l,r,nw[l],nw[r],0,0}; int mid = l + r >> 1; build(u<<1,l,mid),build(u<<1|1,mid+1,r); pushup(u); } void modify(int u,int l,int r,int k) { if(l<=tr[u].l&&tr[u].r<=r) { change(tr[u],k); return ; } pushdown(u); int mid = tr[u].l + tr[u].r >> 1; if(l<=mid) modify(u<<1,l,r,k); if(r>mid) modify(u<<1|1,l,r,k); pushup(u); } int query(int u,int l,int r) { if(l<=tr[u].l&&tr[u].r<=r) { if(tr[u].l==l) Lc = tr[u].lc; if(tr[u].r==r) Rc = tr[u].rc; return tr[u].sum; } pushdown(u); int mid = tr[u].l + tr[u].r >> 1; int res = 0,lc = -1,rc = -1; if(l<=mid) res += query(u<<1,l,r),lc = tr[u<<1].rc; if(r>mid) res += query(u<<1|1,l,r),rc = tr[u<<1|1].lc; if(lc!=-1&&rc!=-1&&lc==rc) res--; return res; } int main() { cin>>n>>m; memset(h,-1,sizeof h); for(int i=1;i<=n;i++) cin>>w[i]; for(int i=0;i<n-1;i++) { int a,b;cin>>a>>b; add(a,b),add(b,a); } dfs1(1,-1,1); dfs2(1,1); build(1,1,n); while(m--) { string op;int a,b,c; cin>>op>>a>>b; if(op=="C") { cin>>c; while(top[a]!=top[b]) { if(dep[top[a]]<dep[top[b]]) swap(a,b); modify(1,id[top[a]],id[a],c); a = fa[top[a]]; } if(dep[a]<dep[b]) swap(a,b); modify(1,id[b],id[a],c); } else { int res = 0,ans1 = -1,ans2 = -1; while(top[a]!=top[b]) { if(dep[top[a]]<dep[top[b]]) swap(a,b),swap(ans1,ans2); res += query(1,id[top[a]],id[a]); if(Rc==ans1) res--; ans1 = Lc; a = fa[top[a]]; } if(dep[a]<dep[b]) swap(a,b),swap(ans1,ans2); res += query(1,id[b],id[a]); if(Lc==ans2) res--; if(Rc==ans1) res--; cout<<res<<endl; } } return 0; }