1. 程式人生 > 實用技巧 >GMOJ 6898. 【2020.11.27提高組模擬】第二題

GMOJ 6898. 【2020.11.27提高組模擬】第二題

題目描述

題解

這題考場時想了特別久,花了很多時間,但是隻想出了\(O(n^2)\)的做法,滿分做法其實不難。

容易發現,一個點如果能變成黑的,當且僅當這個點是黑色或者子樹中有兩個節點是黑色的。

進一步可以發現,對於這個子樹中的點,他們對於這個點的要求是:除自己外的子樹有黑色點,或者這個點是黑色。

也就是隻要這個點能變成黑色,就能對其子樹中所有點產生貢獻。

不妨用一個樹形dp,將每個點變成黑色的最短時間計算出來。然後按編號從小到大加入點,如果在當前編號下某個點可以變成黑色,就將這個點子樹中的點(出這個點以外)對答案的貢獻加一,同時將當前最新變色的點對答案的貢獻加一。然後就能知道每個點首次加進去的答案了。

但是還需要統計對之前的點的貢獻,因為這些點也可能被當前的區間修改,所以再開一個線段樹維護一下就行了。

程式碼

#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 800010
#define ll long long
#define mo 1000000007
using namespace std;
int n,i,j,bfs[N],fa[N],dfn[N],son[N],num,root,q[N],tot,f[N],g[N],stk[N][2],top;
int x,y;
ll sum,ans;
struct node{
	ll lazy,sum,pk;
}tr[N*5];
struct edge{
	int to,next;
}e[N];
struct pl{
	int time,val;
}dp[N];
int read(){
	int x=0;
	char ch=getchar();
	while (ch<'0'||ch>'9') ch=getchar();
	while (ch>='0'&&ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x;
}
void insert_edge(int x,int y){
	tot++;
	e[tot].to=y;
	e[tot].next=q[x];
	q[x]=tot;
}
int cmp(pl x,pl y){
	return x.time<y.time;
}
void update(int x,int l,int r){
	ll p=tr[x].lazy;
	if (l<r){
		int mid=(l+r)/2;
		tr[x*2].sum=(tr[x*2].sum+(mid-l+1)*p)%mo;
		tr[x*2+1].sum=(tr[x*2+1].sum+(r-mid)*p)%mo;
		tr[x*2].lazy=(tr[x*2].lazy+p)%mo;tr[x*2+1].lazy=(tr[x*2+1].lazy+p)%mo;
	}
	tr[x].lazy=0;
}
void change(int x,int l,int r,int l1,int r1){
	if (l1>r1) return;
	update(x,l,r);
	if (l1<=l&&r1>=r){
		tr[x].sum=(tr[x].sum+r-l+1)%mo;
		tr[x].lazy=tr[x].lazy+1;
		return;
	}
	int mid=(l+r)/2;
	if (l1<=mid) change(x*2,l,mid,l1,r1);
	if (r1>mid) change(x*2+1,mid+1,r,l1,r1);
	tr[x].sum=(tr[x*2].sum+tr[x*2+1].sum)%mo;
}
void changepk(int x,int l,int r,int k){
	if (l==r){
		tr[x].pk++;
		return;
	}
	int mid=(l+r)/2;
	if (k<=mid) changepk(x*2,l,mid,k);
	else changepk(x*2+1,mid+1,r,k);
	tr[x].pk=(tr[x*2].pk+tr[x*2+1].pk)%mo;
}
ll gets(int x,int l,int r,int k){
	update(x,l,r);
	if (l==r) return tr[x].sum;
	int mid=(l+r)/2;
	if (k<=mid) return gets(x*2,l,mid,k);
	else return gets(x*2+1,mid+1,r,k);
}
ll getpk(int x,int l,int r,int l1,int r1){
	if (l1>r1) return 0;
	if (l1<=l&&r1>=r) return tr[x].pk;
	int mid=(l+r)/2,pk=0;
	if (l1<=mid) pk=(pk+getpk(x*2,l,mid,l1,r1))%mo;
	if (r1>mid) pk=(pk+getpk(x*2+1,mid+1,r,l1,r1))%mo;
	return pk;
}
int main(){
	freopen("dierti.in","r",stdin);
	freopen("dierti.out","w",stdout);
	n=read();
	for (i=1;i<=n;i++){
		fa[i]=read();
		if (fa[i]==0) root=i;
		else insert_edge(fa[i],i);
	}
	top=1;stk[1][0]=root;stk[1][1]=q[root];
	while (top){
		x=stk[top][0];
		if (stk[top][1]==q[x]){
			f[x]=g[x]=1e9;
			dfn[x]=++num;
		}
		if (!stk[top][1]){
			g[fa[x]]=min(g[fa[x]],min(f[x],x));
			if (g[fa[x]]<f[fa[x]]) swap(f[fa[x]],g[fa[x]]);
			son[x]=num;
			top--;
			continue;
		}
		for (i=stk[top][1];i;i=e[i].next){
			y=e[i].to;
			stk[top][1]=e[i].next;
			stk[++top][0]=y;stk[top][1]=q[y];
			break;
		}
	}
	for (i=1;i<=n;i++) dp[i].time=min(i,g[i]),dp[i].val=i;
	sort(dp+1,dp+n+1,cmp);
	j=0;
	ans=1;
	for (i=1;i<=n;i++){
		change(1,1,n,dfn[i],dfn[i]);
		while (j+1<=n&&dp[j+1].time<=i){
			j++;
			change(1,1,n,dfn[dp[j].val]+1,son[dp[j].val]);
			sum=(sum+getpk(1,1,n,dfn[dp[j].val]+1,son[dp[j].val]))%mo;
		}
		sum=(sum+gets(1,1,n,dfn[i]))%mo;
		changepk(1,1,n,dfn[i]);
		ans=ans*sum%mo;
	}
	printf("%lld\n",ans);
	return 0;
}