Luogu P3258 [JLOI2014]松鼠的新家
阿新 • • 發佈:2020-08-06
思路
這道題我一開始做的時候並不會樹上差分,然後就卡了好久……
首先是樹上差分,這個東西和普通序列上的差分大同小異。設\(sum[i]\)為差分陣列,那麼\(sum[i]\)表示的是\(i\)這個點到根節點上所有值的和。若要對\(x \sim y\)這條鏈上所有的點都加上\(v\),
那麼就要對\(sum[x]+=v , sum[y]+=v , sum[lca(x,y)]-=v , sum[fa(lca(x,y))]-=v\) 。(這一部分建議自行畫圖理解一下,和普通差分的原理相似)
關於LCA,應該就沒啥好說的,一般就是倍增求LCA和樹剖求LCA(偶爾也會有用RMQ求LCA)。這裡我使用的是樹剖求LCA(比較快嘛)。
最後統計答案,就是進行一遍DFS,每個結點的權值即為該節點的子樹的權值和(差分和字首和互為逆運算)。但是這個題還有一個坑點,就是每次對兩個點進行操作時,\(2 \sim (n-1)\)這個區間
中的所有點就被重複加了,所以在最後要對這些點的權值減去\(1\) 。並且根據題目中所說,在第\(n\)個節點是不需要糖的,所以\(n\)這個結點的權值同樣也要減\(1\)。
Code
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #define MAXN 300005 int n, a[MAXN], cnt; int top[MAXN], son[MAXN], fa[MAXN]; int siz[MAXN], dep[MAXN], sum[MAXN]; class node{ public: int to; node *nxt = NULL; } edge[MAXN << 1], *head[MAXN]; inline int read(void){ int f = 1, x = 0;char ch; do{ch = getchar();if(ch=='-')f = -1;} while (ch < '0' || ch > '9'); do{ x = x * 10 + ch - '0';ch = getchar();} while (ch >= '0' && ch <= '9'); return f * x; } inline void add_edge(int x,int y){ ++cnt; edge[cnt].nxt = head[x]; head[x] = &edge[cnt]; head[x]->to = y; return; } void DFS1(int u){ siz[u] = 1; dep[u] = dep[fa[u]] + 1; for (node *i = head[u]; i != NULL;i = i->nxt){ int v = i->to; if(v==fa[u]) continue; fa[v] = u; DFS1(v); siz[u] += siz[v]; if(!son[u]||siz[son[u]]<siz[v]) son[u] = v; } }//樹剖的第一個DFS void DFS2(int u,int idx){ top[u] = idx; if(son[u]) DFS2(son[u], idx); for (node *i = head[u]; i != NULL;i = i->nxt){ int v = i->to; if(v==fa[u]||v==son[u]) continue; DFS2(v, v); } }//... inline int LCA(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]>=dep[top[y]]) x = fa[top[x]]; else y = fa[top[y]]; } return dep[x] < dep[y] ? x : y; }//... void DFS3(int u){ for (node *i = head[u]; i != NULL;i = i->nxt){ int v = i->to; if(v==fa[u]) continue; DFS3(v); sum[u] += sum[v]; } }//統計答案,把差分陣列還原 int main(){ n = read(); for (int i = 1; i <= n;++i) a[i] = read(); for (int i = 1; i < n;++i){ int x = read(), y = read(); add_edge(x, y), add_edge(y, x);//不要忘了是加雙向邊 } DFS1(1), DFS2(1, 1); for (int i = 1; i < n;++i){ ++sum[a[i]], ++sum[a[i + 1]]; int la = LCA(a[i], a[i + 1]); --sum[la], --sum[fa[la]];//差分部分 // std::cout << a[i] << ' ' << a[i + 1]<<' ' << la << '\n'; } // for (int i = 1; i <= n;++i) // std::cout << sum[i] << ' '; // puts(""); DFS3(1);//統計答案 for (int i = 2; i <= n;++i) --sum[a[i]];//減去重複和不需要的點 for (int i = 1; i <= n;++i) printf("%d\n", sum[i]); return 0; }