1. 程式人生 > >Codeforces 916E(思維+dfs序+線段樹+LCA)

Codeforces 916E(思維+dfs序+線段樹+LCA)

題面

傳送門
題目大意:給定初始根節點為1的樹,有3種操作
1.把根節點更換為r
2.將包含u,v的節點的最小子樹(即lca(u,v)的子樹)所有節點的值+x
3.查詢v及其子樹的值之和

分析

看到批量修改子樹,我們想到將樹上操作轉化為區間操作
通過DFS序我們可以實現這一點.
對於每個節點x,我們記錄它在前序遍歷中的位置l[x],再一次回到x時的序號r[x],則x及其子樹的區間為前序遍歷中的[l[x],r[x]]
具體可點選這篇部落格

那麼,3種操作如何進行:
操作1.用一個變數root記錄當前根即可,時間複雜度O(1)
以下求LCA,DFS序,子樹,以及修改等操作都在初始的樹上進行

,再想辦法將它轉換為根不是1的情況
操作2.由於根節點變化,需要分類討論
首先,定義三個點的LCA值lca(u,v,w)為lca(u,v),lca(u,w),lca(v,w)中深度最深的那一個
設修改的點為u,v,根節點為root
(1) 若lca(u,v)在root的子樹中
這裡寫圖片描述
顯然結果和根為1的情況一樣,直接修改即可,時間複雜度O(log2n)

(2)若lca(u,v,root)=root
這裡寫圖片描述
很明顯包含u,v的最小子樹就是整棵樹,所以修改整棵樹,時間複雜度

O(log2n)

(3)若root在lca(u,v,root)的子樹中
這裡寫圖片描述
此時可採用類似容斥原理的方法
先將整棵樹的值+x
再找到root的祖先中離lca(u,v,root)最近的整數w,將w及其子樹(綠色部分)的值-x,剩下的就是包含u,v的最小子樹了(黃色部分)
求w可用樹上倍增,時間複雜度O(log2n)

操作3.
類似操作2的分類討論
設查詢的點為u,根節點為root
(1)若u在root的子樹中,則直接查詢u的子樹
(2)若u=root,查詢整棵樹
(3)若root在u的子樹中
這裡寫圖片描述


先查詢整棵樹的值之和,再找root的祖先中距離u最近的一個v
用整棵樹的值之和-v及子樹的值之和(綠色部分)=所求(黃色部分)

程式碼

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define maxn 100005
#define maxlog 32
using namespace std;
inline int qread(){
    int x=0,sign=1;
    char c=getchar();
    while(c<'0'||c>'9'){
        if(c=='-') sign=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=x*10+c-'0';
        c=getchar();
    }
    return x*sign;
}

int n,q;
int a[maxn];
int root=1;

struct edge{
    int from;
    int to;
    int next;
}E[maxn<<1];
int head[maxn];
int size=0;
void add_edge(int u,int v){
    size++;
    E[size].from=u;
    E[size].to=v;
    E[size].next=head[u];
    head[u]=size;
}

int cnt;
int log2n;
int l[maxn],r[maxn];
int deep[maxn],anc[maxn][maxlog];
void dfs(int x,int fa){
    l[x]=++cnt;
    anc[x][0]=fa;
    for(int i=1;i<=log2n;i++){
        anc[x][i]=anc[anc[x][i-1]][i-1];
    }
    for(int i=head[x];i;i=E[i].next){
        int y=E[i].to;
        if(y!=fa){
            deep[y]=deep[x]+1;
            dfs(y,x);
        } 
    }
    r[x]=cnt;
}

int lca(int x,int y){
    if(deep[x]<deep[y]) swap(x,y);
    for(int i=log2n;i>=0;i--){
        if(deep[anc[x][i]]>=deep[y]){
            x=anc[x][i];
        }
    }
    if(x==y) return x;
    for(int i=log2n;i>=0;i--){
        if(anc[x][i]!=anc[y][i]){
            x=anc[x][i];
            y=anc[y][i];
        }
    }
    return anc[x][0];
}

int tri_lca(int u,int v,int r){
    int l1=lca(u,v);
    int l2=lca(u,r);
    int l3=lca(v,r);
    int max_deep=max(deep[l1],max(deep[l2],deep[l3]));
    if(deep[l1]==max_deep) return l1;
    else if(deep[l2]==max_deep) return l2;
    else return l3;
}

int get_close(int w,int r){
    int x=r;
    for(int i=log2n;i>=0;i--){
        if(deep[anc[x][i]]>deep[w]){
            x=anc[x][i];
        }
    }
    return x;
}
struct node{
    int l;
    int r;
    long long v;
    long long mark;
    int len(){
        return r-l+1;
    }
}tree[maxn<<2];
void push_up(int pos){
    tree[pos].v=tree[pos<<1].v+tree[pos<<1|1].v;
}
void build(int l,int r,int pos){
    tree[pos].l=l;
    tree[pos].r=r;
    tree[pos].v=0;
    tree[pos].mark=0;
    if(l==r) return;
    int mid=(l+r)>>1;
    build(l,mid,pos<<1);
    build(mid+1,r,pos<<1|1); 
}
void push_down(int pos){
    if(tree[pos].mark){
        tree[pos<<1].mark+=tree[pos].mark;
        tree[pos<<1|1].mark+=tree[pos].mark;
        tree[pos<<1].v+=tree[pos].mark*tree[pos<<1].len();
        tree[pos<<1|1].v+=tree[pos].mark*tree[pos<<1|1].len();
        tree[pos].mark=0;
    }
}
void update(int L,int R,long long v,int pos){
    if(L<=tree[pos].l&&R>=tree[pos].r){
        tree[pos].mark+=v;
        tree[pos].v+=(v*tree[pos].len());
        return;
    }
    push_down(pos);
    int mid=(tree[pos].l+tree[pos].r)>>1;
    if(L<=mid) update(L,R,v,pos<<1);
    if(R>mid) update(L,R,v,pos<<1|1);
    push_up(pos);
}
long long query(int L,int R,int pos){
    if(L<=tree[pos].l&&R>=tree[pos].r){
        return tree[pos].v;
    }
    push_down(pos);
    int mid=(tree[pos].l+tree[pos].r)>>1;
    long long ans=0;
    if(L<=mid) ans+=query(L,R,pos<<1);
    if(R>mid) ans+=query(L,R,pos<<1|1);
    return ans;
}

void change(int u,int v,int x){
    int xx=lca(u,v);
    int lca_num=tri_lca(u,v,root);
    if(l[root]<l[xx]||r[root]>r[xx]){
        update(l[xx],r[xx],x,1);
        return;
    }else if(lca_num==root){
        update(1,n,x,1);
        return;
    }else{
        int w2=get_close(lca_num,root);
        update(1,n,x,1);
        update(l[w2],r[w2],-x,1);   
        return;
    }
}

long long sum(int w){
    if(l[root]<l[w]||r[root]>r[w]){
        return query(l[w],r[w],1);
    }else{
        if(w==root){
            return query(1,n,1);
        } 
        int sonw=get_close(w,root);
//      printf("%d\n",query(1,n,1));
//      printf("%d\n",query(l[sonw],r[sonw],1));
        return query(1,n,1)-query(l[sonw],r[sonw],1);
    }
}

int main(){
    int u,v,cmd,x;
    n=qread();
    q=qread();
    for(int i=1;i<=n;i++) a[i]=qread();
    for(int i=1;i<n;i++){
        u=qread();
        v=qread();
        add_edge(u,v);
        add_edge(v,u); 
    }
    deep[1]=1;
    log2n=log2(n)+1;
    dfs(1,0);
    build(1,n,1);
    for(int i=1;i<=n;i++){
        update(l[i],l[i],a[i],1);
    }
    for(int i=1;i<=q;i++){
        cmd=qread();
        if(cmd==1){
            v=qread();
            root=v;
        }else if(cmd==2){
            u=qread();
            v=qread();
            x=qread();
            change(u,v,x);
        }else{
            v=qread();
            printf("%I64d\n",sum(v));
        }
//      printf("debug: sum=%d\n",query(1,n,1));
    }
    return 0;
}