Educational Codeforces Round 129 F. Unique Occurrences
阿新 • • 發佈:2022-05-25
傳送門
\(\texttt{difficulty:2300}\)
題目大意
一棵 \(n(2\le n\le 5\cdot 10^5)\) 個節點的樹,每條邊有權值 \(x(1\le x\le n)\) , \(f(u,v)\) 表示從 \(u\) 到 \(v\) 路徑上所有僅出現過 \(1\) 次的數字的個數,求 \(\sum_{1\le u<v\le n}f(u,v)\) 。
思路
我們考慮每一條邊對於答案的貢獻,其貢獻為所有路徑僅經過改權值的一次的邊的節點對數,我們如果刪掉所有權值為 \(x\) 的邊,那麼原來的樹就會變為若干連通塊,連通塊之間可以通過被刪除的邊重新連成一棵樹,我們記為 \(t_x\)
我們可以通過 \(\texttt{dfs}\) 來解決,先求出以每個節點為根的子樹的大小 \(siz_v\) 。之後再進行一遍 \(\texttt{dfs}\) ,我們對每種權值 \(x\) 分別建立 \(t_x\) ,同時建立 \(n\) 個新的節點作為這些樹的根(不然的話所有的 \(t_x\) 都會共用一個根),並且讓它們的 \(siz\) 都為 \(n\) 。我們可以用所有向上的邊的權值為 \(x\) 的節點表示刪除權值為 \(x\) 的邊後的連通塊,將該節點加入到 \(t_x\) 中,同時我們再對每種權值用棧來記錄其當前的父親節點來進行連邊。
對於連通塊大小的計算,我們在新加入一個節點時,其連通塊的大小就設為 \(siz_v\)
程式碼
#include<bits/stdc++.h> #include<unordered_map> #include<unordered_set> using namespace std; using LL = long long; using ULL = unsigned long long; using PII = pair<LL, LL>; using TP = tuple<int, int, int>; #define all(x) x.begin(),x.end() #define mk make_pair //#define int LL //#define lc p*2 //#define rc p*2+1 #define endl '\n' #define inf 0x3f3f3f3f #define INF 0x3f3f3f3f3f3f3f3f #pragma warning(disable : 4996) #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0) const double eps = 1e-8; const LL MOD = 1000000007; const LL mod = 998244353; const int maxn = 1000010; struct edge { int to, cost; }; vector<edge>G[maxn]; vector<int>T[maxn]; stack<int>S[maxn]; LL N, siz[maxn], nsiz[maxn], root[maxn], ans = 0; void add_edge(int from, int to, int cost) { G[from].push_back(edge{ to,cost }); G[to].push_back(edge{ from,cost }); } void add_edge(int from, int to) { T[from].push_back(to); T[to].push_back(from); } void dfs1(int v, int p) { siz[v] = 1; for (auto& [to, c] : G[v]) { if (to == p) continue; dfs1(to, v); siz[v] += siz[to]; } } void dfs2(int v, int p, int x) { if (x) { nsiz[v] = siz[v]; add_edge(S[x].top(), v); nsiz[S[x].top()] -= siz[v]; S[x].push(v); } for (auto& [to, c] : G[v]) { if (to == p) continue; dfs2(to, v, c); } if (x) S[x].pop(); } void dfs3(int v, int p) { for (auto&to : T[v]) { if (to == p) continue; ans += nsiz[v] * nsiz[to]; dfs3(to, v); } } void solve() { for (int i = 1; i <= N; i++) root[i] = N + i, S[i].push(N + i), nsiz[root[i]] = N; dfs1(1, 0), dfs2(1, 0, 0); for (int i = 1; i <= N; i++) dfs3(root[i], 0); cout << ans << endl; } int main() { IOS; cin >> N; int u, v, x; for (int i = 1; i < N; i++) cin >> u >> v >> x, add_edge(u, v, x); solve(); return 0; }