樹的重心
顯然如果我們直接列舉斷哪條邊,剩下的兩顆樹的重心編號非常不好求,一個常見的想法是我們反過來,考慮每個節點有多少種情況作為樹的重心,計算每個點對答案的貢獻。
下面我們就需要大力分類討論了。假設我們現在考慮計算貢獻的節點為 \(x\),令以其兒子為根的子樹樹的最大的子樹大小為 \(mx_x\),次大值為 \(se_x\),以 \(x\) 為根的子樹大小為 \(S_x\),假設 \(x\) 的一個祖先 \(f\) 斷開了其與父親的邊。那麼如果 \(x\) 要作為以 \(f\) 為根的樹的重心,那麼需要滿足下面一個條件:
\[\begin{cases} S_f - S_x \le \lfloor \frac{S_f}{2} \rfloor\\ mx_x \le \lfloor \frac{S_f}{2} \rfloor \end{cases}\\\]
因為我們需要統計的是滿足條件的 \(y\) 的形式,因此我們儘量將原式化簡成關於 \(y\) 的一個範圍。實際上,上面那個柿子能改寫成 \(2 \times S_f - 2 \times S_x \le S_f\),移項可得 \(S_f \le 2 \times S_x\),類似地下面那個柿子有 \(S_f \ge 2 \times mx_x\),於是我們只需要統計一個點 \(x\) 到根的路徑上有多少個點滿足 \(2 \times mx_x \le S_f \le 2 \times S_x\) 即可。這個我們可以使用樹狀陣列來做,進入每個點是將 \(S_x\) 加入樹狀陣列,出去時刪除,那麼遞迴到每個點時的樹狀陣列就是加入 \(x\)
再考慮第二種情況,斷掉 \(x\) 子樹內的一個點 \(y\) 與其父親的連邊的情況。又要分兩種情況,當 \(y\) 在 \(x\) 的重兒子內時,需要滿足:
\[\begin{cases} se_x \le \lfloor \frac{S_1 - S_y}{2} \rfloor\\ mx_x - S_y \le \lfloor \frac{S_1 - S_y}{2} \rfloor\\ S_1 - S_x \le \lfloor \frac{S_1 - S_y}{2} \rfloor \end{cases}\]
同理上面的化簡可得:\(2 \times mx_x - S_1 \le S_y \le \min\{S_1 - 2 \times se_x, 2 \times S_x - S_1\}\)
再考慮如果 \(y\) 在 \(x\) 的非重兒子內部,那麼我們需要滿足:
\[\begin{cases} mx_x \le \lfloor \frac{S_1 - S_y}{2} \rfloor\\ S_1 - S_x \le \lfloor \frac{S_1 - S_y}{2} \rfloor \end{cases}\]
可以得出 \(S_y \le \min\{S_1 - 2 \times mx_x, 2 \times S_x - S_1\}\),同樣是查詢子樹資訊,用線段樹合併解決。
我們再考慮最後一種情況,當不為 \(x\) 的祖先且不在 \(x\) 子樹內的一個點 \(y\) 斷掉了其與父親的連邊時,需要滿足:
\[\begin{cases} S_1 - S_y - S_x \le \lfloor \frac{S_1 - S_y}{2} \rfloor\\ mx_x \le \lfloor \frac{S_1 - S_y}{2} \rfloor \end{cases}\]
化簡可得:\(S_1 - 2 \times S_x \le S_y \le S_1 - 2 \times mx_x\)。
這個時候我被難住了,我們怎麼知道不包含 \(x\) 的祖先和其子樹內的點滿足條件的所有 \(y\) 呢?其實很簡單,上面我們已經知道了怎麼統計在 \(x\) 的子樹和其子樹內滿足條件的方法了,我們直接考慮容斥即可。考慮使用全域性滿足條件的點數減去在 \(x\) 的祖先和其子樹內滿足條件的點數即可,這個全域性滿足條件的點可以使用 \(ton\) 求出。
程式碼細節比較多,想清楚再碼。
#include<bits/stdc++.h>
using namespace std;
#define N 300000 + 5
#define M 6000000 + 5
#define ls t[p].l
#define rs t[p].r
#define mid (l + r) / 2
#define lowbit(x) (x & (-x))
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
struct edge{
int v, next;
}e[N << 1];
struct tree{
int l, r, sum;
}t[M];
long long ans;
int T, n, u, v, tot, cnt, h[N], c[N], s[N], rt[N], mx[N], se[N], ton[N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void Add(int u, int v){
e[++cnt].v = v, e[cnt].next = h[u], h[u] = cnt;
e[++cnt].v = u, e[cnt].next = h[v], h[v] = cnt;
}
void add(int p, int k){
for(int i = p; i <= n; i += lowbit(i)) c[i] += k;
}
int ask(int p){
int ans = 0; p = min(p, n);
for(int i = p; i >= 1; i -= lowbit(i)) ans += c[i];
return ans;
}
void update(int &p, int l, int r, int x, int y, int k){
if(!p) p = ++tot; t[p].sum += k;
if(l == r) return;
if(mid >= x) update(ls, l, mid, x, y, k);
if(mid < y) update(rs, mid + 1, r, x, y, k);
}
int query(int p, int l, int r, int x, int y){
if(!p || l > r) return 0;
if(l >= x && r <= y) return t[p].sum;
int ans = 0;
if(mid >= x) ans += query(ls, l, mid, x, y);
if(mid < y) ans += query(rs, mid + 1, r, x, y);
return ans;
}
void Merge(int &p, int k, int l, int r){
if(!p || !k){ p = p + k; return;}
if(l == r){ t[p].sum += t[k].sum; return;}
Merge(ls, t[k].l, l, mid), Merge(rs, t[k].r, mid + 1, r);
t[p].sum = t[ls].sum + t[rs].sum;
}
void dfs1(int u, int fa){
s[u] = 1;
Next(i, u){
int v = e[i].v; if(v == fa) continue;
dfs1(v, u), s[u] += s[v];
if(s[v] >= mx[u]) se[u] = mx[u], mx[u] = s[v];
else if(s[v] > se[u]) se[u] = s[v];
}
}
void dfs2(int u, int fa){
int tmp = ask(2 * s[u]) - ask(2 * mx[u] - 1) - (ask(s[1] - 2 * mx[u]) - ask(s[1] - 2 * s[u] - 1));
if(mx[u] <= s[u] / 2 && u != 1) ++tmp;
if(s[1] >= 2 * mx[u] && s[1] <= 2 * s[u] && u != 1) --tmp;
add(s[u], 1);
Next(i, u){
int v = e[i].v; if(v == fa) continue;
dfs2(v, u), Merge(rt[u], rt[v], 1, n);
if(s[v] == mx[u]) tmp += query(rt[v], 1, n, max(1, 2 * mx[u] - s[1]), min(s[1] - 2 * se[u], 2 * s[u] - s[1]));
else tmp += query(rt[v], 1, n, 1, min(s[1] - 2 * mx[u], 2 * s[u] - s[1]));
}
add(s[u], -1), update(rt[u], 1, n, s[u], s[u], 1);
tmp -= query(rt[u], 1, n, max(1, s[1] - 2 * s[u]), s[1] - 2 * mx[u]);
if(s[1] - 2 * mx[u] >= 1) tmp += ton[s[1] - 2 * mx[u]];
if(s[1] - 2 * s[u] >= 1) tmp -= ton[s[1] - 2 * s[u] - 1];
ans += 1ll * tmp * u;
}
int main(){
T = read();
while(T--){
memset(h, 0, sizeof(h)), memset(rt, 0, sizeof(rt));
memset(s, 0, sizeof(s)), memset(mx, 0, sizeof(mx));
memset(se, 0, sizeof(se)), memset(ton, 0, sizeof(ton));
rep(i, 1, tot) t[i].l = t[i].r = t[i].sum = 0;
tot = cnt = ans = 0;
n = read();
rep(i, 1, n - 1) u = read(), v = read(), Add(u, v);
dfs1(1, 0);
rep(i, 1, n) ++ton[s[i]];
rep(i, 1, n) ton[i] = ton[i - 1] + ton[i];
dfs2(1, 0);
printf("%lld\n", ans);
}
return 0;
}