題解 Count on a tree II/【模板】樹分塊
阿新 • • 發佈:2021-10-19
Description
給出一個大小為 \(n\) 的樹,每個點有點權,有 \(m\) 次查詢,每次查詢 \(u\to v\) 的不同點權個數。強制線上。
\(n\le 4\times 10^4,m\le 10^5\)
Solution
不知道這是不是正宗的樹分塊。
我們考慮假如我們能取出約 \(\Theta(n/B)\) 個點,使得任意一個點到其最近的一個點距離都 \(\le B\),那麼我們就可以提前處理任意兩兩這些點預處理資訊,再把剩下的距離手動加進去。這樣的話我們複雜度就可以做到 \(\Theta(n^2/B^2\times t+qB)\) 。其中 \(t\) 是一次預處理的複雜度。
我們發現其實 \(B=\sqrt n\) 的時候最優。那我們怎麼取呢?我們發現直接在 \(\text{dep}\equiv 0\pmod{B}\) 的時候就好了,因為顯然。
然後這個題目的話我們可以用 bitset 來維護一下就好了。稍微有點卡空間和卡時間,\(B=800\) 的時候比較優秀。
Code
#include <bits/stdc++.h> using namespace std; #define Int register int #define MAXN 40005 #define MAXM 205 template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;} template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);} template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');} template <typename T> inline void chkmax (T &a,T b){a = max (a,b);} template <typename T> inline void chkmin (T &a,T b){a = min (a,b);} vector <int> h[MAXN]; int n,m,uni,cnt,ind,B,dfn[MAXN],siz[MAXN],val[MAXN],tmp[MAXN],pat[MAXN],dep[MAXN],tur[MAXN],par[MAXN][21]; void dfs1 (int u,int fa){ dep[u] = dep[fa] + 1,par[u][0] = fa,dfn[u] = ++ ind,siz[u] = 1; if (dep[u] % B == 0) tur[u] = ++ cnt,tmp[cnt] = u; for (Int i = 1;i <= 20;++ i) par[u][i] = par[par[u][i - 1]][i - 1]; for (Int v : h[u]) if (v ^ fa) dfs1 (v,u),siz[u] += siz[v]; } int hav[MAXM][MAXM],have; bitset <MAXN> Now,con[MAXM][MAXM]; void dfs2 (int st,int u,int fa){ int lst = Now[val[u]];have += (!Now[val[u]]),Now[val[u]] = 1; if (tur[u]) con[st][tur[u]] = Now,hav[st][tur[u]] = have; for (Int v : h[u]) if (v ^ fa) dfs2 (st,v,u); Now[val[u]] = lst,have -= (!lst); } int getlca (int u,int v){ if (dep[u] < dep[v]) swap (u,v); for (Int i = 20,dis = dep[u] - dep[v];~i;-- i) if (dis >> i & 1) u = par[u][i]; if (u == v) return u; else{ for (Int i = 20;~i;-- i) if (par[u][i] ^ par[v][i]) u = par[u][i],v = par[v][i]; return par[u][0]; } } bool checkin (int u,int v){return dfn[u] <= dfn[v] && dfn[v] <= dfn[u] + siz[u] - 1;} int query (int u,int v){ int lca = getlca (u,v),tmp1 = 0,tmp2 = 0,tmpu = u,tmpv = v; while(dep[tmpu] >= dep[lca]){ if (tur[tmpu]){ tmp1 = tur[tmpu]; break; } tmpu = par[tmpu][0]; } while (dep[tmpv] >= dep[lca]){ if (tur[tmpv]){ tmp2 = tur[tmpv]; break; } tmpv = par[tmpv][0]; } if (tmp1 && tmp2){ int ans = hav[tmp1][tmp2];Now = con[tmp1][tmp2]; for (Int f1 = u;f1 != tmpu;f1 = par[f1][0]) ans += (!Now[val[f1]]),Now[val[f1]] = 1; for (Int f2 = v;f2 != tmpv;f2 = par[f2][0]) ans += (!Now[val[f2]]),Now[val[f2]] = 1; return ans; } else if (!tmp1 && !tmp2){ Now.reset ();int ans = 0; for (Int f1 = u;dep[f1] >= dep[lca];f1 = par[f1][0]) ans += (!Now[val[f1]]),Now[val[f1]] = 1; for (Int f2 = v;dep[f2] >= dep[lca];f2 = par[f2][0]) ans += (!Now[val[f2]]),Now[val[f2]] = 1; return ans; } else{ if (tmp2) swap (tmp1,tmp2),swap (u,v),swap (tmpu,tmpv);tmp2 = 0,tmpv = 0; for (Int i = 1;i <= cnt;++ i) if (checkin (tmp[i],tmpu) && checkin (lca,tmp[i])){ if (tmp2 == 0 || dep[tmp[tmp2]] > dep[tmp[i]]) tmpv = tmp[i],tmp2 = i; } int ans = hav[tmp1][tmp2];Now = con[tmp1][tmp2]; for (Int f1 = tmpv;dep[f1] >= dep[lca];f1 = par[f1][0]) ans += (!Now[val[f1]]),Now[val[f1]] = 1; for (Int f2 = v;dep[f2] >= dep[lca];f2 = par[f2][0]) ans += (!Now[val[f2]]),Now[val[f2]] = 1; for (Int f3 = u;f3 != tmpu;f3 = par[f3][0]) ans += (!Now[val[f3]]),Now[val[f3]] = 1; return ans; } } signed main(){ read (n,m),B = 800; for (Int i = 1;i <= n;++ i) read (val[i]),tmp[i] = val[i]; sort (tmp + 1,tmp + n + 1),uni = unique (tmp + 1,tmp + n + 1) - tmp - 1; for (Int i = 1;i <= n;++ i) val[i] = lower_bound (tmp + 1,tmp + n + 1,val[i]) - tmp; for (Int i = 2,u,v;i <= n;++ i) read (u,v),h[u].push_back (v),h[v].push_back (u); cnt = 0,dfs1 (1,0); for (Int i = 1;i <= cnt;++ i) dfs2 (i,tmp[i],0); int lstans = 0; while (m --> 0){ int u,v;read (u,v),u ^= lstans; write (lstans = query (u,v)),putchar ('\n'); } return 0; }