1. 程式人生 > 實用技巧 >題解「GMOJ6898 【2020.11.27提高組模擬】第二題」

題解「GMOJ6898 【2020.11.27提高組模擬】第二題」

題面

GMOJ6898 【2020.11.27提高組模擬】第二題

題解

考慮非根結點 \(u\) 和它的父親 \(fa_u\) ,若 \(fa_u\) 除了 \(u\) 所在的子樹的其他子樹中有其他黑點 \(v\) ,那麼稱 \(fa_u\)\(v\) 覆蓋,並且 \(fa_u\) 會對以 \(u\) 為根的子樹中的黑色節點產生貢獻。對於每一個結點 \(u\) ,找出最早覆蓋 \(fa_u\) 的結點 \(v\) ,統計出 \(n-1\) 個點對 \((u,v)\)

從小到大加入黑點,查詢當前加入的結點 \(v\) 覆蓋了哪些結點,以及對應的點對 \((u,v)\) ,對 \(u\) 的子樹貢獻加一,子樹加用樹狀陣列維護即可。處理完這些點對後,在樹狀陣列上查詢 \(f_v\)

,計入答案中。因為還要知道 \(v\) 對之前幾個黑點產生了貢獻,所以還要用一個樹狀陣列統計子樹內黑點個數。

\(\text{Code}:\)

#include <cctype>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long lxl;
const int maxn=8e5+5;
const lxl mod=1e9+7;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
template <typename T>
inline void read(T &x)
{
	x=0;T f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
	while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
	x*=f;
}

struct edge
{
	int u,v,next;
	edge(int u,int v,int next):
		u(u),v(v),next(next){}
	edge(){}
}e[maxn];

int head[maxn],ecnt;

inline void add(int u,int v)
{
	e[ecnt]=edge(u,v,head[u]);
	head[u]=ecnt++;
}

int n,fa[maxn];
int dfn[maxn],idx[maxn],siz[maxn],dfs_cnt;
int Min[maxn],sMin[maxn],son[maxn];
pair<int,int> pii[maxn];
int pcnt;

void dfs(int u)
{
	dfn[u]=++dfs_cnt;
	idx[dfs_cnt]=u;
	siz[u]=1;
	Min[u]=u;
	for(int i=head[u];~i;i=e[i].next)
	{
		int v=e[i].v;
		dfs(v);
		siz[u]+=siz[v];
		if(Min[v]<Min[u]) sMin[u]=Min[u],Min[u]=Min[v],son[u]=v;
		else if(Min[v]<sMin[u]) sMin[u]=Min[v];
	}
}

namespace BIT
{
	int sum[maxn];
	inline int lowbit(int x) {return x&-x;}
	inline void add(int x,int d)
	{
		for(int i=x;i<=n;i+=lowbit(i))
			sum[i]+=d;
	}
	inline int query(int x)
	{
		int res=0;
		for(int i=x;i>=1;i-=lowbit(i))
			res+=sum[i];
		return res;
	}
	inline int query(int l,int r)
	{
		return query(r)-query(l-1);
	}
}

namespace Segment_Tree
{
	int sum[maxn];
	inline int lowbit(int x) {return x&-x;}
	inline void add(int x,int d)
	{
		for(int i=x;i<=n;i+=lowbit(i))
			sum[i]+=d;
	}
	inline int query(int x)
	{
		int res=0;
		for(int i=x;i>=1;i-=lowbit(i))
			res+=sum[i];
		return res;
	}
	inline void modify(int l,int r,int d)
	{
		add(l,d);add(r+1,-d);
	}
}

int main()
{
	freopen("dierti.in","r",stdin);
	freopen("dierti.out","w",stdout);
	read(n);
	memset(head,-1,sizeof(head));
	int rt;
	for(int i=1;i<=n;++i)
	{
		read(fa[i]);
		if(fa[i]) add(fa[i],i);
		else rt=i;
	}
	dfs(rt);
	for(int i=1;i<=n;++i) if(fa[i])
		pii[++pcnt]=make_pair(son[fa[i]]==i?sMin[fa[i]]:Min[fa[i]],i);
	sort(pii+1,pii+pcnt+1);
	int ans=1,sum=0;
	for(int i=1,p=1;i<=n;++i)
	{
		while(p<=pcnt&&pii[p].first<=i)
		{
			int u=pii[p].second;
			Segment_Tree::modify(dfn[u],dfn[u]+siz[u]-1,1);
			sum+=BIT::query(dfn[u],dfn[u]+siz[u]-1);
			if(sum>=mod) sum-=mod;
			++p;
		}
		BIT::add(dfn[i],1);
		sum+=Segment_Tree::query(dfn[i])+1;
		if(sum>=mod) sum-=mod;
		ans=1ll*ans*sum%mod;
	}
	printf("%d\n",ans);
	return 0;
}