P3383 樹鏈剖分模板題
阿新 • • 發佈:2020-08-15
#include<bits/stdc++.h> using namespace std; #define rep(i, a, n) for(int i = a; i <= n; ++ i); #define per(i, a, n) for(int i = n; i >= a; -- i); typedef long long ll; const int N = 2e6+ 5; // const ll mod = 1e9 + 7; int mod; const double Pi = acos(- 1.0); const int INF = 0x3f3f3f3f; const int G = 3, Gi = 332748118; ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; } ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; } ll lcm(ll a, ll b) { return a * b / gcd(a, b);} bool cmp(int a, int b){ return a > b;} // int n, m, r; int head[N], cnt = 0; struct node{ int to, nxt, c; }edge[N << 1]; struct Tree{ int l, r, val, lz; }tree[N * 4]; int val[N], tval[N]; int son[N], siz[N],dfn[N], dep[N], top[N], fa[N], rnk[N]; int res = 0, tot = 0; void add(int u, int v){ edge[cnt].to = v, edge[cnt].nxt = head[u], head[u] = cnt ++; edge[cnt].to = u, edge[cnt].nxt = head[v], head[v] = cnt ++; } void pushdown(int index){ if(tree[index].lz){ tree[index << 1].val += (tree[index << 1].r - tree[index << 1].l + 1) * tree[index].lz % mod; tree[index << 1 | 1].val += (tree[index<<1|1].r - tree[index<<1|1].l + 1) * tree[index].lz % mod; tree[index << 1].lz += tree[index].lz; tree[index << 1 | 1].lz += tree[index].lz; tree[index].lz = 0; } } void pushup(int index){ tree[index].val = (tree[index << 1].val + tree[index << 1 | 1].val) % mod; } void Build(int l, int r, int index){ tree[index].l = l, tree[index].r = r; tree[index].lz = 0; if(l == r){ tree[index].val = tval[l] % mod; return; } int mid = (l + r) >> 1; Build(l, mid, index << 1); Build(mid + 1, r, index << 1 | 1); pushup(index); } void updata(int l, int r, int index, int val){ if(tree[index].l >= l && tree[index].r <= r){ tree[index].lz += val; tree[index].val += (tree[index].r - tree[index].l + 1) * val; tree[index].val %= mod; return; } if(tree[index].lz) pushdown(index); int mid = (tree[index].l + tree[index].r) >> 1; if(l <= mid) updata(l, r, index << 1, val); if(r > mid) updata(l, r, index << 1 | 1, val); pushup(index); } int query(int l, int r, int index){ if(l <= tree[index].l && tree[index].r <= r){ return tree[index].val % mod; } if(tree[index].lz) pushdown(index); int mid = (tree[index].l + tree[index].r) >> 1; int ans = 0; if(l <= mid) ans += query(l, r, index << 1); if(r > mid) ans += query(l, r, index << 1 | 1); return ans % mod; } // -------------------------------- int qRange(int x, int y){ //x 到 y樹上最短路徑結點權值和 int res = 0; while(top[x] != top[y]){ if(dep[top[x]] < dep[top[y]]) swap(x, y); res += query(dfn[top[x]], dfn[x], 1); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x, y); res += query(dfn[x], dfn[y], 1); return res % mod; } void updRange(int x, int y, int c){ //x 到 y最短路徑上點值 + z c %= mod; while(top[x] != top[y]){ if(dep[top[x]] < dep[top[y]]) swap(x, y); updata(dfn[top[x]], dfn[x], 1, c); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x, y); updata(dfn[x], dfn[y], 1, c); } int qSon(int x){ //以x為根結點的子樹內所有節點值之和 return query(dfn[x], dfn[x] + siz[x] - 1, 1); } void updSon(int x, int val){ //以x為根的子樹內所有節點值 + z updata(dfn[x], dfn[x] + siz[x] - 1, 1, val); } void dfs1(int u, int pre){ dep[u] = dep[pre] + 1; fa[u] = pre; siz[u] = 1; int maxx = -1; for(int i = head[u]; i != -1; i = edge[i].nxt){ int v = edge[i].to; if(v == pre) continue; dfs1(v, u); siz[u] += siz[v]; if(siz[v] > maxx){ maxx = siz[v]; son[u] = v; } } } void dfs2(int u, int topu){ //topu當前鏈的最頂端的節點 dfn[u] = ++ tot; tval[tot] = val[u]; top[u] = topu; rnk[tot] = u; if(!son[u]) return; dfs2(son[u], topu); for(int i = head[u]; i != -1; i = edge[i].nxt){ int v = edge[i].to; if(v == son[u] || v == fa[u]) continue; dfs2(v, v); } } int main() { scanf("%d%d%d%d",&n, &m, &r, &mod); cnt = 0; head[0] = -1; for(int i = 1; i <= n; ++ i) { scanf("%d",&val[i]); head[i] = -1; } for(int i = 1; i < n; ++ i){ int x, y; scanf("%d%d",&x,&y); add(x, y); } dfs1(r, r); dfs2(r, r); Build(1, n, 1); while(m --){ int k, x, y, z; scanf("%d",&k); if(k == 1){ //x 到 y最短路徑上點值 + z scanf("%d%d%d",&x,&y,&z); updRange(x, y, z); } else if(k == 2){ //x 到 y樹上最短路徑結點權值和 scanf("%d%d",&x,&y); printf("%d\n",qRange(x, y)); } else if(k == 3){ //以x為根的子樹內所有節點值 + z scanf("%d%d",&x,&y); updSon(x, y); } else{ //以x為根結點的子樹內所有節點值之和 scanf("%d",&x); printf("%d\n",qSon(x)); } } return 0; }