【洛谷P5311】成都七中
阿新 • • 發佈:2021-10-01
題目
題目連結:https://www.luogu.com.cn/problem/P5311
給你一棵 \(n\) 個節點的樹,每個節點有一種顏色,有 \(m\) 次查詢操作。
查詢操作給定引數 \(l\ r\ x\),需輸出:
將樹中編號在 \([l,r]\) 內的所有節點保留,\(x\) 所在連通塊中顏色種類數。
每次查詢操作獨立。
\(n,m\leq 10^5\)。
思路
鬼能想到這道題是點分樹啊。
點分樹有一個性質:對於原樹上的一個連通塊,這個連通塊一定存在一個點,使得點分樹上這個點的子樹內,包含了連通塊內所有的點。
反證法。如果不存在這樣的點,設連通塊所有點在點分樹內深度最小的點為 \(x\)
那麼可以把每一個詢問對應到 \(x\) 所在連通塊內點分樹上深度最小的點。
然後對於一個點 \(x\),考慮求出所有對應到他的詢問。可以遍歷點分樹子樹內所有點,對於一個點 \(y\),求出原樹中 \(x\) 到 \(y\) 路徑上點的編號的最小值和最大值。分別記為 \(mn_y\) 和 \(mx_y\)。
然後對於一個詢問 \(l,r\),滿足 \(l\leq mn_y,r\geq mx_y\) 的不同顏色數。這個東西最暴力的做法是把顏色單獨看作一維然後三維數點,算上點分樹的複雜度是 \(O(n\log^3 n)\)
把所有詢問和點都扔到一起,按照 \(mn\)(詢問是 \(l\))從大到小排序。然後依次列舉所有的點(詢問),記錄目前每個顏色的 \(mx\) 的最小值,遇到詢問的時候就只需要查詢每個顏色 \(mx\) 最小值 \(\leq r\) 的數量。樹狀陣列維護即可。
時間複雜度 \(O(n\log^2n)\)。
程式碼
#include <bits/stdc++.h> using namespace std; const int N=100010,Inf=1e9; int n,m,Q,rt,tot,a[N],id[N],ans[N],dep[N],minn[N],head[N],siz[N],maxp[N],fat[N]; bool vis[N]; vector<int> qry[N]; struct edge { int next,to; }e[N*2]; struct node { int l,r,id; }b[N],c[N*2]; void add(int from,int to) { e[++tot]=(edge){head[from],to}; head[from]=tot; } void findrt(int x,int fa,int sum) { siz[x]=1; maxp[x]=0; for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (!vis[v] && v!=fa) { findrt(v,x,sum); siz[x]+=siz[v]; maxp[x]=max(maxp[x],siz[v]); } } maxp[x]=max(maxp[x],sum-siz[x]); if (!rt || maxp[x]<maxp[rt]) rt=x; } void dfs1(int x,int sum) { vis[x]=1; for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (!vis[v]) { int s=(siz[v]>siz[x])?(sum-siz[x]):siz[v]; rt=0; findrt(v,x,s); fat[rt]=x; dep[rt]=dep[x]+1; dfs1(rt,s); } } } void dfs2(int x,int fa,int d,int mn,int mx) { c[++m]=(node){mn,mx,-a[x]}; for (int i=0;i<(int)qry[x].size();i++) if (qry[x][i] && b[qry[x][i]].l<=mn && b[qry[x][i]].r>=mx) c[++m]=b[qry[x][i]],qry[x][i]=0; for (int i=head[x];~i;i=e[i].next) { int v=e[i].to; if (dep[v]>d && v!=fa) dfs2(v,x,d,min(mn,v),max(mx,v)); } } bool cmp(node x,node y) { if (x.l!=y.l) return x.l>y.l; return x.id<y.id; } bool cmp2(int x,int y) { return dep[x]<dep[y]; } struct BIT { int c[N]; void add(int x,int v) { for (int i=x;i<=n;i+=i&-i) c[i]+=v; } int query(int x) { int ans=0; for (int i=x;i;i-=i&-i) ans+=c[i]; return ans; } }bit; int main() { memset(head,-1,sizeof(head)); scanf("%d%d",&n,&Q); for (int i=1;i<=n;i++) scanf("%d",&a[i]); for (int i=1,x,y;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } findrt(1,0,n); dfs1(rt,n); for (int i=1,x;i<=Q;i++) { scanf("%d%d%d",&b[i].l,&b[i].r,&x); b[i].id=i; qry[x].push_back(i); } for (int i=1;i<=n;i++) id[i]=i; sort(id+1,id+1+n,cmp2); memset(minn,0x3f3f3f3f,sizeof(minn)); for (int k=1;k<=n;k++) { int i=id[k]; m=0; dfs2(i,0,dep[i],i,i); sort(c+1,c+1+m,cmp); for (int j=1;j<=m;j++) if (c[j].id>0) ans[c[j].id]=bit.query(c[j].r); else if (c[j].r<minn[-c[j].id]) { if (minn[-c[j].id]<Inf) bit.add(minn[-c[j].id],-1); minn[-c[j].id]=c[j].r; bit.add(c[j].r,1); } for (int j=0;j<=m;j++) if (c[j].id<0 && minn[-c[j].id]<Inf) { bit.add(minn[-c[j].id],-1); minn[-c[j].id]=Inf; } } for (int i=1;i<=Q;i++) cout<<ans[i]<<"\n"; return 0; }