[NOI2020]命運 題解
阿新 • • 發佈:2020-08-29
考慮\(dp.\)
記\(dp_{x,i}\)表示\(x\)子樹內的所有邊以及\(x\)到\(x\)的父親的邊的狀態\((\) 是否選取 \()\) 都決定好了\(,\)目前還沒有被解決的限制的最大深度\(\leq i\)的方案數\(.\)
不難發現轉移時相當於每個兒子的\(dp\)陣列對位相乘\(,\)最後要求一個當前\(dp\)陣列的和加到全域性上\(.\)
可以用帶\(tag\)的線段樹合併維護\(.\)
我的考場程式碼寫的是啟發式合併\(.\)
啟發式合併程式碼\(:\)
#include <bits/stdc++.h> #define LL long long using namespace std; inline int read(){ static int x; static char c; x = 0,c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) x = x * 10 + c - '0',c = getchar(); return x; } const int N = 500050,P = 998244353; int To[N<<1],Ne[N<<1],He[N],_; inline void adde(int x,int y){ ++_; To[_] = y,Ne[_] = He[x],He[x] = _; ++_; To[_] = x,Ne[_] = He[y],He[y] = _; } int n,m,fa[N],dpt[N],lim[N]; inline void dfs(int x){ dpt[x] = dpt[fa[x]] + 1; for (int y,p = He[x]; p ; p = Ne[p]) if ((y=To[p])^fa[x]) fa[y] = x,dfs(y); } const int V = N * 60; inline void upd(int &x,int y){ x = (x+y>=P)?(x+y-P):(x+y); } int lc[V],rc[V],val[V],mul[V],add[V],siz[V],cnto; int stk[V],top; inline int New(){ static int o; if (top) o = stk[top],--top; else o = ++cnto; lc[o] = rc[o] = add[o] = siz[o] = val[o] = 0,mul[o] = 1; return o; } inline void DD(int x){ stk[++top] = x; } inline void tmul(int o,int v){ if (o) mul[o] = (LL)mul[o] * v % P,add[o] = (LL)add[o] * v % P,val[o] = (LL)val[o] * v % P; } inline void tadd(int o,int v){ if (o) upd(add[o],v),upd(val[o],v); } inline void down(int o){ if (mul[o] ^ 1) tmul(lc[o],mul[o]),tmul(rc[o],mul[o]),mul[o] = 1; if (add[o]) tadd(lc[o],add[o]),tadd(rc[o],add[o]),add[o] = 0; } inline void up(int o){ val[o] = val[rc[o] ? rc[o] : lc[o]],siz[o] = siz[lc[o]] + siz[rc[o]]; } int pp,vv; inline void Del(int &o,int l,int r){ if (!o || r < pp) return; if (l >= pp){ DD(o),o = 0; return; } down(o); int mid = l+r>>1; Del(lc[o],l,mid); Del(rc[o],mid+1,r); up(o); if (siz[o] == 0) DD(o),o = 0; } inline void Ins(int &o,int l,int r){ if (!o) o = New(); if (l == r){ siz[o] = 1,val[o] = vv; return; } down(o); int mid = l+r>>1; if (pp <= mid) Ins(lc[o],l,mid); else Ins(rc[o],mid+1,r); up(o); } int ll,rr; inline void Mul(int o,int l,int r){ if (!o) return; if (ll <= l && rr >= r){ tmul(o,vv); return; } down(o); int mid = l+r>>1; if (ll <= mid) Mul(lc[o],l,mid); if (rr > mid) Mul(rc[o],mid+1,r); up(o); } inline void Add(int o,int l,int r){ if (!o) return; if (ll <= l && rr >= r){ tadd(o,vv); return; } down(o); int mid = l+r>>1; if (ll <= mid) Add(lc[o],l,mid); if (rr > mid) Add(rc[o],mid+1,r); up(o); } inline bool Is(int o,int l,int r){ if (!o) return 0; if (l == r) return 1; down(o); int mid = l+r>>1; if (pp <= mid) return Is(lc[o],l,mid); return Is(rc[o],mid+1,r); } int qans,qi; inline void Query(int o,int l,int r){ if (!o || qi >= r || l >= pp) return; if (r < pp){ qi = r,qans = val[o]; return; } down(o); int mid = l+r>>1; Query(rc[o],mid+1,r); Query(lc[o],l,mid); } inline void radd(int rt,int l,int r,int v){ ll = l,rr = r,vv = v,Add(rt,0,n); } inline void rmul(int rt,int l,int r,int v){ ll = l,rr = r,vv = v,Mul(rt,0,n); } inline void rins(int &rt,int p,int v){ pp = p,vv = v,Ins(rt,0,n); } inline void rdel(int &rt,int p){ pp = p,Del(rt,0,n); } inline bool ris(int rt,int p){ pp = p; return Is(rt,0,n); } int ti[N],tv[N],cntt; inline void Dfs(int o,int l,int r){ if (!o) return; if (l == r){ ++cntt; ti[cntt] = l,tv[cntt] = val[o]; DD(o); return; } down(o); int mid = l+r>>1; Dfs(lc[o],l,mid); Dfs(rc[o],mid+1,r); DD(o); } int ans; inline void Ask(int o,int l,int r){ if (!o) return; if (l == r){ ans = val[o]; return; } down(o); int mid = l+r>>1; if (lc[o]) Ask(lc[o],l,mid); else Ask(rc[o],mid+1,r); } inline void Dfs2(int o,int l,int r){ if (!o) return; if (l == r){ ++cntt; ti[cntt] = l,tv[cntt] = val[o]; return; } down(o); int mid = l+r>>1; Dfs2(lc[o],l,mid); Dfs2(rc[o],mid+1,r); } inline int Merge(int rt1,int rt2){ if (siz[rt1] < siz[rt2]) swap(rt1,rt2); cntt = 0,Dfs(rt2,0,n); ti[++cntt] = n+1; for (int i = cntt-1; i >= 1; --i){ if (!ris(rt1,ti[i])){ pp = ti[i],qi = -1,Query(rt1,0,n); rins(rt1,ti[i],qans); } rmul(rt1,ti[i],ti[i+1]-1,tv[i]); } return rt1; } int T[N]; inline void dp(int x){ if (lim[x]) rins(T[x],0,0); rins(T[x],lim[x],1); for (int y,p = He[x]; p ; p = Ne[p]) if ((y=To[p])^fa[x]) dp(y),T[x] = Merge(T[x],T[y]); if (x == 1){ Ask(T[1],0,n); return; } radd(T[x],0,n,val[T[x]]); rdel(T[x],dpt[fa[x]]); //cerr << "print " << x <<'\n',Print(T[x]); } int main(){ // freopen("destiny.in","r",stdin); // freopen("destiny.out","w",stdout); int i,x,y; n = read(); for (i = 1; i < n; ++i) x = read(),y = read(),adde(x,y); dfs(1); for (i = 1; i <= n; ++i) lim[i] = 0; m = read(); while (m--) x = read(),y = read(),lim[y] = max(lim[y],dpt[x]); dp(1); cout << ans << '\n'; }