codeforces1467E Distinctive Roots in a Tree
技術標籤:資料結構
題目:
給定一棵
n
n
n個結點的樹,每個點有一個權值
a
i
a_i
ai,問有多少個點
u
u
u滿足從
u
u
u出發到其他所有點的路徑中都不存在兩個權值相同的點。
(
1
≤
n
≤
2
×
1
0
5
,
1
≤
a
i
≤
1
0
9
)
(1 \le n \le 2 \times10^5,1 \le a_i \le 10^9)
(1≤n≤2×105,1≤ai≤109)
題解:
一開始拿到這個題我想的是用換根
d
p
dp
dp,但發現有一些資訊維護不了。
這道題應該從性質入手,先隨便確定一個根,考慮有兩個權值相同的點的情況,設為
u
,
v
u,v
(1)
u
,
v
u,v
u,v存在祖先關係
不妨設
u
u
u是
v
v
v的祖先,那麼顯然可能滿足題意的點只能是在
v
v
v所在的那棵以
u
u
u的某個子節點為根的子樹中的而又不在以
v
v
v為根的子樹中的點。覺得不清楚可以看圖:
(2)
u
,
v
u,v
u,v不存在祖先關係
那麼可以確定在以
u
u
u為根的子樹中和以
v
v
v為根節點的子樹中的點是不滿足題意的。
那麼上述的判法將所有不滿足題意的點都判掉了嗎?答案是肯定的,因為如果一個點不滿足題意,那麼這個點到其他所有結點的路徑中一定會出現至少兩者之一的情況,所以所有不合法的點都會被判掉。
正確性得到證明以後,來考慮具體的演算法實現。我們可以用打01標記的方式來維護點的合法性,由於有多次更新操作和單次查詢操作,所以可以用樹上差分來維護。我們用
f
u
=
0
/
1
f_u=0/1
複雜度: O ( n l o g n ) O(nlogn) O(nlogn)
程式碼:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<string>
#include<bitset>
#include<sstream>
#include<ctime>
//#include<chrono>
//#include<random>
//#include<unordered_map>
using namespace std;
#define ll long long
#define ls o<<1
#define rs o<<1|1
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define sz(x) (int)(x).size()
#define all(x) (x).begin(),(x).end()
const double pi=acos(-1.0);
const double eps=1e-6;
const int mod=1e9+7;
const int INF=0x3f3f3f3f;
const int maxn=2e5+5;
ll read(){
ll x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,ans;
vector<int>g[maxn];
int f[maxn],a[maxn];
map<int,int>num,sta;
void dfs(int u,int fa){
int base=sta[a[u]];
sta[a[u]]++;
for(auto v:g[u]){
if(v==fa)continue;
int tmp=sta[a[u]];
dfs(v,u);
if(tmp!=sta[a[u]]){
f[1]++;
f[v]--;
}
}
int diff=sta[a[u]]-base;
if(diff!=num[a[u]]){
f[u]++;
}
}
void solve(int u,int fa,int x){
if(!x){
++ans;
}
for(auto v:g[u]){
if(v==fa)continue;
solve(v,u,x+f[v]);
}
}
int main(void){
// freopen("in.txt","r",stdin);
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
num[a[i]]++;
}
int u,v;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
g[u].pb(v);
g[v].pb(u);
}
dfs(1,0);
ans=0;
solve(1,0,f[1]);
printf("%d\n",ans);
return 0;
}