NOI2020 命運
阿新 • • 發佈:2021-06-27
題目描述
給定一棵樹 \(T=(V,E)\) 和點對集合 \(\mathcal Q\) ,滿足對於所有 \((u,v)\in \mathcal Q\),都有 \(u\neq v\),並且 \(u\) 是 \(v\) 在樹 \(T\) 上的祖先。
求有多少個不同的函式 \(f:E\to\{0, 1\}\)(將每條邊 \(e\in E\) 的 \(f(e)\) 值置為 \(0\) 或 \(1\)),滿足對於任何 \((u,v)\in\mathcal Q\),都存在 \(u\) 到 \(v\) 路徑上的一條邊 \(e\) 使得 \(f(e)=1\)。
結果對 \(998244353\) 取模。
Solution
考慮 DP。
設 \(f_{u,d}\) 表示節點為 \(u\),沒滿足限制的最大深度為 \(d\) 的方案數。
轉移列舉 \((u,v)\in E\),則
\[f_{u,d}=\sum_{i=0}^{dep_u}{f_{v,i}}f_{u,d}+\sum_{i=0}^{d}{f_{v,i}}f_{u,d}+\sum_{i=0}^{d-1}{f_{u,i}}f_{v,d} \]設 \(s_{u,d}=\sum_{i=0}^{d}f_{u,i}\),則
\[f_{u,d}=f_{u,d}(s_{v,dep_u}+s_{v,d})+f_{v,d}s_{u,d-1} \]線段樹合併優化即可。時間複雜度為 \(O(n\log n)\)
Code
#include<cstdio> #include<vector> using namespace std; const int maxn=1000010,MLY=998244353; template<class T>inline T Max(const T &a,const T &b){return a>b?a:b;} vector<int> mp[maxn]; struct SegmentTree{ int l,r,lc,rc,sum,mul; SegmentTree(){mul=1;} }tr[maxn*10]; int Fir[maxn],Nxt[maxn],Too[maxn],tot; int cnt,dep[maxn],root[maxn],n,m; inline void add(int a,int b){ Too[++tot]=b;Nxt[tot]=Fir[a];Fir[a]=tot; } inline void pushmul(int u,int d){ tr[u].sum=1ll*tr[u].sum*d%MLY; tr[u].mul=1ll*tr[u].mul*d%MLY; } inline void pushdown(int u){ pushmul(tr[u].lc,tr[u].mul);pushmul(tr[u].rc,tr[u].mul); tr[u].mul=1; } inline void pushup(int u){ tr[u].sum=(tr[tr[u].lc].sum+tr[tr[u].rc].sum)%MLY; } inline int modify(int u,int l,int r,int p){ u=++cnt;tr[u].sum=1; if(l==r)return u; int mid=(l+r)>>1; if(p<=mid)tr[u].lc=modify(tr[u].lc,l,mid,p); else tr[u].rc=modify(tr[u].rc,mid+1,r,p); return u; } inline int query(int u,int l,int r,int ql,int qr){ if(!u||ql<=l&&r<=qr)return tr[u].sum; int mid=(l+r)>>1,ans=0; pushdown(u); if(ql<=mid)ans=query(tr[u].lc,l,mid,ql,qr); if(mid<qr)ans=(ans+query(tr[u].rc,mid+1,r,ql,qr))%MLY; return ans; } inline int merge(int p,int q,int l,int r,int &s1,int &s2){ if(!p&&!q)return 0; if(!p){ s1=(s1+tr[q].sum)%MLY; pushmul(q,s2); return q; } if(!q){ s2=(s2+tr[p].sum)%MLY; pushmul(p,s1); return p; } if(l==r){ int tp=tr[p].sum,tq=tr[q].sum; s1=(s1+tq)%MLY; tr[p].sum=(1ll*tr[p].sum*s1+1ll*tr[q].sum*s2)%MLY; s2=(s2+tp)%MLY; return p; } pushdown(p);pushdown(q); int mid=(l+r)>>1; tr[p].lc=merge(tr[p].lc,tr[q].lc,l,mid,s1,s2); tr[p].rc=merge(tr[p].rc,tr[q].rc,mid+1,r,s1,s2); pushup(p); return p; } inline void dfs(int u,int f){ dep[u]=dep[f]+1; int mxdep=0; for(auto t:mp[u])mxdep=Max(mxdep,dep[t]); root[u]=modify(root[u],0,n,mxdep); for(int i=Fir[u];i;i=Nxt[i]){ int v=Too[i]; if(v!=f){ dfs(v,u); int s1=query(root[v],0,n,0,dep[u]),s2=0; root[u]=merge(root[u],root[v],0,n,s1,s2); } } } int main(){ scanf("%d",&n); for(int i=1;i<n;++i){ int x,y; scanf("%d%d",&x,&y); add(x,y);add(y,x); } scanf("%d",&m); for(int i=1;i<=m;++i){ int u,v; scanf("%d%d",&u,&v); mp[v].push_back(u); } dfs(1,1); printf("%d",query(root[1],0,n,0,0)); return 0; }