CF1613F - Tree Coloring 題解
大家好,這裡是一個不會 NTT 的菜雞在 xjbbb(怎麼說話呢,罵兩個老師高興啊?)。NTT 的板子都是網上剽的(需要注意的是,NTT 需要將度數變成 \(2\) 的整次冪,但是 INTT 之後一定要 resize 回 \(\deg a + \deg b - 1\),不然可能會指數級增長)。
樸素的方法
前面的部分感覺並不難想。考慮正難則反,計算至少有一個點 \(x\) 使得 \(a_x = a_{fa_x} + 1\)。考慮容斥,欽定某些邊滿足,形成若干條直鏈(一個點必然不會有兩條通向兒子的邊滿足),每條直鏈分配的 \(a\) 值事連續的區間。那麼方案數就是 \(c!\),其中 \(c\)
於是現在就是要求對於每個 \(i\),選出 \(i\) 條邊,滿足每個節點最多有一條通向兒子的邊被選,的方案數。容易發現,這其實就是 \(b_x = |son_x|\) 中選 \(i\) 的權值積的和。這貌似是分治 NTT 經典問題(?)。對每個 \(b_j\),其選或不選的生成函式為 \(1 + b_jx\),最後答案就是 \(\prod(1 + b_jx)\) 的各項係數。這玩意可以 cdq 分治 + NTT,複雜度 2log。(當然也可以任意順序啟發式合併 + NTT)
code
constexpr int N = 1e6 + 10; fc_init(N); int n; int deg[N]; struct poly : vi { using vi::vi; static constexpr int g = 3; void NTT(int f) { poly &a = *this; int lim = 0; while((1 << lim) < a.size()) ++lim; a.resize(1 << lim, 0); static int R[N]; REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1)); REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]); for(int i = 1; i < a.size(); i <<= 1) { int gn = qpow(g, (mod - 1) / (i << 1)); for(int j = 0; j < a.size(); j += i << 1){ int G = 1; for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) { int x = a[j + k], y = (ll)G * a[j + k + i] % mod; a[j + k] = add(x, y), a[j + k + i] = add(x, -y); } } } if(f == 1) return; int nv = inv(a.size()); reverse(a.begin() + 1, a.end()); REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod; } friend poly operator*(poly x, poly y) { int sz = x.size() + y.size() - 1; x.resize(sz, 0), y.resize(sz, 0); x.NTT(1), y.NTT(1); REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod; x.NTT(-1); x.resize(sz); return x; } void prt() { for(int x : *this) cout << x << " "; puts("!"); } }; poly cdq(int l = 1, int r = n) { if(l == r) return {1, deg[l]}; int mid = l + r >> 1; return cdq(l, mid) * cdq(mid + 1, r); } void mian() { n = read(); memset(deg, -1, sizeof(deg)); deg[1] = 0; REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; } poly p = cdq(); int ans = 0; REP(i, 0, n - 1) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod); prt(ans), pc('\n'); }
1log 方法
來自 EI 的方法。
注意到一件事情:\(\sum b_j = \mathrm O(n)\)。將所有 \(b_j\) 放到桶裡去,對每個桶事 \((1 + jx)^{b_j}\),可以直接用二項式定理線性展開。然後再一路暴力 NTT 的話,複雜度顯然是 \(\mathrm O\!\left(\sum\limits_{i=1}^n\sum\limits_{j = 1}^iC_j\log n\right)=\mathrm O\!\left(\sum\limits_{i = 1}^n (n-i+1)C_i\log n\right)\),其中 \(C_i\) 事桶 \(i\) 的大小。事實上我們可以重新指定順序,讓 \(n-i+1\) 這個看上去事 \(\mathrm O(n)\) 的東西發揮作用。我們發現必然有 \(\sum\limits_{i=1}^niC_i=\mathrm O(n)\),如果我們使得 \(n-i\) 變成 \(i\) 的話,複雜度就是 \(\mathrm O(n\log n)\) 了,而這隻需要倒過來暴力 NTT 即可。
code
constexpr int N = 1e6 + 10;
fc_init(N);
int n;
int deg[N];
struct poly : vi {
using vi::vi;
static constexpr int g = 3;
void NTT(int f) {
poly &a = *this;
int lim = 0;
while((1 << lim) < a.size()) ++lim;
if((1 << lim) > N) exit(114514);
a.resize(1 << lim, 0);
static int R[N];
REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
for(int i = 1; i < a.size(); i <<= 1) {
int gn = qpow(g, (mod - 1) / (i << 1));
for(int j = 0; j < a.size(); j += i << 1){
int G = 1;
for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
}
}
}
if(f == 1) return;
int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
}
friend poly operator*(poly x, poly y) {
int sz = x.size() + y.size() - 1;
x.resize(sz, 0), y.resize(sz, 0);
x.NTT(1), y.NTT(1);
REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
x.NTT(-1);
x.resize(sz);
return x;
}
void prt() {
for(int x : *this) cout << x << " "; puts("!");
}
};
int cnt[N];
void mian() {
n = read();
memset(deg, -1, sizeof(deg)); deg[1] = 0;
REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
poly p = {1};
REP(i, 1, n) ++cnt[deg[i]];
PER(i, n, 1) {
poly q(cnt[i] + 1);
int now = 1;
REP(j, 0, cnt[i]) q[j] = (ll)now * comb(cnt[i], j) % mod, now = (ll)now * i % mod;
p = p * q;
}
int ans = 0;
REP(i, 0, min(n - 1, int(p.size()) - 1)) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
prt(ans), pc('\n');
}
帶根號的方法
跟據 \(\sum\limits_{i = 1}^n iC_i = \mathrm O(n)\) 還可以得到一個結論:\(C_i\) 有值的桶只有 \(\mathrm O(\sqrt n)\) 個,根號分治分類討論即可證明(和固定自然想到根號分治)。
眾所周知,在要相乘的多項式比較少的時候,可以都 NTT,一起乘起來,最後只要一遍 INTT。但是代價是每個多項式的 NTT 規模要是所有多項式的和,一般在多項式比較多的時候就不划算了。
這題只有 \(\mathrm O(\sqrt n)\) 個多項式,比較少,可以考慮這個 trick。那麼每個多項式都要做規模為 \(n\) 的 NTT,而每個多項式都是 \((1+jx)^{b_j}\) 的形式,它的點值是好求的,就求出二項式點值然後快速冪即可。這樣複雜度是 \(\mathrm O(n\sqrt n\log n)\),雖然跟暴力 NTT 複雜度一樣,但是常數不要小太多!所以卡一卡實際上是可以過去的。
以及你會發現這玩意跑不滿。NTT 規模實際上可以到 \(n - C_0\),這看起來沒用,但你會發現這對 \(1\sim\sqrt n\) 都取滿了的情況很友好,於是決定分析一波這玩意的最劣情況。設他有 \(x\) 個位置有值,此時要使得 \(C_0\) 最小,最優的是 \(2\sim x\) 各放一個,剩下放 \(1\),那麼 \(C_0=\mathrm O\!\left(x^2\right)\),那麼複雜度就是 \(\mathrm O\!\left((nx-x^3)\log n\right)\)。對 \(nx-x^3\) 求導,可知 \(x=\sqrt{\dfrac n3}\) 的時候最劣,算一下發現雀食優了不少。
code
constexpr int N = 1e6 + 10;
fc_init(N);
int n;
int deg[N];
struct poly : vi {
using vi::vi;
static constexpr int g = 3;
void NTT(int f) {
poly &a = *this;
int lim = 0;
while((1 << lim) < a.size()) ++lim;
if((1 << lim) > N) exit(114514);
a.resize(1 << lim, 0);
static int R[N];
REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
for(int i = 1; i < a.size(); i <<= 1) {
int gn = qpow(g, (mod - 1) / (i << 1));
for(int j = 0; j < a.size(); j += i << 1){
int G = 1;
for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
}
}
}
if(f == 1) return;
int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
}
friend poly operator*(poly x, poly y) {
int sz = x.size() + y.size() - 1;
x.resize(sz, 0), y.resize(sz, 0);
x.NTT(1), y.NTT(1);
REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
x.NTT(-1);
x.resize(sz);
return x;
}
void prt() {
for(int x : *this) cout << x << " "; puts("!");
}
};
int cnt[N];
void mian() {
n = read();
memset(deg, -1, sizeof(deg)); deg[1] = 0;
REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
REP(i, 1, n) ++cnt[deg[i]];
int lim = 0;
while((1 << lim) < n - cnt[0] + 1) ++lim;
poly p(1 << lim, 1);
PER(i, n, 1) if(cnt[i]) {
int G = qpow(3, (mod - 1) >> lim), gn = 1;
REP(j, 0, (1 << lim) - 1) p[j] = (ll)p[j] * qpow(((ll)i * gn + 1) % mod, cnt[i]) % mod, gn = (ll)gn * G % mod;
}
p.NTT(-1);
int ans = 0;
REP(i, 0, min(n - 1, int(p.size()) - 1)) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
prt(ans), pc('\n');
}