1. 程式人生 > 其它 >【洛谷P5311】成都七中

【洛谷P5311】成都七中

題目

題目連結:https://www.luogu.com.cn/problem/P5311
給你一棵 \(n\) 個節點的樹,每個節點有一種顏色,有 \(m\) 次查詢操作。
查詢操作給定引數 \(l\ r\ x\),需輸出:
將樹中編號在 \([l,r]\) 內的所有節點保留,\(x\) 所在連通塊中顏色種類數。
每次查詢操作獨立。
\(n,m\leq 10^5\)

思路

鬼能想到這道題是點分樹啊。
點分樹有一個性質:對於原樹上的一個連通塊,這個連通塊一定存在一個點,使得點分樹上這個點的子樹內,包含了連通塊內所有的點。
反證法。如果不存在這樣的點,設連通塊所有點在點分樹內深度最小的點為 \(x\)

,那麼必然存在另一個連通塊內的點 \(y\) 不在 \(x\) 的子樹內,那麼原樹從 \(x\)\(y\) 的路徑上,一定存在一個點,在點分樹上的深度小於 \(x\) 的深度。矛盾。
那麼可以把每一個詢問對應到 \(x\) 所在連通塊內點分樹上深度最小的點。
然後對於一個點 \(x\),考慮求出所有對應到他的詢問。可以遍歷點分樹子樹內所有點,對於一個點 \(y\),求出原樹中 \(x\)\(y\) 路徑上點的編號的最小值和最大值。分別記為 \(mn_y\)\(mx_y\)
然後對於一個詢問 \(l,r\),滿足 \(l\leq mn_y,r\geq mx_y\) 的不同顏色數。這個東西最暴力的做法是把顏色單獨看作一維然後三維數點,算上點分樹的複雜度是 \(O(n\log^3 n)\)
,無法接受。
把所有詢問和點都扔到一起,按照 \(mn\)(詢問是 \(l\))從大到小排序。然後依次列舉所有的點(詢問),記錄目前每個顏色的 \(mx\) 的最小值,遇到詢問的時候就只需要查詢每個顏色 \(mx\) 最小值 \(\leq r\) 的數量。樹狀陣列維護即可。
時間複雜度 \(O(n\log^2n)\)

程式碼

#include <bits/stdc++.h>
using namespace std;

const int N=100010,Inf=1e9;
int n,m,Q,rt,tot,a[N],id[N],ans[N],dep[N],minn[N],head[N],siz[N],maxp[N],fat[N];
bool vis[N];
vector<int> qry[N];

struct edge
{
	int next,to;
}e[N*2];

struct node
{
	int l,r,id;
}b[N],c[N*2];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

void findrt(int x,int fa,int sum)
{
	siz[x]=1; maxp[x]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v] && v!=fa)
		{
			findrt(v,x,sum);
			siz[x]+=siz[v]; maxp[x]=max(maxp[x],siz[v]);
		}
	}
	maxp[x]=max(maxp[x],sum-siz[x]);
	if (!rt || maxp[x]<maxp[rt]) rt=x;
}

void dfs1(int x,int sum)
{
	vis[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v])
		{
			int s=(siz[v]>siz[x])?(sum-siz[x]):siz[v];
			rt=0; findrt(v,x,s);
			fat[rt]=x; dep[rt]=dep[x]+1;
			dfs1(rt,s);
		}
	}
}

void dfs2(int x,int fa,int d,int mn,int mx)
{
	c[++m]=(node){mn,mx,-a[x]};
	for (int i=0;i<(int)qry[x].size();i++)
		if (qry[x][i] && b[qry[x][i]].l<=mn && b[qry[x][i]].r>=mx)
			c[++m]=b[qry[x][i]],qry[x][i]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (dep[v]>d && v!=fa)
			dfs2(v,x,d,min(mn,v),max(mx,v));
	}
}

bool cmp(node x,node y)
{
	if (x.l!=y.l) return x.l>y.l;
	return x.id<y.id;
}

bool cmp2(int x,int y)
{
	return dep[x]<dep[y];
}

struct BIT
{
	int c[N];
	
	void add(int x,int v)
	{
		for (int i=x;i<=n;i+=i&-i)
			c[i]+=v;
	}
	
	int query(int x)
	{
		int ans=0;
		for (int i=x;i;i-=i&-i)
			ans+=c[i];
		return ans;
	}
}bit;

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&Q);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	findrt(1,0,n);
	dfs1(rt,n);
	for (int i=1,x;i<=Q;i++)
	{
		scanf("%d%d%d",&b[i].l,&b[i].r,&x);
		b[i].id=i;
		qry[x].push_back(i);
	}
	for (int i=1;i<=n;i++) id[i]=i;
	sort(id+1,id+1+n,cmp2);
	memset(minn,0x3f3f3f3f,sizeof(minn));
	for (int k=1;k<=n;k++)
	{
		int i=id[k]; m=0;
		dfs2(i,0,dep[i],i,i);
		sort(c+1,c+1+m,cmp);
		for (int j=1;j<=m;j++)
			if (c[j].id>0)
				ans[c[j].id]=bit.query(c[j].r);
			else if (c[j].r<minn[-c[j].id])
			{
				if (minn[-c[j].id]<Inf) bit.add(minn[-c[j].id],-1);
				minn[-c[j].id]=c[j].r;
				bit.add(c[j].r,1);
			}
		for (int j=0;j<=m;j++)
			if (c[j].id<0 && minn[-c[j].id]<Inf)
			{
				bit.add(minn[-c[j].id],-1);
				minn[-c[j].id]=Inf;
			}
	}
	for (int i=1;i<=Q;i++)
		cout<<ans[i]<<"\n";
	return 0;
}