codeforces 543d Road Improvement 樹形dp (★ )
阿新 • • 發佈:2019-01-02
題意:
在一棵根節點為1的樹上,一開始所有的路都是壞的,現在你可以修路。
讓你找出以每個節點為首都,到達其他任意一個節點所經過的壞路不超過1條的方案數。
思路:
參考別人的。參考完之後又感覺思想狠簡單。。。
dp1[i]:以i為根的子樹,滿足到達這棵子樹的任意節點的壞路不超過1條的方案數。
dp2[i]:從該點出發,往父親方向的滿足要求的方案數。
首先先看dp1。假設當前節點為u,v為其子節點,則u~v之間的路只有修和不修兩種。
修:方案數有dp1[v]; 不修:方案數只有1,下面的路都是要修的。
則 dp1[u] = (dp1[v1]+1)*(dp1[v2]+1)*……
轉換根節點,現在看dp2,假設要求節點dp2[v1],u為v的父親節點。(dp2[1] = 1)
v1的方案數來自父親節點往上的方案數還有其他兄弟的方案數。
則dp2[v1] = (dp2[u]*(dp1[v2]+1)*(dp1[v3]+1)*……)+1;(加1的原因跟上面一樣)
如果不能理解,就畫個圖感受一下。
1.有的同學可能會想到用乘法逆元來求其他兄弟的方案數,實際上這會WA10;
2.求出每個節點的字首積和字尾積即可避免除法(逆元),如果每次都重新求兄弟方案數積TLE39.
#include <bits/stdc++.h> using namespace std; const int N = 2e5+5; const int MOD = 1e9+7; typedef long long LL; int n; int head[N], cnt = 0; LL dp1[N], dp2[N]; int f[N]; vector <LL> front[N]; vector <LL> back[N]; vector <LL> tt[N]; struct Edge { int v, next; }e[N<<1]; void addEdge(int u, int v) { e[cnt] = (Edge){v, head[u]}; head[u] = cnt++; } inline void cal(LL &a, LL b) { a = a*b%MOD; } void dfs(int u, int par) { dp1[u] = 1; f[u] = par; int cnt = 0; for(int i = head[u];i != -1; i = e[i].next) { if(e[i].v == par) continue; int v = e[i].v; dfs(v, u); cal(dp1[u], dp1[v]+1); tt[u].push_back(dp1[v]+1); cnt++; } LL tmp = 1, tmp2 = 1; for(int i = 0, j = cnt-1;i < cnt; i++, j--) { cal(tmp, tt[u][i]); cal(tmp2, tt[u][j]); front[u].push_back(tmp); back[u].push_back(tmp2); } /* cout<<"u = "<<u<<endl; for(int i = 0;i < cnt; i++) { cout<<front[u][i]<<" "; } cout<<endl; for(int i = 0;i < cnt; i++) { cout<<back[u][i]<<" "; } cout<<endl; */ } /* LL just(int v, int u) { LL ret = 1; for(int i = head[u];i != -1; i = e[i].next) { if(e[i].v == v || e[i].v == f[u]) continue; int tv = e[i].v; cal(ret, dp1[tv]+1); } cal(ret, dp2[u]); ret++; return ret; } */ LL just(int v, int u, int idx) { LL ret = 1; if(front[u].size() > 0) { if(idx > 0) cal(ret, front[u][idx-1]); if(idx < (back[u].size()-1)) { int real = back[u].size()-idx-2; cal(ret, back[u][real]); } } cal(ret, dp2[u]); ret++; return ret; } void dfs2(int u, int par) { int idx = 0; for(int i = head[u];i != -1; i = e[i].next) { if(e[i].v == par) continue; int v = e[i].v; dp2[v] = just(v, u, idx++); dfs2(v, u); } } void solve() { dfs(1, -1); dp2[1] = 1; dfs2(1, -1); for(int i = 1;i <= n; i++) printf("%I64d%c", dp2[i]*dp1[i]%MOD, i == n?'\n':' '); } int main() { scanf("%d", &n); memset(head, -1, sizeof(head)); for(int i = 2;i <= n; i++) { int v; scanf("%d", &v); addEdge(v, i); addEdge(i, v); } solve(); return 0; }