1. 程式人生 > >codeforces 543d Road Improvement 樹形dp (★ )

codeforces 543d Road Improvement 樹形dp (★ )

題意:

在一棵根節點為1的樹上,一開始所有的路都是壞的,現在你可以修路。

讓你找出以每個節點為首都,到達其他任意一個節點所經過的壞路不超過1條的方案數。

思路:

參考別人的。參考完之後又感覺思想狠簡單。。。

dp1[i]:以i為根的子樹,滿足到達這棵子樹的任意節點的壞路不超過1條的方案數。

dp2[i]:從該點出發,往父親方向的滿足要求的方案數。

首先先看dp1。假設當前節點為u,v為其子節點,則u~v之間的路只有修和不修兩種。

修:方案數有dp1[v];   不修:方案數只有1,下面的路都是要修的。

則 dp1[u] = (dp1[v1]+1)*(dp1[v2]+1)*……

轉換根節點,現在看dp2,假設要求節點dp2[v1],u為v的父親節點。(dp2[1] = 1)

v1的方案數來自父親節點往上的方案數還有其他兄弟的方案數。

則dp2[v1] = (dp2[u]*(dp1[v2]+1)*(dp1[v3]+1)*……)+1;(加1的原因跟上面一樣)

如果不能理解,就畫個圖感受一下。

1.有的同學可能會想到用乘法逆元來求其他兄弟的方案數,實際上這會WA10;

2.求出每個節點的字首積和字尾積即可避免除法(逆元),如果每次都重新求兄弟方案數積TLE39.

#include <bits/stdc++.h>
using namespace std;

const int N = 2e5+5;
const int MOD = 1e9+7;
typedef long long LL;

int n;
int head[N], cnt = 0;
LL dp1[N], dp2[N];
int f[N];
vector <LL> front[N];
vector <LL> back[N];
vector <LL> tt[N];

struct Edge {
    int v, next;
}e[N<<1];

void addEdge(int u, int v) {
    e[cnt] = (Edge){v, head[u]};
    head[u] = cnt++;
}
    
inline void cal(LL &a, LL b) {
    a = a*b%MOD;
}
void dfs(int u, int par) {
    dp1[u] = 1;
    f[u] = par;
    int cnt = 0;
    for(int i = head[u];i != -1; i = e[i].next) {
        if(e[i].v == par) continue;
        int v = e[i].v;
        dfs(v, u);
        cal(dp1[u], dp1[v]+1);
        tt[u].push_back(dp1[v]+1);
        cnt++;
    }
    LL tmp = 1, tmp2 = 1;
    for(int i = 0, j = cnt-1;i < cnt; i++, j--) {
        cal(tmp, tt[u][i]);
        cal(tmp2, tt[u][j]);
        front[u].push_back(tmp);
        back[u].push_back(tmp2);
    }
    /*
    cout<<"u = "<<u<<endl;
    for(int i = 0;i < cnt; i++) {
        cout<<front[u][i]<<" ";
    }
    cout<<endl;
    for(int i = 0;i < cnt; i++) {
        cout<<back[u][i]<<" ";
    }
    cout<<endl;
    */
}
        
/*
LL just(int v, int u) {
    LL ret = 1;
    for(int i = head[u];i != -1; i = e[i].next) {
        if(e[i].v == v || e[i].v == f[u]) continue;
        int tv = e[i].v;
        cal(ret, dp1[tv]+1);
    }
    cal(ret, dp2[u]);
    ret++;
    return ret;
}
*/
    
LL just(int v, int u, int idx) {
    LL ret = 1;
    if(front[u].size() > 0) {
        if(idx > 0) cal(ret, front[u][idx-1]);
        if(idx < (back[u].size()-1)) {
            int real = back[u].size()-idx-2;
            cal(ret, back[u][real]);
        }
    }
    cal(ret, dp2[u]);
    ret++;
    return ret;
}
        
void dfs2(int u, int par) {
    int idx = 0;
    for(int i = head[u];i != -1; i = e[i].next) {
        if(e[i].v == par) continue;
        int v = e[i].v;
        dp2[v] = just(v, u, idx++);
        dfs2(v, u);
    }
}
        
void solve() {
    dfs(1, -1);
    dp2[1] = 1;
    dfs2(1, -1);
    for(int i = 1;i <= n; i++)
        printf("%I64d%c", dp2[i]*dp1[i]%MOD, i == n?'\n':' ');
}

int main() {
    scanf("%d", &n);
    memset(head, -1, sizeof(head));
    for(int i = 2;i <= n; i++) {
        int v;
        scanf("%d", &v);
        addEdge(v, i);
        addEdge(i, v);
    }
    solve();
    return 0;
}