1. 程式人生 > 實用技巧 >P4827「國家集訓隊」 Crash 的文明世界

P4827「國家集訓隊」 Crash 的文明世界

「國家集訓隊」 Crash 的文明世界

提供一種不需要腦子的方法。

其實是看洛谷討論版看出來的(

(但是全網也就這一篇這個方法的題解了)

首先這是一個關於樹上路徑的問題,我們可以無腦上點分治。

考慮當以 \(root\) 為根時,如何計算經過 \(root\) 的路徑對某一個點的貢獻。

若現在我們要找經過 \(root\) 的路徑中長度為 \(d\) 且路徑的一端為 \(u\)

則這一部分的貢獻為 \(v_{d}cnt_{d-h_u}\),其中 \(v_d=d^k\)\(h_u\) 表示點 \(u\) 的深度,\(cnt_i\) 表示深度為 \(i\) 的節點個數。

當然這裡會有一種不合法的情況,就是找到的路徑兩端點在 \(root\)

的同一棵子樹中。這可以用點分治慣用的容斥解決。

\(root\) 為根時,路徑對點 \(u\) 的貢獻為(事實上對深度為 \(h_u\) 的節點貢獻是相同的)

\[\sum_{d=h_u}^{maxdeep+h_u}v_dcnt_{d-h_u}\\ \]

為了處理起來更加方便,我們增加一些無用的部分

於是有

\[\sum_{d=0}^{2\times maxdeep}v_dcnt_{d-h_u}\\ \]

\(n=2\times maxdeep\)

\[\sum_{d=0}^{n}v_dcnt_{d-h_u}\\ \]

按照套路,將 \(cnt\) 陣列翻轉一下

\[\sum_{d=0}^{n}v_dcnt_{n-d+h_u}\\ \]

\[Ans_{n+h_u}=\sum_{d=0}^{n}v_dcnt_{n-d+h_u}\\ \]

這是一個卷積的形式,直接 \(\texttt{FFT/NTT}\) 即可。

所以總時間複雜度為 \(O(n\log_2n\log_2k)\)

(所以為啥不把這題的 k 開到和 n 同級呢)

下面講講常數優化:

  • 預處理原根、單位根必不可少。
  • 能不取模儘量別取模。
  • 由於這也是在分治的過程中進行 \(\texttt{FFT}\) 的計算,所以當規模較小時暴力會更快。

另外值得注意的是,由於本題的模數不是一個 \(\texttt{NTT}\) 模數,而中間過程中的結果最大可能為 \(10006^2>998244353\)

,所以我們可能得換一個 \(\texttt{NTT}\) 模數?(我也不知道需不需要,求大佬指點。事實上 998244353 也過了

這樣的話結果就一定不會有問題。

這個題就這樣非常套路地被我們解決了。

貼一個很醜的程式碼:

#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+5;
const int p=1e4+7;
const int P=1004535809;
int n,k;
struct edge{
	int to,nex;
}e[maxn<<1];
int head[maxn],tot;
int siz[maxn],dp[maxn],vis[maxn],rt;
int w[maxn],cnt[maxn],ans[maxn];
int f[maxn],g[maxn],rev[maxn],len=1;
void add(int a,int b){
	e[++tot]=(edge){b,head[a]};
	head[a]=tot;
}
int ksm(int a,int b,int p){
	int ans=1;
	while(b){
		if(b&1) ans=1ll*ans*a%p;
		b>>=1,a=1ll*a*a%p;
	}
	return ans;
}
vector<int> W[20];
void INIT(){
	for(int i=1,num=0;num<=17;++num,i<<=1){
		int w=ksm(3,(P-1)/(i<<1),P),tmp=1;
		for(int k=0;k<i;++k)
			W[num].emplace_back(tmp),tmp=1ll*tmp*w%P;
	}
}
void NTT(int *f){
	for(int i=0;i<len;++i)
		if(i<rev[i]) swap(f[i],f[rev[i]]);
	for(int i=1,num=0;i<len;i<<=1,++num){
		for(int j=0;j<len;j+=(i<<1)){
			for(int k=0;k<i;++k){
				int x=f[j|k],y=1ll*W[num][k]*f[i|j|k]%P;
				f[j|k]=x+y>P?x+y-P:x+y;
				f[i|j|k]=x-y<0?x-y+P:x-y;
			}
		}
	}
}
void init(int x){
	len=1;
	while(len<=x) len<<=1;
	f[0]=g[0]=0;
	for(int i=1;i<len;++i)
		rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
	memset(f,0,sizeof (int)*len);
	memset(g,0,sizeof (int)*len);
}
void getroot(int u,int f,int sum){
	siz[u]=1,dp[u]=0;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v==f||vis[v]) continue;
		getroot(v,u,sum);
		siz[u]+=siz[v];
		dp[u]=max(siz[v],dp[u]);
	}
	dp[u]=max(dp[u],sum-siz[u]);
	if(dp[u]<dp[rt]) rt=u;
}
void clear(int u,int f,int dis,int &mx){
	mx=max(mx,dis);
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v==f||vis[v]) continue;
		clear(v,u,dis+1,mx);
	}
}
void getdis(int u,int f,int dis){
	++cnt[dis];
	if(cnt[dis]>=p) cnt[dis]-=p;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v==f||vis[v]) continue;
		getdis(v,u,dis+1);
	}
}
int owo[251];
void mul(int *a,int *b,int n){
	if(n<=100){
		memset(owo,0,sizeof (int)*(2*n+1));
		for(int i=0;i<=n;++i)
			for(int j=0;j<=n;++j)
				owo[i+j]=owo[i+j]+1ll*a[i]*b[j]%P>P?owo[i+j]+1ll*a[i]*b[j]%P-P:owo[i+j]+1ll*a[i]*b[j]%P;
		for(int i=0;i<=2*n;++i) a[i]=owo[i];
		return ;
	}
	memcpy(f,a,sizeof (int)*(n+1));
	memcpy(g,b,sizeof (int)*(n+1));
	NTT(f),NTT(g);
	for(int i=0;i<len;++i) f[i]=1ll*f[i]*g[i]%P;
	NTT(f);
	reverse(f+1,f+len);
	int inv=ksm(len,P-2,P);
	for(int i=0;i<=2*n;++i) a[i]=1ll*f[i]*inv%P;
}
void dfs(int u,int f,int dis,int opt){
	ans[u]+=opt*cnt[dis];
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v==f||vis[v]) continue;
		dfs(v,u,dis+1,opt);
	}
}
void calc(int u,int dis,int opt){
	int n=0;
	clear(u,0,dis,n);n*=2;
	memset(cnt,0,sizeof (int)*(n+1));
	getdis(u,0,dis);
	reverse(cnt,cnt+n+1);
	init(2*n);
	mul(cnt,w,n);
	for(int i=0;i<=n;++i) cnt[i]=cnt[i+n];
	dfs(u,0,dis,opt);
}
void solve(int u){
	vis[u]=1;
	calc(u,0,1);
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(vis[v]) continue;
		calc(v,1,-1);
		rt=0;
		getroot(v,0,siz[v]);
		solve(rt);
	}
}
int main(){
//	freopen("1.in","r",stdin);
//	freopen("1.out","w",stdout);
	ios::sync_with_stdio(0);
	cin.tie(0),cout.tie(0);
	cin>>n>>k;
	INIT();
	for(int i=1;i<n;++i){
		int a,b;cin>>a>>b;
		add(a,b),add(b,a);
	}
	for(int i=0;i<n;++i) w[i]=ksm(i,k,p);
	dp[0]=(1<<30);
	getroot(1,0,n);
	solve(rt);
	for(int i=1;i<=n;++i) cout<<ans[i]%p<<'\n';
}