1. 程式人生 > 實用技巧 >題解 HDU 5279 plays Minecraft

題解 HDU 5279 plays Minecraft

題目傳送門

題目大意

給出\(n\)以及\(a_{1,2,...,n}\),表示有\(n\)個完全圖,第\(i\)個完全圖大小為\(a_i\),這些完全圖之間第\(i\)個完全圖的點\(a_i\)\(i\bmod n+1\)的點\(1\)相連。問有多少種方法可以刪掉某些邊,使得整個圖變成一個森林。

思路

話說因為是英文懶得讀題,直接看題解裡面的題目大意,結果一直理解不了,後來才發現他意思寫錯了。。。所以說一個正確的題目大意有多重要(霧

有一個人盡皆知的知識,就是\(n\)個點的樹有\(n^{n-2}\)個。下面會用到這個東西。

我們設\(f_i\)表示\(i\)個點的完全圖刪掉一些邊變成森林的方案數。可以得到:

\[f_n=\sum_{i=0}^{n-1}\binom{n-1}{i}f_i(n-i)^{n-i-2} \]

\[=(n-1)!\sum_{i=0}^{n-1}\frac{f_i}{i!}\frac{(n-i)^{n-i-2}}{(n-i-1)!} \]

這個式子的意思就是我們可以先固定一個點,然後從\(n-1\)選出\(i\)個點單獨成森林,然後剩下的點組成一棵樹。

然後我們發現這個式子我們其實可以使用分治\(\text {NTT}\)預處理\(\Theta(n\log^2 n)\)之內求出來。

我們考慮如何統計答案。我們發現其實可以使用容斥原理求到:

\[\text {ans}=2^n\prod_{i=1}^{n} f_{a_i}-\prod_{i=1}^{n} (\sum_{j=2}^{a_i}\binom{a_i-2}{j-2}j^{j-2}f_{a_i-j}) \]

前面一個的意思就是統計所有的答案,完全圖之間管它連不連,後面一個就是整個圖構成一個大環(這其實並不準確,但是意思到位就行)。既然要構成大環,肯定一個完全圖內首尾要相連,就肯定是一棵樹(要刪邊成森林)。

然後我們發現後面那個式子可以\(\text {NTT}\)預處理,於是我們就解決了這個問題。

總時間複雜度為\(\Theta(\max\{a_i\}+Tn)\)

\(\text {Code}\)

#include <bits/stdc++.h>
using namespace std;

#define Int register int
#define mod 998244353
#define ll long long
#define MAXN 600005

int quick_pow (int a,int b,int c){
	int res = 1;for (;b;b >>= 1,a = 1ll * a * a % c) if (b & 1) res = 1ll * res * a % c;
	return res;
}

int rev[MAXN];

void NTT (int *a,int len,int type){
#define G 3
#define Gi 332748118
	int limit = 1,l = 0;
	while (limit < len) limit <<= 1,l ++;
	for (Int i = 0;i < limit;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
	for (Int i = 0;i < limit;++ i) if (i < rev[i]) swap (a[i],a[rev[i]]);
	for (Int i = 1;i < limit;i <<= 1){
		int Wn = quick_pow (type == 1 ? G : Gi,(mod - 1) / (i << 1),mod);
		for (Int j = 0;j < limit;j += i << 1)
			for (Int k = 0,w = 1;k < i;++ k,w = 1ll * w * Wn % mod){
				int x = a[j + k],y = 1ll * w * a[i + j + k] % mod;
				a[j + k] = (x + y) % mod,a[i + j + k] = (x + mod - y) % mod;
			}
	}
	if (type == 1) return ;
	for (Int i = 0,Inv = quick_pow (limit,mod - 2,mod);i < limit;++ i) a[i] = 1ll * a[i] * Inv % mod;
#undef G
#undef Gi
}

void multi (int *a,int *b,int *c,int len1,int len2){
	int limit = 1;
	while (limit < len1 + len2 - 1) limit <<= 1;
	for (Int i = len1;i < limit;++ i) a[i] = 0;
	for (Int i = len2;i < limit;++ i) b[i] = 0;
	NTT (a,limit,1),NTT (b,limit,1);
	for (Int i = 0;i < limit;++ i) c[i] = 1ll * a[i] * b[i] % mod;
	NTT (c,limit,-1);
}

int f[MAXN],g[MAXN],A[MAXN],B[MAXN],C[MAXN],fac[MAXN],ifac[MAXN];

void cdq (int l,int r){
	if (l == r){
		if (l == 0) f[l] = 1;
		else f[l] = 1ll * f[l] * fac[l - 1] % mod;
		return ;
	}
	int mid = (l + r) >> 1;cdq (l,mid);
	for (Int i = l;i <= mid;++ i) A[i - l] = 1ll * f[i] * ifac[i] % mod;
	B[0] = 1;for (Int i = 2;i <= r - l;++ i) B[i - 1] = 1ll * quick_pow (i,i - 2,mod) * ifac[i - 1] % mod;
	multi (A,B,C,mid - l + 1,r - l);
	for (Int i = mid + 1;i <= r;++ i) f[i] = (f[i] + C[i - l - 1]) % mod;
	cdq (mid + 1,r);
}

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}

signed main(){
#define Maxn 1<<17
#define Maxm 100005
	fac[0] = 1;for (Int i = 1;i <= Maxn;++ i) fac[i] = 1ll * fac[i - 1] * i % mod;
	ifac[Maxn] = quick_pow (fac[Maxn],mod - 2,mod);for (Int i = Maxn;i;-- i) ifac[i - 1] = 1ll * ifac[i] * i % mod;
	cdq (0,Maxm);
	for (Int i = 2;i <= Maxm;++ i) A[i - 2] = 1ll * quick_pow (i,i - 2,mod) * ifac[i - 2] % mod;
	for (Int i = 0;i <= Maxm - 2;++ i) B[i] = 1ll * f[i] * ifac[i] % mod;
	multi (A,B,C,Maxm - 1,Maxm - 1);
	g[1] = 1;for (Int i = 2;i <= Maxm;++ i) g[i] = 1ll * fac[i - 2] * C[i - 2] % mod;
	int T;read (T);
	while (T --){
		int n;read (n);int ans = quick_pow (2,n,mod),sum = 1;
		for (Int i = 1,a;i <= n;++ i) read (a),ans = 1ll * ans * f[a] % mod,sum = 1ll * sum * g[a] % mod;
		write ((ans + mod - sum) % mod),putchar ('\n');
	}
	return 0;
}