2020 牛客多校7 C A National Pandemic(樹鏈剖分)
題意
有\(T\)組樣例,第二行給\(n\),\(m\),分別為點的個數和詢問個數,接下來\(n-1\)行為邊,之後是\(m\)行詢問
詢問有三種,\(1\) \(x\) \(w\), 表示點\(x\)增加\(w\),其他所有點增減\(w-dis(x,y)\)(點\(y\)樹上到\(x\)的距離)
\(2\) \(x\),表示將點\(x\)的權值變為\(min(F(x), 0)\)
\(3\) \(x\),表示詢問\(x\)的權值\(F(x)\)
解法
對整顆樹進行樹鏈剖分(對各點做標記的部分只有兩個dfs而已,不是很難),之後將重新標號的點對映到線段樹上,並用線段樹維護每個點的權值(除了2操作可以使用另外一個數組來維護每個點被減去的權值),操作1我們只需要維護一個全域性變數,然後對x到根上所有的點的點權全部+2,這樣我們就可以假裝所有的1操作都是在根上操作的了( ̄▽ ̄)/
先拿例題中的圖舉例
首先將樹剖成三條鏈
經過操作\(1\) \(1\) \(5\)之後,全域性變數變為\(4\),\(1\)的點權變為\(2\),又經過\(2\) \(1\), \(1\)的點權修正值為\(-5\)
再經過\(1\) \(2\) \(7\)之後,\(1\), \(2\)
這樣操作的時候,比如你要求點\(3\)的值,就是\(全域性權重-dep[3] * 2 + queryPath(1,3) = 9\), 因為有兩次操作\(1\),而權值公式為\(w-dis(x,y)\),所以減去兩次\(dep[3]\),最後我們補上多減的值就是了。
又比如我們詢問\(2\), 權值如上計算是\(11\),按照最基本的式子原本應該是\(全域性權重-cnt1*dep[2]\), 而因為\(1\),\(2\)都曾進行過一次操作,我們應該將操作點移到根後點\(2\)失去的貢獻加回來,也就是每經過一個點,我們就將權重\(+2\)
程式碼
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N = 5e4 + 7;
int t;
int n, m, u, v, op, w;
int head[N], cnte = 0, idx = 0;
int dep[N], dfn[N], fa[N], top[N], wson[N], siz[N], id[N];
ll addw[N], global_wt;
struct Edge{
int to, nxt, val;
} edge[N << 1];
void addedge(int u, int v)
{
edge[++cnte].nxt = head[u];
edge[cnte].to = v;
edge[cnte].val = 0;
head[u] = cnte;
}
struct segment_tree
{
#define lson rt << 1
#define rson rt << 1 | 1
ll sum[N << 2], add[N << 2];
void udp(int rt, ll v, int l, int r)
{
sum[rt] += (r - l + 1ll) * v;
add[rt] += v;
}
void build(int l, int r, int rt)
{
sum[rt] = add[rt] = 0;
if (l == r) return;
int mid = (l + r) >> 1;
build(l, mid, lson);
build(mid + 1, r, rson);
}
void push_down(int rt, int l, int r)
{
if(add[rt])
{
int mid = (l + r) >> 1;
udp(lson, add[rt], l, mid);
udp(rson, add[rt], mid + 1, r);
add[rt] = 0;
}
}
void update(int L, int R, ll c, int l, int r, int rt)
{
if (L == l && r == R )
{
udp(rt, c, l, r);
return;
}
int mid = (l + r) >> 1;
push_down(rt, l, r);
if (L <= mid) update(L, min(R,mid), c, l, mid, lson);
if (R > mid) update(max(L,mid + 1), R, c, mid + 1, r, rson);
sum[rt] = sum[lson] + sum[rson];
}
ll query(int L, int R, int l, int r, int rt)
{
if (L == l && r == R) return sum[rt];
int mid = (l + r) >> 1;
push_down(rt, l, r);
ll res = 0;
if (R <= mid) res += query(L, R, l, mid, lson);
else if (L > mid) res += query(L, R, mid + 1, r, rson);
else return query(L, mid, l, mid, lson) + query(mid + 1, R, mid + 1, r, rson);
}
} tree;
void init()
{
cnte = 0;
idx = 0;
global_wt = 0;
fill(head, head + N, 0);
fill(fa, fa + N, 0);
fill(siz, siz + N, 0);
fill(wson, wson + N, 0);
fill(addw, addw + N, 0);
}
void dfs1(int u)
{
siz[u] = 1;
for (int i = head[u]; i;i = edge[i].nxt)
{
int v = edge[i].to;
if(v == fa[u]) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs1(v);
siz[u] += siz[v];
if(siz[v] > siz[wson[u]]) wson[u] = v;
}
}
void dfs2(int u, int chain)
{
id[u] = ++idx, dfn[idx] = u;
top[u] = chain;
if(wson[u] != 0) dfs2(wson[u], chain);
for (int i = head[u]; i; i = edge[i].nxt)
{
int v = edge[i].to;
if(v == fa[u] || v == wson[u]) continue;
dfs2(v, v);
}
}
void udpPath(int u, int v, int val)
{
while(top[u] != top[v])
{
if(dep[top[u]] < dep[top[v]]) swap(u, v);
tree.update(id[top[u]], id[u], val, 1, n, 1);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u, v);
tree.update(id[v], id[u], val, 1, n, 1);
}
ll queryPath(int u, int v)
{
ll res = 0;
while(top[u] != top[v])
{
if(dep[top[u]] < dep[top[v]]) swap(u, v);
res += tree.query(id[top[u]], id[u], 1, n, 1);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u, v);
res += tree.query(id[v], id[u], 1, n, 1);
return res;
}
void solve()
{
scanf("%d %d", &n, &m);
for (int i = 1; i < n; i++)
{
scanf("%d %d", &u, &v);
addedge(u, v);
addedge(v, u);
}
tree.build(1, n, 1);
dep[1] = 1;
dfs1(1);
dfs2(1, 1);
int cnt1 = 0;
while(m--)
{
scanf("%d", &op);
if(op == 1)
{
scanf("%d %d", &u, &w);
global_wt += w;
global_wt -= dep[u];
cnt1++;
udpPath(1, u, 2);
}
else if(op == 2)
{
scanf("%d", &u);
ll weight = addw[u] - 1ll * cnt1 * dep[u] + global_wt + queryPath(1, u);
if(weight > 0) addw[u] -= weight;
}
else
{
scanf("%d", &u);
ll weight = addw[u] - 1ll * cnt1 * dep[u] + global_wt + queryPath(1, u);
printf("%lld\n", weight);
}
}
}
int main()
{
scanf("%d", &t);
while(t--)
{
init();
solve();
}
return 0;
}