1. 程式人生 > 實用技巧 >codeforces-102501J Counting Trees題解

codeforces-102501J Counting Trees題解

題意:給你一個二叉樹的中序遍歷$(n<=1000000)$,節點值為$1$~$1000000$的整數,可重複。問有多少種樹的形態滿足每個子樹的根節點的權值小於等於該子樹所有節點的權值。模$1e9+7$。

比較顯然的一點是如果n個數各不相同,那麼答案是1。每次我們找出當前區間的最小的數,把他作為根節點,然後將左右兩個區間遞迴即可。

如果n個數都相同,那麼答案就是第n個斯特林數。

斯特林數的遞推式為$C_{n+1}=C_0C_n+C_1C_{n-1}+……+C_nC_0$

正好對應了列舉第幾個數作為根節點的方案數。

經過仔細分析,假設當前區間為$(l,r)$,區間$(l,r)$的答案是$ans(l,r)$,最小值為$x$,$x$出現的位置為$x_1,x_2,x_3……,x_cnt$,可以發現,$ans(l,r)={ans}(l,x_1-1)*{ans}(x_1+1,x_2-1)*……*ans (x_{cnt-1}+1,x_{cnt}-1)*ans(x_cnt+1,r)*C_{cnt}$

所以,我們對區間1~n進行dfs即可。由於每個數最多會造成對兩個區間的dfs,所以預處理斯特林數,構建st表以O(1)求區間最小值之後複雜度為$O(n)$。

具體實現方法為用$vector[x]$儲存所有$x$出現的位置。由於我們的區間是從左到右進行dfs的,所以在計算第 $i$個$x$出現的位置時前$i-1$個$x$已經處理完了,我們設定一個$pos[x]$表示當前該處理到第幾個$x$就可以了。

  1 #include <bits/stdc++.h>
  2 #include<map>
  3 #define N 2000005
  4 using namespace std;
5 int n,A[N/2],zz,B[N/2]; 6 const int p=1e9+7; 7 map<int,int> ma; 8 int st[N/2][22],xp[N/2]; 9 int pos[N/2]; 10 long long jc[N],ni[N]; 11 vector<int> q1[N/2]; 12 long long ksm(long long x,long long z) 13 { 14 long long ans=1; 15 while(z) 16 { 17 if(z&1) ans=ans*x%p;
18 x=x*x%p; 19 z>>=1; 20 } 21 return ans; 22 } 23 long long get_C(int x,int y) 24 { 25 if(y>x)return 0; 26 return jc[x]*ni[y]%p*ni[x-y]%p; 27 } 28 long long get_E(int x) 29 { 30 return get_C(x*2,x)*ni[x+1]%p*jc[x]%p; 31 } 32 int get_mn(int L,int R) 33 { 34 int l=xp[R-L+1]; 35 return min(st[L][l],st[R-(1<<l)+1][l]); 36 } 37 long long dfs(int L,int R) 38 { 39 if(L>=R)return 1; 40 int x=get_mn(L,R); 41 long long ans=1; 42 int len=q1[x].size(); 43 int cnt=0,La=L,nw=pos[x]; 44 while(nw<len&&q1[x][nw]<L) nw++; 45 46 while(nw<len&&q1[x][nw]<=R) 47 { 48 cnt++; 49 ans=ans*dfs(La,q1[x][nw]-1)%p; 50 La=q1[x][nw]+1; 51 nw++; 52 } 53 nw--; 54 if(q1[x][nw]<=R&&q1[x][nw]>=L) 55 { 56 ans=ans*dfs(q1[x][nw]+1,R)%p; 57 } 58 pos[x]=nw+1; 59 return ans*get_E(cnt)%p; 60 } 61 int main(){ 62 // freopen("test.in","r",stdin); 63 // freopen("1.out","w",stdout); 64 scanf("%d",&n); 65 if(!n) 66 { 67 printf("1\n"); 68 return 0; 69 } 70 jc[0]=1,ni[0]=1; 71 for(int i=1;i<=n*2;i++) jc[i]=jc[i-1]*i%p; 72 ni[n*2]=ksm(jc[n*2],p-2); 73 for(int i=n*2-1;i;i--) ni[i]=ni[i+1]*(i+1)%p; 74 for(int i=1;i<=n;i++) 75 { 76 scanf("%d",&A[i]); 77 if(!ma[A[i]]) 78 { 79 zz++; 80 ma[A[i]]=zz; 81 B[zz]=A[i]; 82 } 83 } 84 sort(B+1,B+zz+1); 85 xp[1]=0; 86 for(int i=2;i<=n;i++) xp[i]=xp[i>>1]+1; 87 for(int i=1;i<=zz;i++) ma[B[i]]=i; 88 for(int i=1;i<=n;i++) A[i]=ma[A[i]]; 89 for(int i=1;i<=n;i++) st[i][0]=A[i],q1[A[i]].push_back(i); 90 for(int i=1;i<=20;i++) 91 { 92 for(int j=1;j<=n;j++) 93 { 94 if((j+(1<<i)-1)>n)break; 95 st[j][i]=min(st[j][i-1],st[j+(1<<(i-1))][i-1]); 96 } 97 } 98 printf("%lld\n",dfs(1,n)); 99 return 0; 100 } 101 /* 102 103 10 104 1 2 3 1 3 2 3 1 2 2 105 6 106 3 1 6 2 4 5 107 */
View Code