AcWing 355. 異象石
阿新 • • 發佈:2021-01-09
題目連結: AcWing 355. 異象石
題目大意:
一棵大小為 \(n\) 的樹,\(m\) 次操作,有三種操作:
- "\(+x\)" 在節點 \(x\) 處出現了異象石
- "\(-x\)" 節點 \(x\) 處的異象石消失
- " \(?\) " 詢問在樹上將所有異象石連通所需邊的最小權值和
\(1\leq n,m\leq 10^5\) ,邊權 \(1\leq z\leq 10^9\) 。
思路:
這道題有一個不好想到的結論,類比於樹的邊權和為 \(dfs\) 經過所有邊的權值和的一半,結論如下:
我們按照時間戳從小到大排序,將出現異象石的節點首尾相連排成一圈,則相鄰節點的距離之和即為答案的一半。
可以結合這棵樹理解一下,黑圈的是出現異象石的節點,粗邊即聯通異象石的邊集:
有這個結論之後接下來的就簡單了,首先 \(dfs\) 求出 \(dfn_i\) ,使用set維護出現異象石的節點序列,設節點 \(x\) 的前後驅分別為 \(u,v\) ,插入 \(x\) 即 \(ans+=Dis(u,x)+Dis(x,v)-Dis(u,v)\) ,刪除類似。
時間複雜度 \(O(nlogn)\) 。
實現細節:
- 倍增求 \(LCA\) 的時候不要把 \(dep\) 和 \(dis\) 搞混了(可能就我會犯這種錯誤)。
- 注意維護set中的首尾相連,當 set 加入 \(x\) 前是空的時,不用更新 \(ans\)
Code:
#include<iostream> #include<cstdio> #include<cstring> #include<set> #define N 100100 #define LOG 17 #define int long long using namespace std; inline int read(){ int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();} while(ch>='0'&&ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=getchar(); return s*w; } int head[N],to[N*2],nxt[N*2]; int cnt,len[N*2]; int dep[N],fa[N][17],dfn[N],rev[N]; int dis[N]; set<int> q; void init(){ cnt=-1; memset(head,-1,sizeof(head)); } void add_e(int a,int b,int l,bool id){ nxt[++cnt]=head[a]; head[a]=cnt; to[cnt]=b; len[cnt]=l; if(id)add_e(b,a,l,0); } void dfs(int x,int fath){ dfn[x]=++cnt; rev[cnt]=x; fa[x][0]=fath; for(int i=1;i<LOG;i++){ fa[x][i]=fa[fa[x][i-1]][i-1]; } for(int i=head[x];~i;i=nxt[i]){ if(to[i]==fath)continue; dep[to[i]]=dep[x]+1; dis[to[i]]=dis[x]+len[i]; dfs(to[i],x); } } int lca(int a,int b){ if(dep[a]<dep[b])swap(a,b); for(int i=16;i>=0;i--) if(fa[a][i]&&dep[fa[a][i]]>=dep[b])a=fa[a][i]; if(a==b)return a; for(int i=16;i>=0;i--){ if(fa[a][i]!=fa[b][i])a=fa[a][i],b=fa[b][i]; } return fa[a][0]; } int Dis(int a,int b){ return dis[a]+dis[b]-2*dis[lca(a,b)]; } int get(int k,int ud){ set<int>::iterator it; if(ud==0){ it=q.lower_bound(k); if(it==q.begin())return rev[*(--q.end())]; else return rev[*(--it)]; }else{ it=q.upper_bound(k); if(it==q.end())return rev[*(q.begin())]; else return rev[*it]; } } signed main(){ int n,m; int x,y,z; cin>>n; init(); for(int i=1;i<n;i++){ x=read(),y=read(),z=read(); add_e(x,y,z,1); } cnt=0; dfs(1,0); cin>>m; char c; int in,ans=0; for(int i=0;i<m;i++){ c=getchar(); while(c!='+'&&c!='-'&&c!='?')c=getchar(); switch(c){ case '?':printf("%lld\n",ans/2);break; case '+':{ in=read(); q.insert(dfn[in]); if(q.size()==1)continue; int u=get(dfn[in],0),v=get(dfn[in],1); ans+=Dis(u,in)+Dis(in,v)-Dis(u,v); break; } case '-':{ in=read(); int u=get(dfn[in],0),v=get(dfn[in],1); q.erase(dfn[in]); ans+=Dis(u,v)-Dis(u,in)-Dis(v,in); } } } return 0; }