GMOJ 6898. 【2020.11.27提高組模擬】第二題
阿新 • • 發佈:2020-11-27
題目描述
題解
這題考場時想了特別久,花了很多時間,但是隻想出了\(O(n^2)\)的做法,滿分做法其實不難。
容易發現,一個點如果能變成黑的,當且僅當這個點是黑色或者子樹中有兩個節點是黑色的。
進一步可以發現,對於這個子樹中的點,他們對於這個點的要求是:除自己外的子樹有黑色點,或者這個點是黑色。
也就是隻要這個點能變成黑色,就能對其子樹中所有點產生貢獻。
不妨用一個樹形dp,將每個點變成黑色的最短時間計算出來。然後按編號從小到大加入點,如果在當前編號下某個點可以變成黑色,就將這個點子樹中的點(出這個點以外)對答案的貢獻加一,同時將當前最新變色的點對答案的貢獻加一。然後就能知道每個點首次加進去的答案了。
但是還需要統計對之前的點的貢獻,因為這些點也可能被當前的區間修改,所以再開一個線段樹維護一下就行了。
程式碼
#include<cstdio> #include<cstring> #include<algorithm> #define N 800010 #define ll long long #define mo 1000000007 using namespace std; int n,i,j,bfs[N],fa[N],dfn[N],son[N],num,root,q[N],tot,f[N],g[N],stk[N][2],top; int x,y; ll sum,ans; struct node{ ll lazy,sum,pk; }tr[N*5]; struct edge{ int to,next; }e[N]; struct pl{ int time,val; }dp[N]; int read(){ int x=0; char ch=getchar(); while (ch<'0'||ch>'9') ch=getchar(); while (ch>='0'&&ch<='9'){ x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x; } void insert_edge(int x,int y){ tot++; e[tot].to=y; e[tot].next=q[x]; q[x]=tot; } int cmp(pl x,pl y){ return x.time<y.time; } void update(int x,int l,int r){ ll p=tr[x].lazy; if (l<r){ int mid=(l+r)/2; tr[x*2].sum=(tr[x*2].sum+(mid-l+1)*p)%mo; tr[x*2+1].sum=(tr[x*2+1].sum+(r-mid)*p)%mo; tr[x*2].lazy=(tr[x*2].lazy+p)%mo;tr[x*2+1].lazy=(tr[x*2+1].lazy+p)%mo; } tr[x].lazy=0; } void change(int x,int l,int r,int l1,int r1){ if (l1>r1) return; update(x,l,r); if (l1<=l&&r1>=r){ tr[x].sum=(tr[x].sum+r-l+1)%mo; tr[x].lazy=tr[x].lazy+1; return; } int mid=(l+r)/2; if (l1<=mid) change(x*2,l,mid,l1,r1); if (r1>mid) change(x*2+1,mid+1,r,l1,r1); tr[x].sum=(tr[x*2].sum+tr[x*2+1].sum)%mo; } void changepk(int x,int l,int r,int k){ if (l==r){ tr[x].pk++; return; } int mid=(l+r)/2; if (k<=mid) changepk(x*2,l,mid,k); else changepk(x*2+1,mid+1,r,k); tr[x].pk=(tr[x*2].pk+tr[x*2+1].pk)%mo; } ll gets(int x,int l,int r,int k){ update(x,l,r); if (l==r) return tr[x].sum; int mid=(l+r)/2; if (k<=mid) return gets(x*2,l,mid,k); else return gets(x*2+1,mid+1,r,k); } ll getpk(int x,int l,int r,int l1,int r1){ if (l1>r1) return 0; if (l1<=l&&r1>=r) return tr[x].pk; int mid=(l+r)/2,pk=0; if (l1<=mid) pk=(pk+getpk(x*2,l,mid,l1,r1))%mo; if (r1>mid) pk=(pk+getpk(x*2+1,mid+1,r,l1,r1))%mo; return pk; } int main(){ freopen("dierti.in","r",stdin); freopen("dierti.out","w",stdout); n=read(); for (i=1;i<=n;i++){ fa[i]=read(); if (fa[i]==0) root=i; else insert_edge(fa[i],i); } top=1;stk[1][0]=root;stk[1][1]=q[root]; while (top){ x=stk[top][0]; if (stk[top][1]==q[x]){ f[x]=g[x]=1e9; dfn[x]=++num; } if (!stk[top][1]){ g[fa[x]]=min(g[fa[x]],min(f[x],x)); if (g[fa[x]]<f[fa[x]]) swap(f[fa[x]],g[fa[x]]); son[x]=num; top--; continue; } for (i=stk[top][1];i;i=e[i].next){ y=e[i].to; stk[top][1]=e[i].next; stk[++top][0]=y;stk[top][1]=q[y]; break; } } for (i=1;i<=n;i++) dp[i].time=min(i,g[i]),dp[i].val=i; sort(dp+1,dp+n+1,cmp); j=0; ans=1; for (i=1;i<=n;i++){ change(1,1,n,dfn[i],dfn[i]); while (j+1<=n&&dp[j+1].time<=i){ j++; change(1,1,n,dfn[dp[j].val]+1,son[dp[j].val]); sum=(sum+getpk(1,1,n,dfn[dp[j].val]+1,son[dp[j].val]))%mo; } sum=(sum+gets(1,1,n,dfn[i]))%mo; changepk(1,1,n,dfn[i]); ans=ans*sum%mo; } printf("%lld\n",ans); return 0; }