arc 048 d 題解
阿新 • • 發佈:2020-12-03
arc 048 d
假設要從\(s\)到\(t\),中途在\(i\)的位置出去走到一個特殊點,然後再回來走到t。
可以發現從i出去走到的關鍵點一定是距離i最近的那個關鍵點。
首先處理出所有點到最近的關鍵點的距離,記為\(dis_i\)。
則從i走到最近的那個關鍵點再走回i需要花費\(3\times dis_i\)的時間。
然後再分類討論i的位置:
- 在\(s\)到\(lca(s,t)\)的路徑上。則花費的時間為:
也就是說只需要找到路徑$s$到lca上最小的-depth[i]+3dis[i]。
- 在lca到t的路徑上,花費的時間為:
與上面類似,只需要找到在lca到t的路徑上最小的depth[i]+3dis[i]。
由於沒有修改,可以用倍增算最小值。
IH19980412
#include <bits/stdc++.h> using namespace std; typedef pair<int,int> P; #define fi first #define sc second #define mp make_pair #define pb push_back #define mod 1000000007 typedef long long ll; int n,m; string s; vector<int>e[100005]; int db[100005][18]; int d[100005][18]; int ee[100005][18]; vector<int>T; int ds[100005],D[100005]; bool used[100005]; int dfs(int v,int u){ db[v][0] = u; for(int i=0;i<e[v].size();i++){ if(e[v][i] == u) continue; D[e[v][i]] = D[v]+1; dfs(e[v][i],v); } } int rd(int s,int g){ int x = 10000000; for(int i=17;i>=0;i--){ if(db[s][i] != -1 && D[g] <= D[db[s][i]]){ x = min(x,d[s][i]); s = db[s][i]; } } return min(x,d[g][0]); } int re(int s,int g){ int x = 10000000; for(int i=17;i>=0;i--){ if(db[s][i] != -1 && D[g] <= D[db[s][i]]){ x = min(x,ee[s][i]); s = db[s][i]; } } return min(x,ee[g][0]); } int lca(int u,int v){ if(D[u]>D[v]) swap(u,v); for(int i=0;i<18;i++){ if( ((D[v]-D[u])>>i)&1) { v = db[v][i]; } } if(u == v) return u; for(int i=17;i>=0;i--){ if(db[u][i] != db[v][i]){ v = db[v][i]; u = db[u][i]; } } return db[v][0]; } int calc(int u,int v){ if(D[u]>D[v]) swap(u,v);int x=u,y=v; for(int i=0;i<18;i++){ if( ((D[v]-D[u])>>i)&1) { v = db[v][i]; } } int c; if(u == v) c=u; else{ for(int i=17;i>=0;i--){ if(db[u][i] != db[v][i]){ v = db[v][i]; u = db[u][i]; } } c = db[v][0]; } return D[x]+D[y]-2*D[c]; } int main(){ cin >> n >> m; for(int i=1;i<n;i++){ int a,b; scanf("%d%d",&a,&b); e[a].pb(b); e[b].pb(a); } cin >> s; for(int i=0;i<n;i++){ if(s[i] == '1') T.pb(i+1); } priority_queue<P,vector<P>,greater<P> >que; fill(ds,ds+100005,10000000); for(int i=0;i<T.size();i++){ que.push(mp(0,T[i])); ds[T[i]] = 0; } while(!que.empty()){ P p = que.top(); que.pop(); if(ds[p.sc] != p.fi) continue; for(int i=0;i<e[p.sc].size();i++){ if(ds[e[p.sc][i]] > p.fi+1){ que.push(mp(p.fi+1,e[p.sc][i])); ds[e[p.sc][i]] = p.fi+1; } } } memset(db,-1,sizeof(db)); dfs(1,-1); for(int i=1;i<=n;i++){ d[i][0] = -D[i]+3*ds[i]; ee[i][0] = D[i]+3*ds[i]; } for(int j=0;j<17;j++) {for(int i=1;i<=n;i++){ if(db[i][j] == -1){ db[i][j+1] = -1; d[i][j+1] = d[i][j]; ee[i][j+1] = ee[i][j]; } else{ db[i][j+1] = db[db[i][j]][j]; d[i][j+1] = min(d[i][j],d[db[i][j]][j]); ee[i][j+1] = min(ee[i][j],ee[db[i][j]][j]); } }} for(int i=0;i<m;i++){ int s,g; scanf("%d%d",&s,&g); int v = lca(s,g); int dist = calc(s,g); int dist2 = calc(v,g); int L = rd(s,v); int R = re(g,v); L += dist*2-dist2+D[v]; R += dist*2-dist2-D[v]; printf("%d\n",min(dist*2,min(L,R))); } }