1. 程式人生 > 其它 >CF1613F - Tree Coloring 題解

CF1613F - Tree Coloring 題解

NTT

大家好,這裡是一個不會 NTT 的菜雞在 xjbbb(怎麼說話呢,罵兩個老師高興啊?)。NTT 的板子都是網上剽的(需要注意的是,NTT 需要將度數變成 \(2\) 的整次冪,但是 INTT 之後一定要 resize 回 \(\deg a + \deg b - 1\),不然可能會指數級增長)。

樸素的方法

前面的部分感覺並不難想。考慮正難則反,計算至少有一個點 \(x\) 使得 \(a_x = a_{fa_x} + 1\)。考慮容斥,欽定某些邊滿足,形成若干條直鏈(一個點必然不會有兩條通向兒子的邊滿足),每條直鏈分配的 \(a\) 值事連續的區間。那麼方案數就是 \(c!\),其中 \(c\)

事直鏈數量,因為這些完整的區間的排列順序唯一決定了它們分配的 \(a\) 值。而顯然 \(c = n - i\),其中 \(i\) 事選的邊的數量。

於是現在就是要求對於每個 \(i\),選出 \(i\) 條邊,滿足每個節點最多有一條通向兒子的邊被選,的方案數。容易發現,這其實就是 \(b_x = |son_x|\) 中選 \(i\) 的權值積的和。這貌似是分治 NTT 經典問題(?)。對每個 \(b_j\),其選或不選的生成函式為 \(1 + b_jx\),最後答案就是 \(\prod(1 + b_jx)\) 的各項係數。這玩意可以 cdq 分治 + NTT,複雜度 2log。(當然也可以任意順序啟發式合併 + NTT)

code
constexpr int N = 1e6 + 10;

fc_init(N);

int n;
int deg[N];

struct poly : vi {
	using vi::vi;
	static constexpr int g = 3;
	void NTT(int f) {
		poly &a = *this;
		int lim = 0;
		while((1 << lim) < a.size()) ++lim;
		a.resize(1 << lim, 0);
		static int R[N];
		REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
		REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
		for(int i = 1; i < a.size(); i <<= 1) {
			int gn = qpow(g, (mod - 1) / (i << 1));
			for(int j = 0; j < a.size(); j += i << 1){
				int G = 1;
				for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
					int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
					a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
				}
			}
		}
		if(f == 1) return;
		int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
		REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
	}
	friend poly operator*(poly x, poly y) {
		int sz = x.size() + y.size() - 1;
		x.resize(sz, 0), y.resize(sz, 0);
		x.NTT(1), y.NTT(1);
		REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
		x.NTT(-1);
		x.resize(sz);
		return x;
	}
	void prt() {
		for(int x : *this) cout << x << " "; puts("!");
	}
};

poly cdq(int l = 1, int r = n) {
	if(l == r) return {1, deg[l]};
	int mid = l + r >> 1;
	return cdq(l, mid) * cdq(mid + 1, r);
}

void mian() {
	n = read();
	memset(deg, -1, sizeof(deg)); deg[1] = 0;
	REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
	poly p = cdq();
	int ans = 0;
	REP(i, 0, n - 1) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
	prt(ans), pc('\n');
}

1log 方法

來自 EI 的方法。

注意到一件事情:\(\sum b_j = \mathrm O(n)\)。將所有 \(b_j\) 放到桶裡去,對每個桶事 \((1 + jx)^{b_j}\),可以直接用二項式定理線性展開。然後再一路暴力 NTT 的話,複雜度顯然是 \(\mathrm O\!\left(\sum\limits_{i=1}^n\sum\limits_{j = 1}^iC_j\log n\right)=\mathrm O\!\left(\sum\limits_{i = 1}^n (n-i+1)C_i\log n\right)\),其中 \(C_i\) 事桶 \(i\) 的大小。事實上我們可以重新指定順序,讓 \(n-i+1\) 這個看上去事 \(\mathrm O(n)\) 的東西發揮作用。我們發現必然有 \(\sum\limits_{i=1}^niC_i=\mathrm O(n)\),如果我們使得 \(n-i\) 變成 \(i\) 的話,複雜度就是 \(\mathrm O(n\log n)\) 了,而這隻需要倒過來暴力 NTT 即可。

code
constexpr int N = 1e6 + 10;

fc_init(N);

int n;
int deg[N];

struct poly : vi {
	using vi::vi;
	static constexpr int g = 3;
	void NTT(int f) {
		poly &a = *this;
		int lim = 0;
		while((1 << lim) < a.size()) ++lim;
		if((1 << lim) > N) exit(114514);
		a.resize(1 << lim, 0);
		static int R[N];
		REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
		REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
		for(int i = 1; i < a.size(); i <<= 1) {
			int gn = qpow(g, (mod - 1) / (i << 1));
			for(int j = 0; j < a.size(); j += i << 1){
				int G = 1;
				for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
					int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
					a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
				}
			}
		}
		if(f == 1) return;
		int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
		REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
	}
	friend poly operator*(poly x, poly y) {
		int sz = x.size() + y.size() - 1;
		x.resize(sz, 0), y.resize(sz, 0);
		x.NTT(1), y.NTT(1);
		REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
		x.NTT(-1);
		x.resize(sz);
		return x;
	}
	void prt() {
		for(int x : *this) cout << x << " "; puts("!");
	}
};

int cnt[N];

void mian() {
	n = read();
	memset(deg, -1, sizeof(deg)); deg[1] = 0;
	REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
	poly p = {1};
	REP(i, 1, n) ++cnt[deg[i]];
	PER(i, n, 1) {
		poly q(cnt[i] + 1);
		int now = 1;
		REP(j, 0, cnt[i]) q[j] = (ll)now * comb(cnt[i], j) % mod, now = (ll)now * i % mod;
		p = p * q;
	}
	int ans = 0;
	REP(i, 0, min(n - 1, int(p.size()) - 1)) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
	prt(ans), pc('\n');
}

帶根號的方法

跟據 \(\sum\limits_{i = 1}^n iC_i = \mathrm O(n)\) 還可以得到一個結論:\(C_i\) 有值的桶只有 \(\mathrm O(\sqrt n)\) 個,根號分治分類討論即可證明(和固定自然想到根號分治)。

眾所周知,在要相乘的多項式比較少的時候,可以都 NTT,一起乘起來,最後只要一遍 INTT。但是代價是每個多項式的 NTT 規模要是所有多項式的和,一般在多項式比較多的時候就不划算了。

這題只有 \(\mathrm O(\sqrt n)\) 個多項式,比較少,可以考慮這個 trick。那麼每個多項式都要做規模為 \(n\) 的 NTT,而每個多項式都是 \((1+jx)^{b_j}\) 的形式,它的點值是好求的,就求出二項式點值然後快速冪即可。這樣複雜度是 \(\mathrm O(n\sqrt n\log n)\),雖然跟暴力 NTT 複雜度一樣,但是常數不要小太多!所以卡一卡實際上是可以過去的。

以及你會發現這玩意跑不滿。NTT 規模實際上可以到 \(n - C_0\),這看起來沒用,但你會發現這對 \(1\sim\sqrt n\) 都取滿了的情況很友好,於是決定分析一波這玩意的最劣情況。設他有 \(x\) 個位置有值,此時要使得 \(C_0\) 最小,最優的是 \(2\sim x\) 各放一個,剩下放 \(1\),那麼 \(C_0=\mathrm O\!\left(x^2\right)\),那麼複雜度就是 \(\mathrm O\!\left((nx-x^3)\log n\right)\)。對 \(nx-x^3\) 求導,可知 \(x=\sqrt{\dfrac n3}\) 的時候最劣,算一下發現雀食優了不少。

code
constexpr int N = 1e6 + 10;

fc_init(N);

int n;
int deg[N];

struct poly : vi {
	using vi::vi;
	static constexpr int g = 3;
	void NTT(int f) {
		poly &a = *this;
		int lim = 0;
		while((1 << lim) < a.size()) ++lim;
		if((1 << lim) > N) exit(114514);
		a.resize(1 << lim, 0);
		static int R[N];
		REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
		REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
		for(int i = 1; i < a.size(); i <<= 1) {
			int gn = qpow(g, (mod - 1) / (i << 1));
			for(int j = 0; j < a.size(); j += i << 1){
				int G = 1;
				for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
					int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
					a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
				}
			}
		}
		if(f == 1) return;
		int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
		REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
	}
	friend poly operator*(poly x, poly y) {
		int sz = x.size() + y.size() - 1;
		x.resize(sz, 0), y.resize(sz, 0);
		x.NTT(1), y.NTT(1);
		REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
		x.NTT(-1);
		x.resize(sz);
		return x;
	}
	void prt() {
		for(int x : *this) cout << x << " "; puts("!");
	}
};

int cnt[N];

void mian() {
	n = read();
	memset(deg, -1, sizeof(deg)); deg[1] = 0;
	REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
	REP(i, 1, n) ++cnt[deg[i]];
	int lim = 0;
	while((1 << lim) < n - cnt[0] + 1) ++lim;
	poly p(1 << lim, 1);
	PER(i, n, 1) if(cnt[i]) {
		int G = qpow(3, (mod - 1) >> lim), gn = 1;
		REP(j, 0, (1 << lim) - 1) p[j] = (ll)p[j] * qpow(((ll)i * gn + 1) % mod, cnt[i]) % mod, gn = (ll)gn * G % mod;
	}
	p.NTT(-1);
	int ans = 0;
	REP(i, 0, min(n - 1, int(p.size()) - 1)) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
	prt(ans), pc('\n');
}
珍愛生命,遠離抄襲!