1. 程式人生 > 其它 >acwing 1140. 最短網路(prim)

acwing 1140. 最短網路(prim)

P2486 [SDOI2011]染色

分析

我們來根據操作來討論一下,需要維護的值有什麼。

將節點 a 到節點 b 的路徑上的所有點(包括 a 和 b)都染成顏色 c。

很明顯,我們需要維護一下tag,來儲存該區間是否發生了整體被某種顏色覆蓋

這並不困難,我們把眼光放到第二個操作上

詢問節點 a 到節點 b 的路徑上的顏色段數量。

此時,我們很明顯需要維護一個sum,表示該段上不同顏色段的數量。

同時為了維護合併後區間的sum,我們需要維護兩個值lcrc分別表示左端點的顏色和右端點的顏色。

若合併區間時,左區間的右端點顏色 = 右區間的左端點顏色,則該區間的顏色段數量減1

同時需要注意的是,在從一個條鏈跳到另外一條鏈時,可能會發生顏色連續的事情,從而使答案減1

具體解決方案就是,可以在全域性開一個LcRc變數,用來記此時該條鏈的左端點顏色和右端點顏色。

同時維護兩個變數ans1ans2,用來分別統計u的上一條鏈的左端點顏色和v的上一條鏈的左端點顏色。

還需要注意的是,當top[u]==top[v],即u,v在同一條鏈時,因為此時區間的兩個端點分別為u,v,需要分別對u,v的上一條鏈的左端點顏色進行對比,若相同則減1

話不多說,直接看程式碼

AC_code

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10,M = N*2;
struct Node
{
    int l,r,lc,rc,sum,tag;
}tr[N<<2];
int h[N],e[M],ne[M],w[N],idx;
int sz[N],fa[N],son[N],dep[N];
int id[N],top[N],nw[N],ts;
int n,m,Lc,Rc;

void add(int a,int b)
{
    e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}

void dfs1(int u,int pa,int depth)
{
    fa[u] = pa,sz[u] = 1,dep[u] = depth;
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==pa) continue;
        dfs1(j,u,depth+1);
        if(sz[son[u]]<sz[j]) son[u] = j;
        sz[u] += sz[j];
    }
}

void dfs2(int u,int tp)
{
    id[u] = ++ts,nw[ts] = w[u],top[u] = tp;
    if(!son[u]) return ;
    dfs2(son[u],tp);
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==son[u]||j==fa[u]) continue;
        dfs2(j,j);
    }
}

void pushup(int u)
{
    tr[u].lc = tr[u<<1].lc,tr[u].rc = tr[u<<1|1].rc;
    tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum;
    if(tr[u<<1].rc==tr[u<<1|1].lc) tr[u].sum--;
}


void change(Node &u,int k)
{
    u.sum = u.tag = 1;
    u.lc = u.rc = k;
}

void pushdown(int u)
{
    if(tr[u].tag)
    {
        change(tr[u<<1],tr[u].lc);
        change(tr[u<<1|1],tr[u].lc);
        tr[u].tag = 0;
    }
}

void build(int u,int l,int r)
{
    if(l==r)
    {
        tr[u] = {l,r,nw[l],nw[l],1,0};
        return ;
    }
    tr[u] = {l,r,nw[l],nw[r],0,0};
    int mid = l + r >> 1;
    build(u<<1,l,mid),build(u<<1|1,mid+1,r);
    pushup(u);
}

void modify(int u,int l,int r,int k)
{
    if(l<=tr[u].l&&tr[u].r<=r)
    {
        change(tr[u],k);
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l<=mid) modify(u<<1,l,r,k);
    if(r>mid) modify(u<<1|1,l,r,k);
    pushup(u);
}

int query(int u,int l,int r)
{
    if(l<=tr[u].l&&tr[u].r<=r) 
    {
        if(tr[u].l==l) Lc = tr[u].lc;
        if(tr[u].r==r) Rc = tr[u].rc;
        return tr[u].sum;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    int res = 0,lc = -1,rc = -1;
    if(l<=mid) res += query(u<<1,l,r),lc = tr[u<<1].rc;
    if(r>mid) res += query(u<<1|1,l,r),rc = tr[u<<1|1].lc;
    if(lc!=-1&&rc!=-1&&lc==rc) res--;
    return res;
}

int main()
{
    cin>>n>>m;
    memset(h,-1,sizeof h);
    for(int i=1;i<=n;i++) cin>>w[i];
    for(int i=0;i<n-1;i++)
    {
        int a,b;cin>>a>>b;
        add(a,b),add(b,a);
    }
    dfs1(1,-1,1);
    dfs2(1,1);
    build(1,1,n);
    while(m--)
    {
        string op;int a,b,c;
        cin>>op>>a>>b;
        if(op=="C")
        {
            cin>>c;
            while(top[a]!=top[b])
            {
                if(dep[top[a]]<dep[top[b]]) swap(a,b);
                modify(1,id[top[a]],id[a],c);
                a = fa[top[a]];
            }
            if(dep[a]<dep[b]) swap(a,b);
            modify(1,id[b],id[a],c);
        }
        else 
        {
            int res = 0,ans1 = -1,ans2 = -1;
            while(top[a]!=top[b])
            {
                if(dep[top[a]]<dep[top[b]]) swap(a,b),swap(ans1,ans2);
                res += query(1,id[top[a]],id[a]);
                if(Rc==ans1) res--;
                ans1 = Lc;
                a = fa[top[a]];
            }
            if(dep[a]<dep[b]) swap(a,b),swap(ans1,ans2);
            res += query(1,id[b],id[a]);
            if(Lc==ans2) res--;
            if(Rc==ans1) res--;
            cout<<res<<endl;
        }
    }
    return 0;
}