1. 程式人生 > >[BZOJ3451]normal 點分治,NTT

[BZOJ3451]normal 點分治,NTT

mem using mit gist 怎樣 read dac 另一個 pro

[BZOJ3451]normal 點分治,NTT

好久沒更博了,咕咕咕。

BZOJ3451權限題,上darkbzoj交吧。

一句話題意,求隨機點分治的期望復雜度。

考慮計算每個點對的貢獻:如果一個點在點分樹上是另一個點的祖先,那麽這個點對另一個點的貢獻就是1,這樣的話,這個點就必須是這兩個點之間的鏈上的點中在點分樹上深度最淺的點,由於鏈上每個點成為點分樹上最淺的點的概率都是相等的,所以這個點對對最終的期望的貢獻就是\(\frac{1}{dis(i, j) + 1}\),這裏的\(dis(i, j)\)習慣上認為是邊的條數,\(+1\)就變成了點的個數。現在我們要求的就是\(\sum \limits _{i = 1} ^{n} \sum \limits _{j = 1} ^{n} \frac {1} {dis(i, j) + 1}\)

考慮怎樣在樹上統計,明顯需要點分治,每次統計從當前分治中心出發的所有長度的路徑條數,跟先前統計過的子樹合並一下就好了。觀察合並的式子,設\(f[i]\)表示這次合並統計的長度為\(i\)的路徑條數,\(p[i]\)表示以前統計過的子樹中長度為\(i\)的路徑條數,\(q[i]\)表示這次統計的長度為\(i\)的路徑條數,顯然有\(f[k] = \sum \limits _{i + j = k} p[i] * q[j]\),明顯是一個卷積的形式,可以FFT,但是發現\(f\)數組的值肯定不會超過NTT模數,於是直接NTT。時間復雜度\(O(n (\log n) ^ 2\)

實現的時候要註意,每次從分支中心出發統計時,必須先把子樹按深度或大小排一遍序,從小往大處理,不然一個掃把圖就能把你卡成\(n ^ 2 \log n\)

,可以自己畫圖手玩。

#include <cstdio>
#include <cctype>
#include <vector>
#include <cstring>
#include <algorithm>
#define R register
#define I inline
#define D double
#define L long long
#define B 1000000
using namespace std;
const int N = 32777, P = 998244353, G = 3, H = 332748118;
char buf[B], *p1, *p2;
I char gc() { return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, B, stdin), p1 == p2) ? EOF : *p1++; }
I int rd() {
    R int f = 0;
    R char c = gc();
    while (c < 48 || c > 57)
        c = gc();
    while (c > 47 && c < 58)
        f = f * 10 + (c ^ 48), c = gc();
    return f;
}
int s[N], t[N], v[N], p[N], q[N], a[N], b[N], d[N], c[N], h[N], u, r, S, e, F, M;
D o;
vector <int> g[N];
I int max(int x, int y) { return x > y ? x : y; }
I void swp(int &x, int &y) { x ^= y, y ^= x, x ^= y; }
I int cmp(int x, int y) { return t[x] < t[y]; }
void gsz(int x, int f) {
    t[x] = 1;
    for (R int i = 0, y; i < s[x]; ++i)
        if (!v[y = g[x][i]] && y ^ f)
            gsz(y, x), t[x] += t[y];
}
void grt(int x, int f, int a) {
    R int m = 0, i, y;
    for (i = 0; i < s[x]; ++i)
        if (!v[y = g[x][i]] && y ^ f)
            m = max(m, t[y]), grt(y, x, a);
    m = max(m, a - t[x]);
    if (m < u)
        u = m, r = x;
}
void dfs(int x, int f, int d) {
    c[++e] = d;
    for (R int i = 0, y; i < s[x]; ++i)
        if (!v[y = g[x][i]] && y ^ f)
            dfs(y, x, d + 1);
}
I L pwr(L a, L b) {
    L r = 1;
    for (; b; b >>= 1, a = a * a % P)
        if (b & 1)
            r = r * a % P;
    return r;
}
void ntt(int *f, int v) {
    R int i, j, k, t;
    L p, q, o;
    for (i = 0; i < M; ++i)
        if (i < d[i])
            swp(f[i], f[d[i]]);
    for (i = 1; i < M; i <<= 1) {
        t = i << 1, p = pwr(v ? G : H, (P - 1) / t);
        for (j = 0; j < M; j += t)
            for (q = 1, k = 0; k < i; ++k)
                o = q * f[i + j + k] % P, f[i + j + k] = (f[j + k] - o + P) % P, f[j + k] = (f[j + k] + o + P) % P, q = q * p % P;
    }
}
void dac(int x) {
    R int i, j, y, z;
    p[0] = 1, h[0] = 0, u = S, gsz(x, 0), grt(x, 0, t[x]), v[r] = 1, sort(&g[r][0], s[r] + &g[r][0], cmp);
    for (i = 0; i < s[r]; ++i)
        if (!v[y = g[r][i]]) {
            for (M = 1; M <= t[x]; M <<= 1) {}
            e = 0, dfs(y, r, 1), F = pwr(M, P - 2), d[0] = 0;
            for (z = M >> 1, j = 0; j < M; ++j)
                d[j] = (d[j >> 1] >> 1)|((j & 1) ? z : 0);
            for (j = 1; j <= e; ++j)
                ++q[c[j]], h[++h[0]] = c[j];
            memcpy(a, p, M * 4), memcpy(b, q, M * 4);
            ntt(a, 1), ntt(b, 1);
            for (j = 0; j < M; ++j)
                a[j] = 1ll * a[j] * b[j] % P;
            ntt(a, 0);
            for (j = 0; j < M; ++j)
                p[j] += q[j], a[j] = 1ll * a[j] * F * 2 % P, o += (D)a[j] / (j + 1);
            for (j = 1; j <= e; ++j)
                q[c[j]] = 0;
        }
    for (i = 1; i <= h[0]; ++i)
        p[h[i]] = 0;
    for (x = r, i = 0; i < s[x]; ++i)
        if(!v[y = g[x][i]])
            dac(y);
}
int main() {
    R int n = rd(), i, x, y;
    for (S = 1; S <= n; S <<= 1) {}
    for (i = 1; i < n; ++i)
        x = rd() + 1, y = rd() + 1, g[x].push_back(y), g[y].push_back(x);
    for (i = 1; i <= n; ++i)
        s[i] = g[i].size();
    o = n, dac(1), printf("%.4lf", o);
    return 0;
}

碼風也變了。。。

[BZOJ3451]normal 點分治,NTT