題解 [校內模擬賽]排列
阿新 • • 發佈:2021-07-25
給定一棵樹,求 \(1 \to n\) 的排列數量使得對於每條邊都滿足若 \(u, v\) 相連則 \(p_u, p_v\) 相連。
賽場上想著“模擬賽都沒有怎麼做出過 B,這次一定要做出來”的 必死 決心,一直做,最後做倒是做出來了,因為 陣列開小 + 雜湊被卡 \(100 \to 40\) 了 /dk
首先手玩幾個樣例可以發現,原問題等價於把每個點重新編號,使得新樹和原樹的形狀一模一樣的方案數。
那麼容易得到兩棵子樹當且僅當根是兄弟且子樹形狀一樣的時候可以交換。那麼只需要在 dfs 的時候找出有哪幾個子樹可以互換,把它們全排列一番就可以了。
判斷子樹相同?暴力可以有 \(40\),但是衝著正解的我當然是要上樹雜湊啦。
一種常用的樹雜湊方式是這樣的。
對於每個點判斷哪些兒子的 \(h\) 值一樣就好了。
等等,還沒完!
對於這樣的一副圖,可以發現把 \(1\) 提起來和把 \(5\) 提起來是完全一樣的,而把 \(5\) 提起來就相當於讓 \(5\) 代替了 \(1\) 的位置,然後下面的位置都換了換。所以還得把上面算出來的答案再乘 \(2\)。
如果你對無根樹的雜湊“敏感度”很高的話,那麼會知道這其實就是重心,但是我對無根樹沒有一點敏感度,所以我選擇做一遍換根 dp 求出以哪些點為根的雜湊值相同。
最後注意了,為了防止雜湊被卡,請記得打亂你的 \(prime\) 陣列!,我沒打亂就 \(100 \to 60\) 了。
賽時求質數的時候用素數定理估計了一下篩的範圍大概是在 \(n \ln n\),但開陣列的時候沒乘 \(20\),於是 \(60 \to 40\)
程式碼照例點選展開看
#include <iostream> #include <utility> #include <algorithm> #include <map> #include <vector> #define int long long const int N = 100005, P = 998244353; int n, fac[N], prime[N*20], size[N], tot, ans = 1, x, y, mul; bool flag[N*20]; unsigned long long h1[N], f[N]; typedef std::map<unsigned long long, int> twt; std::vector<int> g[N]; twt cnt; void dfs(int u, int fa) { h1[u] = 1, size[u] = 1; twt t; for (int i = 0; i < (int)g[u].size(); i++) { int v = g[u][i]; if (v == fa) continue; dfs(v, u); size[u] += size[v]; t[h1[v]] ++; h1[u] += h1[v] * prime[size[v]]; } for (twt::iterator i = t.begin(); i != t.end(); i++) ans = ans * fac[i->second] % P; } void get(int u, int fa) { for (int i = 0; i < (int)g[u].size(); i++) { int v = g[u][i]; if (v == fa) continue; f[v] = h1[v] + (f[u] - h1[v] * prime[size[v]]) * prime[(n-size[v])]; // if (v == 5) std::cout << f[u] - h1[v] * prime[size[v]] << "$\n"; get(v, u); } } signed main() { std::cin >> n; fac[0] = 1; for (int i = 1; i <= n; i++) fac[i] = fac[i-1] * i % P; for (int i = 2; i <= 2000000; i++) { if (!flag[i]) prime[++tot] = i; for (int j = i+i; j <= 2000000; j += i) flag[j] = 1; } std::random_shuffle(prime+1, prime+tot+1); for (int i = 1; i < n; i++) { std::cin >> x >> y; g[x].push_back(y), g[y].push_back(x); } dfs(1, 0); // std::cout << h1[1] << '\n'; f[1] = h1[1]; get(1, 0); // for (int i = 1; i <= n; i++) std::cout << f[i] << '\n'; // std::cout << h1[1] << "#\n"; // for (int i = 1; i <= n; i++) mul += f[i] == f[1]; std::cout << ans * mul % P; }