1. 程式人生 > >P5002 專心OI - 找祖先

P5002 專心OI - 找祖先

P5002 專心OI - 找祖先

給定一棵有根樹(\(n \leq 10000\)),\(M \leq 50000\) 次詢問, 求以 \(x\)\(LCA\) 的點對個數


錯誤日誌: 看下面


Solution

設點 \(u\) 的子樹大小為 \(size[u]\)
現詢問以 \(u\)\(LCA\) 的點對個數
\(u\) 的子節點為 \(v_{1}, v_{2},...,v_{m}\)
我們現在統計點對中一個點在 \(v_{1}\) 內的答案數
一個點在 \(v_{1}\) 內, 另一個點要麼是 \(u\) 要麼在剩下的子節點子樹內選一個
所以總數為 \(size[u] - size[v_{1}]\)


乘法原理 \(v_{1}\) 子樹中所有答案為 \(size[v] * (size[u] - size[v])\)
每個子節點點都是這樣
最後加上一個 \((u, u)\) 即可

然後這樣 \(85pnts\)
看資料範圍, 最不優情況 \(O(NM)\)
然後發現 \(M > N\)
詢問次數多於點個數?
這提示我們像記憶化那樣記下答案
回頭想想這樣的複雜度為 \(O(N)\) , 這是一棵樹, 邊數和點數不會差太多, 複雜度有保障

Code

#include<iostream>
#include<cstdio>
#include<queue>
#include<cstring>
#include<algorithm>
#define LL long long
#define REP(i, x, y) for(LL i = (x);i <= (y);i++)
using namespace std;
LL RD(){
    LL out = 0,flag = 1;char c = getchar();
    while(c < '0' || c >'9'){if(c == '-')flag = -1;c = getchar();}
    while(c >= '0' && c <= '9'){out = out * 10 + c - '0';c = getchar();}
    return flag * out;
    }
const LL maxn = 200019,INF = 1e9 + 19, M = 1e9 + 7;
LL head[maxn],nume = 1;
struct Node{
    LL v,dis,nxt;
    }E[maxn << 3];
void add(LL u,LL v,LL dis){
    E[++nume].nxt = head[u];
    E[nume].v = v;
    E[nume].dis = dis;
    head[u] = nume;
    }
LL num, root, na;
LL size[maxn], fa[maxn];
LL mem[maxn];
void dfs(LL u, LL F){
    size[u] = 1;
    for(LL i = head[u];i;i = E[i].nxt){
        LL v = E[i].v;
        if(v == F)continue;
        fa[v] = u;
        dfs(v, u);
        size[u] = (size[u] + size[v]) % M;
        }
    }
void init(){
    num = RD(), root = RD(), na = RD();
    REP(i, 1, num - 1){
        LL u = RD(), v = RD();
        add(u, v, 1), add(v, u, 1);
        }
    dfs(root, -1);
    }
void solve(){
    while(na--){
        LL ans = 0, u = RD(), temp = size[u];
        if(mem[u]){printf("%lld\n", mem[u]);continue;}
        for(LL i = head[u];i;i = E[i].nxt){
            LL v = E[i].v;
            if(v == fa[u])continue;
            ans = (ans + ((size[v] * (temp - size[v])) % M + M) % M) % M;
            temp = ((temp - size[v]) % M + M) % M;
            }
        ans = (ans * 2 + 1) % M;
        mem[u] = ans;
        printf("%lld\n", ans);
        }
    }
int main(){
    init();
    solve();
    return 0;
    }