1. 程式人生 > 其它 >CF324E Xenia and Tree

CF324E Xenia and Tree

有一顆樹,初始$1$號點為紅色,其餘為藍色,有兩種操作:1.修改一個藍點為紅點。2.查詢每個點最近紅點的距離

題目傳送門


題意

有一顆樹,初始\(1\)號點為紅色,其餘為藍色,有兩種操作:1.修改一個藍點為紅點。2.查詢每個點最近紅點的距離


題解

應該是較板題
最近根號寫多了,首先想到的是: 每根號個詢問重構一次,不超過根號就暴力列舉這個點和那根號個點的距離。
發現可能是對的,dp可以在\(O(n)\)複雜度處理出答案,每\(O(\sqrt{n})\) dp一次即可。
考慮需要\(O(1)\)LCA,於是學了一手。
如何dp?對於每個點,先dp它的兒子,然後他的答案就等於它兒子或父親中答案最小的加1.
注意,要呼叫兩遍這個dp。
你可以這樣理解,這是一個混合dp,它實際上分兩步進行:
第一步: 處理出每個節點向下最近的距離,這個可以從下往上dp一遍,
第二步:處理每個點向上經過它父親的最近距離,由於根節點不能向上,所以答案不變,從上往下dp是對的。
我比較懶,寫在一起, 呼叫兩遍。

實現

#include <iostream>
#include <cstdio>
#include <vector>
#include <cmath>
using namespace std;

int read(){
    int num=0, flag=1; char c=getchar();
    while(!isdigit(c) && c!='-') c=getchar();
    if(c == '-') c=getchar(), flag=-1;
    while(isdigit(c)) num=num*10+c-'0', c=getchar();
    return num*flag;
}

int min(int a, int b){return a<b?a:b;}

const int N = 2e6+100;
const int M = 21;
const int inf = 0x3f3f3f3f;
int n, m, sqr, fa[N], dep[N], lg[N], dfn[N], f[N][M], tot=0;
int col[N], ans[N];
vector<int> p[N];
vector<int> op;

void dfs(int x){
    dfn[x] = ++tot, dep[x] = dep[fa[x]] + 1;
    f[tot][0] = x;
    for(auto i : p[x]){
        if(i == fa[x]) continue;
        fa[i] = x;
        dfs(i);
        f[++tot][0] = x;
    }
}

int mindep(int x, int y){
    return dep[x]<dep[y]?x:y;
}

void pre(){
    for(int i=2; i<N; i++) lg[i] = lg[i>>1] + 1;
    for(int i=1; i<M; i++){
        for(int j=1; j<=tot; j++){
            f[j][i] = mindep(f[j][i-1], f[j+(1<<(i-1))][i-1]);
        }
    }
}

int rmq(int l, int r){
	if(l > r) swap(l, r); 
    return mindep(f[l][lg[r-l+1]], f[r-(1<<lg[r-l+1])+1][lg[r-l+1]]);
}

int lca(int x, int y){
    return rmq(dfn[x], dfn[y]);
}

int getDist(int x, int y){
    return dep[x]+dep[y] - 2*dep[lca(x, y)];
}

void solve(int x){
    if(col[x]) ans[x]=0;
    ans[x] = min(ans[x], ans[fa[x]]+1);
    for(auto i : p[x]){
        if(i == fa[x]) continue;
        solve(i);
        ans[x] = min(ans[x], ans[i]+1);
    }
}

int main(){
    n=read(), m=read(), sqr=sqrt(m); 
    for(int i=1; i<=n; i++) ans[i]=inf;
    for(int i=1; i<n; i++){
        int u=read(), v=read();
        p[u].push_back(v), p[v].push_back(u);
    }
    dfs(1); pre();

    op.push_back(1);
    col[1] = 1;
    while(m--){
        int type=read(), x=read();
        if(type == 1){
            op.push_back(x);
            col[x] = 1;
        }else{
            if(op.size() > sqr){
                solve(1);
                solve(1); 
                op.clear();
            }

            for(auto i : op){
                ans[x] = min(ans[x], getDist(x, i));
            }

            printf("%d\n", ans[x]);
        }
    }
    return 0;
}