1. 程式人生 > >「九省聯考 2018」秘密襲擊

「九省聯考 2018」秘密襲擊

return opera cout 聯通塊 優化 per clear 復合 inline

「九省聯考 2018」秘密襲擊

解題思路

\(a_i\) 為樹上聯通塊第 \(k\) 大大於等於 \(i\) 的個數,那麽答案就是
\[ \sum _{i=1}^Wi(a_i-a_{i+1})=\sum_{i=1}^W a_i \]
\(dp[u][i][j]\) 表示以 \(u\) 為根的聯通子樹,大於等於 \(i\) 的點有 \(j\) 個的方案數,把最後一維寫成生成函數的形式
\[ f(u,i)=\sum dp[u][i][j]x^j \]
轉移也可以用生成函數的形式表示
\[ f(u,i)=\prod_{u\rightarrow v}(f(v,i)+1)\times \begin{cases} x&d_u \geq i\\1&otherwise\end{cases} \]


用三模數 \(\text{NTT}\) 轉移可以 \(\mathcal O(n^3logn)\) ,後面一維子樹大小有關點分優化一下可以 \(\mathcal O(n^2log^2n)\) ,甚至不如暴力,其實標算也跑不過暴力

這種時候就可以往點值這方向考慮,這裏的操作點值是直接相乘,所以用線段樹合並維護非常方便,另外還需要記一下所有 \(f(u,i)\) 和的多項式的點值。

我們將 \(n+1\) 個點帶進去算出其在每一個 \(\sum_u f(u,i)\) 下的點值,用線段樹合並維護的話需要的操作有:

維護兩個 \((f,g)\) 表示點值以及子樹點值和

  1. \((f,g)\rightarrow(f+1,g)\)
  2. \((f,g)\rightarrow (f,g+f)\)
  3. \((f,g)\rightarrow(f*x,g)\)
  4. \((f,g)\rightarrow(f,g+x)\)

這些操作都可以寫作函數 \(tr(a,b,c,d)\) 的形式,表示 \((f,g)\rightarrow (af+b,g+cf+d)\)

這個函數滿足結合律並且是封閉的,推一下復合就可以合並標記了,然後線段樹合並的時候因為要支持下傳標記,所以要在某個點沒有左右兒子的時候將已經確定的點值轉移到另外一個點上,不然復雜度會掛。

聽說這個套路叫整體 \(\text{DP}\) ,我不太會這一套理論QwQ。

最後需要用拉格朗日插值把系數全部求出來,考慮拉格朗日插值最基本的式子

\[ F(x)=\sum_i y_i\prod_{j\neq i}\dfrac{x-x_j}{x_i-x_j} \]
\(w =\prod (x-x_i)\),就能得到
\[ F(x)=\sum_i y_i\dfrac{w}{x-x_i}\prod_{j\neq i}\dfrac{1}{x_i-x_j} \]
先求出 \(w\) 的每一項的系數,然後每次把 \((x-x_i)\) 除掉,退位維護一下系數乘上其它常數加到答案的多項式上就可以了,這東西也可以分治 \(\text{NTT}\) 優化,不過這題沒有啥必要。

一通操作下來整道題復雜度是 \(\mathcal O (n^2logn)\) ,跑的大概是暴力的 \(10\) 倍,有妹子的 \(\text{txc}\) 天上第二。


code

/*program by mangoyang*/ 
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 5005, mod = 64123;
struct Node{ 
    int a, b, c, d;
    inline void init(){ a = 1, b = c = d = 0; }
    Node operator * (const Node & A) const{
        return (Node){
            (int) (1ll * A.a * a % mod),
            (int) ((1ll * A.a * b + A.b) % mod),
            (int) ((1ll * A.c * a + c) % mod),
            (int) ((1ll * A.c * b + A.d + d) % mod) 
        };  
    }
};
vector<int> g[N];
int Y[N], d[N], rt[N], n, W, k;
namespace Seg{
    #define mid ((l + r) >> 1)
    Node tag[N*32];
    int lc[N*32], rc[N*32], st[N*32], top, size;
    inline int newnode(){ 
        return ++size, lc[size] = rc[size] = 0, tag[size].init(), size; 
    }
    inline void clear(int x){
        lc[x] = rc[x] = 0, tag[x].init(), st[++top] = x;
    }
    inline void pushdown(int u){
        if(!lc[u]) lc[u] = newnode();
        if(!rc[u]) rc[u] = newnode();
        tag[lc[u]] = tag[lc[u]] * tag[u];
        tag[rc[u]] = tag[rc[u]] * tag[u], tag[u].init();
    }
    inline void change(int &u, int l, int r, int L, int R, Node x){
        if(!u) u = newnode();
        if(l >= L && r <= R) return (void) (tag[u] = tag[u] * x);
        pushdown(u);
        if(L <= mid) change(lc[u], l, mid, L, R, x);
        if(mid < R) change(rc[u], mid + 1, r, L, R, x);
    }
    inline int merge(int x, int y){
        if(!x || !y) return x + y;
        if(!lc[x] && !rc[x]) swap(x, y);
        if(!lc[y] && !rc[y])
            tag[x] = tag[x] * (Node){tag[y].b, 0, 0, tag[y].d};
        else{
            pushdown(x), pushdown(y);
            lc[x] = merge(lc[x], lc[y]);
            rc[x] = merge(rc[x], rc[y]);
        }
        return x;
    }
    inline void getnode(int u, int l, int r, int x){
        if(!u) return;
        if(l == r) return (void) ((Y[x] += tag[u].d) %= mod);
        pushdown(u);
        getnode(lc[u], l, mid, x), getnode(rc[u], mid + 1, r, x);
    }
}
inline void dfs(int u, int fa, int x){
    Seg::change(rt[u], 1, W, 1, W, (Node){0, 1, 0, 0});
    for(int i = 0; i < (int) g[u].size(); i++){
        int v = g[u][i];
        if(v == fa) continue;
        dfs(v, u, x), rt[u] = Seg::merge(rt[u], rt[v]), rt[v] = 0;
    }
    if(d[u]) Seg::change(rt[u], 1, W, 1, d[u], (Node){x, 0, 0, 0});
    Seg::change(rt[u], 1, W, 1, W, (Node){1, 0, 1, 0});
    Seg::change(rt[u], 1, W, 1, W, (Node){1, 1, 0, 0});
}
inline int Pow(int a, int b){
    int ans = 1;
    for(; b; b >>= 1, a = 1ll * a * a % mod)
        if(b & 1) ans = 1ll * ans * a % mod;
    return ans;
}
inline void dec(int *a, int *b, int x){
    static int tmp[N];
    for(int i = 0; i <= n + 1; i++) tmp[i] = a[i];
    for(int i = n + 1; i >= 1; i--){
        b[i-1] = tmp[i];
        (tmp[i-1] += 1ll * x * tmp[i] % mod) %= mod;
    }
}
inline int Lagrange() {
    static int G[N], F[N], inv[N], ans; G[0] = 1;
    for(int i = 1; i <= n; i++) inv[i] = Pow(i, mod - 2);
    for(int i = n + 1; i >= 1;--i)
        for(int j = n + 1; j >= 0; j--){
            G[j] = 1ll * (mod - i) * G[j] % mod;
            if(j) (G[j] += G[j-1]) %= mod;
    }
    for(int i = 1; i <= n + 1; i++){
        dec(G, F, i); int res = 0;
        for(int j = k; j <= n; j++) (res += F[j]) %= mod;
        for(int j = 1; j <= n + 1; j++) if(i != j) {
            if(j < i) res = 1ll * res * inv[i-j] % mod;
            else res = 1ll * res * (mod - inv[j-i]) % mod;
        }   
        res = 1ll * res * Y[i] % mod; (ans += res) %= mod;
    }
    return ans;
}
int main(){
    read(n), read(k), read(W);
    for(int i = 1; i <= n; i++) read(d[i]);
    for(int i = 1, x, y; i < n; i++){
        read(x), read(y);
        g[x].push_back(y), g[y].push_back(x);
    }
    for(int i = 1; i <= n + 1; i++){
        dfs(1, 0, i);
        Seg::getnode(rt[1], 1, W, i);
        rt[1] = Seg::size = 0;
    }
    cout << Lagrange() << endl;
    return 0;
}

「九省聯考 2018」秘密襲擊