熟練剖分(tree) 樹形DP
阿新 • • 發佈:2020-09-04
熟練剖分(tree) 樹形DP
題目描述
分析
我們設\(f[i][j]\)為以\(i\)為根節點的子樹中最壞時間複雜度小於等於\(j\)的概率
設\(g[i][j]\)為當前掃到的以\(i\)為父親節點的所有兒子最壞時間複雜度小於等於\(j\)的概率之和
因為每遍歷到一個新的節點,原來的\(g\)陣列中的值就要全部更新,因此我們壓掉第一維
下面我們考慮轉移
對於當前列舉到的某一個節點,我們用三重迴圈分別掃一邊
第一重迴圈代表當前哪一個節點充當重兒子,第二重迴圈列舉所有兒子,第三充迴圈列舉最壞時間複雜度\(k\)
如果第二重迴圈中列舉的兒子恰好是重兒子的話,那麼父親節點的最壞時間複雜度為\(k\)
第一種情況就是重兒子的時間複雜度恰好為\(k\)的概率乘上其它兒子時間複雜度小於等於\(k\)的概率
第二種情況就是其它兒子的時間複雜度恰好為\(k\)的概率乘上重兒子的時間複雜度小於等於\(k\)的概率
不要忘了減去重複的情況
如果第二重迴圈中列舉的兒子不是重兒子的話,那麼父親節點的最壞時間複雜度為\(k\)的情況可以由兩種情況轉移過來
第一種情況就是重兒子的時間複雜度恰好為\(k-1\)的概率乘上其它兒子時間複雜度小於等於\(k\)的概率
第二種情況就是其它兒子的時間複雜度恰好為\(k\)的概率乘上重兒子的時間複雜度小於等於\(k-1\)的概率
也不要忘了減去重複的情況
程式碼
#include<cstdio> #include<cstring> #include<vector> inline int read(){ int x=0,fh=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') fh=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*fh; } const int maxn=3e3+5; const int mod=1e9+7; int fa[maxn],head[maxn],tot=1,n,rt; struct asd{ int to,next; }b[maxn<<1]; void ad(int aa,int bb){ b[tot].to=bb; b[tot].next=head[aa]; head[aa]=tot++; } int ksm(int ds,int zs){ int ans=1; while(zs){ if(zs&1) ans=1LL*ans*ds%mod; ds=1LL*ds*ds%mod; zs>>=1; } return ans; } int son[maxn],siz[maxn]; long long f[maxn][maxn],g[maxn],h[maxn]; void dfs(int now){ siz[now]=1; for(int i=head[now];i!=-1;i=b[i].next){ int u=b[i].to; if(u==fa[now]) continue; dfs(u); siz[now]+=siz[u]; } int p=ksm(son[now],mod-2); for(int i=head[now];i!=-1;i=b[i].next){ if(b[i].to==fa[now]) continue; for(int j=0;j<=n;j++) g[j]=1; //初始化g陣列 int zez=b[i].to; //列舉重兒子 for(int j=head[now];j!=-1;j=b[j].next){ if(b[j].to==fa[now]) continue; int qez=b[j].to; //列舉其它兒子 for(int k=0;k<=siz[qez]+1;k++){ //列舉最大時間複雜度 long long qt=g[k],xz=f[qez][k]; if(k) qt-=g[k-1],xz-=f[qez][k-1]; if(zez==qez){ h[k]=(qt*f[qez][k]%mod+xz*g[k]%mod-xz*qt%mod+mod)%mod; } else { xz=f[qez][k-1]; if(k>1) xz-=f[qez][k-2]; h[k]=(qt*f[qez][k-1]%mod+xz*g[k]%mod-xz*qt%mod+mod)%mod; } } g[0]=h[0],h[0]=0; for(int k=1;k<=siz[qez]+1;k++){ g[k]=(g[k-1]+h[k])%mod; h[k]=0; } //h陣列臨時儲存狀態 } for(int j=siz[now];j>=1;j--){ g[j]=(g[j]-g[j-1]+mod)%mod; //將字首和陣列還原成正常陣列 } for(int j=0;j<=siz[now];j++){ f[now][j]=(f[now][j]+g[j]*p%mod)%mod; } } if(son[now]==0) f[now][0]=1; for(int i=1;i<=siz[now]+1;i++){ f[now][i]=(f[now][i]+f[now][i-1])%mod; } } int main(){ memset(head,-1,sizeof(head)); n=read(); int aa; for(int i=1;i<=n;i++){ son[i]=read(); for(int j=1;j<=son[i];j++){ aa=read(),fa[aa]=i; ad(i,aa),ad(aa,i); } } rt=1; while(fa[rt]) rt=fa[rt]; dfs(rt); long long ans=0; for(int i=1;i<=n;i++){ ans=(ans+i*(f[rt][i]-f[rt][i-1]+mod)%mod)%mod; } printf("%lld\n",ans); return 0; }