1. 程式人生 > >樹上游戲

樹上游戲

題目描述

lrb有一棵樹,樹的每個節點有個顏色。給一個長度為n的顏色序列,定義s(i,j) 為i 到j 的顏色數量

\(Sum_i = \sum_{i=1}^{n}{s_{i,j}}\)

現在他想讓你求出所有的sum[i]

輸入輸出格式

輸入格式:

第一行為一個整數n,表示樹節點的數量

第二行為n個整數,分別表示n個節點的顏色c[1],c[2]……c[n]

接下來n-1行,每行為兩個整數x,y,表示x和y之間有一條邊

輸出格式:

輸出n行,第i行為sum[i]

輸入輸出樣例

輸入樣例#1:

5
1 2 3 2 3
1 2
2 3
2 4
1 5

輸出樣例#1:

10
9
11
9
12

說明

sum[1]=s(1,1)+s(1,2)+s(1,3)+s(1,4)+s(1,5)=1+2+3+2+2=10
sum[2]=s(2,1)+s(2,2)+s(2,3)+s(2,4)+s(2,5)=2+1+2+1+3=9
sum[3]=s(3,1)+s(3,2)+s(3,3)+s(3,4)+s(3,5)=3+2+1+2+3=11
sum[4]=s(4,1)+s(4,2)+s(4,3)+s(4,4)+s(4,5)=2+1+2+1+3=9
sum[5]=s(5,1)+s(5,2)+s(5,3)+s(5,4)+s(5,5)=2+3+3+3+1=12
對於40%的資料,n<=2000

對於100%的資料,1<=n,c[i]<=10^5


題解

點分治

一堆細節

要算每兩個點對之間顏色種類的個數

\(O(n^2)\)很好做,只需要列舉每個點當根,當某種顏色第一次出現時對根的答案貢獻+Size[u]

考慮怎麼優化這個過程

可以使用點分治來處理每個點對答案的貢獻來做到不重不漏

因為點分治每次處理的是經過分治重心的路徑的貢獻

所以我們先處理出從分治重心開始到各個子樹的每個顏色的路徑條數\(val[col[u]]\)

然後對於每個子樹

先減掉該子樹對\(val[]\)的貢獻

然後計算出sum_val\(= \sum_{i=1}^{i<=colnum}{val[i]}\)

對於從u到根(不包含根)的路徑上的出現過的顏色,把他們從sum_val中減掉

再加上當前分治塊的大小減掉當前分治子樹的大小

表示這種顏色對答案的貢獻是跨過分治重心的所有路徑

最後清空一下陣列就好辣

程式碼

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define int long long
const int M = 100005 ;
const int INF = 1e9 + 7 ;
using namespace std ;
inline int read() {
    char c = getchar() ; int x = 0 , w = 1 ;
    while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
    while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
    return x*w ;
}

bool vis[M] , hap[M] ;
int n , col[M] , hea[M] , num ;
int tot , rt , tmin , size[M] ;
int appear[M] , val[M] ;
int Sum[M] , Bsz , sum_val ;
int colnum , tcl[M] , Tag ;


struct E { int Nxt , to ; } edge[M << 1] ;
inline void add_edge(int from , int to) { 
    edge[++num].Nxt = hea[from] ; 
    edge[num].to = to ;hea[from] = num ; 
}
void Getroot(int u , int father) {
    size[u] = 1 ; int Mx = -1 ; 
    for(int i = hea[u] ; i ; i = edge[i].Nxt) {
        int v = edge[i].to ; if(v == father || vis[v]) continue ;
        Getroot(v , u) ; size[u] += size[v] ; Mx = max(Mx , size[v]) ;
    }
    Mx = max(Mx , tot - size[u]) ; if(Mx < tmin) tmin = Mx , rt = u ;
}
void FirDfs(int u , int father) {
    if(!hap[col[u]]) {
        tcl[++colnum] = col[u] ;
        hap[col[u]] = true ;
    }
    size[u] = 1 ; appear[col[u]] ++ ;
    for(int i = hea[u] ; i ; i = edge[i].Nxt) {
        int v = edge[i].to ; if(v == father || vis[v]) continue ;
        FirDfs(v , u) ; size[u] += size[v] ;
    }
    appear[col[u]] -- ;
    if(!appear[col[u]])  val[col[u]] += size[u] ;
}
void Update(int u , int father , int dlt) {
    appear[col[u]] ++ ;
    for(int i = hea[u] ; i ; i = edge[i].Nxt) {
        int v = edge[i].to ; if(v == father || vis[v]) continue ;
        Update(v , u , dlt) ; 
    }
    appear[col[u]] -- ; if(!appear[col[u]]) val[col[u]] += size[u] * dlt ;
}
void GetAns(int u , int father) {
    if(!appear[col[u]]) {
        Tag -= val[col[u]] ;
        Tag += Bsz ;
    }
    appear[col[u]] ++ ;
    Sum[u] += sum_val + Tag ;
    for(int i = hea[u] ; i ; i = edge[i].Nxt) {
        int v = edge[i].to ; if(vis[v] || v == father) continue ;
        GetAns(v , u) ;
    }
    appear[col[u]] -- ;
    if(!appear[col[u]]) {
        Tag += val[col[u]] ;
        Tag -= Bsz ;
    }
}
void Dfs(int u) {
    FirDfs(u , u) ; vis[u] = true ;
    sum_val = 0 ;
    for(int i = 1 ; i <= colnum ; i ++) 
        sum_val += val[tcl[i]] ;
    Sum[u] += sum_val ;
    for(int i = hea[u] ; i ; i = edge[i].Nxt) {
        int v = edge[i].to ; if(vis[v]) continue ;
        appear[col[u]] = 1 ; val[col[u]] -= size[v] ;
        Update(v , u , -1) ; 
        Bsz = size[u] - size[v] ; sum_val = 0 ; 
        for(int j = 1 ; j <= colnum ; j ++) sum_val += val[tcl[j]] ;
        appear[col[u]] = 0 ; Tag = 0 ;
        GetAns(v , u) ; 
        val[col[u]] += size[v] ;
        appear[col[u]] = 1 ;
        Update(v , u , 1) ;
        appear[col[u]] = 0 ;
    }
    for(int i = 1 ; i <= colnum ; i ++) val[tcl[i]] = 0 , hap[tcl[i]] = false ;
    colnum = 0 ;
    for(int i = hea[u] ; i ; i = edge[i].Nxt) {
        int v = edge[i].to ; if(vis[v]) continue ;
        tot = size[u] ; tmin = INF ;
        Getroot(v , u) ; Dfs(rt) ;
    }
}
# undef int
int main() {
# define int long long
    n = read() ;
    for(int i = 1 ; i <= n ; i ++) col[i] = read() ;
    for(int i = 1 , u , v ; i < n ; i ++) {
        u = read() , v = read() ;
        add_edge(u , v) ; add_edge(v , u) ;
    }
    tot = n , tmin = INF ;
    Getroot(1 , 1) ; Dfs(rt) ;
    for(int i = 1 ; i <= n ; i ++) printf("%lld\n",Sum[i]) ;
    return 0 ;
}