淺談生成樹計數問題,以SPOJ HIGH, BZOJ 4894, BZOJ 1016為例
給定一個包含n個節點,m條邊的無向圖,問圖中的生成樹的種類數有多少。
點我,喵=w=?
這就是一個最基本的生成樹問題,由此我們可以引出生成樹計數的矩陣樹定理。
矩陣樹定理:
一個無向圖G的生成樹的個數為其基爾霍夫矩陣的任意n-1階主子式的行列式的絕對值。
G的度數矩陣:
G的鄰接矩陣:
G的基爾霍夫矩陣
那麼我們只需要計算出行列式的值即可。
先用搞死小圓 小圓:我招誰惹誰了 ,將基爾矩陣化成上三角矩陣,那麼其行列式的值為
那麼我們這題就做完了!
#define others
#ifdef poj
#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#endif // poj
#ifdef others
#include <bits/stdc++.h>
#endif // others
//#define file
#define all(x) x.begin(), x.end()
using namespace std;
const double eps = 1e-8;
const double pi = acos(-1.0);
int dcmp(double x) {
if(fabs(x)<=eps) return 0;
return (x>0)?1:-1;
};
typedef long long LL;
void file() {
freopen("a.in", "r", stdin);
// freopen("1.txt", "w", stdout);
}
namespace Solver {
const int maxn = 15 ;
const int mod = 31011;
int n, m;
double C[maxn][maxn];
double gauss(int n, double a[maxn][maxn]) {
n--;
for(int i = 0 ; i < n ; i++){
int r = i;
for(int j = i + 1 ; j < n; j++)
if(fabs(a[j][i]) > fabs(a[r][i]))
r = j;
if(dcmp(a[r][i]) == 0) return 0;
if(r != i){
for(int j = 0 ; j < n ; j++)
swap(a[i][j] , a[r][j]);
}
for(int j = n; j >= i; j--)
for(int k = i + 1; k < n; k++)
a[k][j] -= a[k][i]/a[i][i] * a[i][j];
}
double ans = 1;
for(int i = 0; i < n; i++)
ans *= a[i][i];
return fabs(ans);
}
int solve() {
int t;
scanf("%d", &t);
while(t--) {
scanf("%d%d", &n, &m);
memset(C, 0, sizeof C);
for(int i = 1; i <= m; i++) {
int u, v;
scanf("%d%d", &u, &v);
u--, v--;
C[u][u] ++, C[v][v] ++;
C[u][v] --, C[v][u] --;
}
printf("%.0f\n", gauss(n, C));
}
}
};
int main() {
// file();
Solver::solve();
return 0;
}
上述程式碼使用了浮點數除法,事實上如果題目要求的方案數對素數取模,我們可以用整數逆元來代替。
需要特別注意的是,我們浮點搞死小圓中找到最大絕對值的係數來消圓本來是為了優化精度,如果在整數的模意義下,這種行為就沒有意義了,此時我們只需要找到第i行及以下任意一個係數非零的元來消即可。
我們可以從這個例子入手:BZOJ4894
給出一個有向圖的鄰接矩陣,求圖中的生成樹方案數mod1e9+7。
除去剛才我們說的需要注意的地方,這題還有一個要點是有向圖。
對於有向圖的基爾霍夫矩陣,其度數矩陣D只記錄出度,並且我們計算的時候,只能計算去除根所在的行和列得到的
#define others
#ifdef poj
#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#endif // poj
#ifdef others
#include <bits/stdc++.h>
#endif // others
//#define file
#define all(x) x.begin(), x.end()
using namespace std;
const double eps = 1e-8;
const double pi = acos(-1.0);
int dcmp(double x) {
if(fabs(x)<=eps) return 0;
return (x>0)?1:-1;
};
typedef long long LL;
void file() {
freopen("a.in", "r", stdin);
// freopen("1.txt", "w", stdout);
}
namespace Solver {
const int maxn = 330;
const int mod = 1e9+7;
int n, m;
LL Pow(LL a, LL b) {
LL res = 1;
while(b) {
if(b & 1) res *= a, res %= mod;
b >>= 1;
a *= a; a %= mod;
}
return res;
}
LL C[maxn][maxn], D[maxn][maxn];
LL gauss(int n, LL a[maxn][maxn]) {
for(int i = 1 ; i < n ; i++){
int r = i;
for(int j = i ; j < n; j++)
if(a[j][i]!=0) {
r = j;
break;
}
if(r != i){
for(int j = 1 ; j < n ; j++)
swap(a[i][j] , a[r][j]);
}
LL inv = Pow(a[i][i], mod-2);
for(int j = n - 1; j >= i; j--)
for(int k = i + 1; k < n; k++) {
a[k][j] -= ((a[k][i]*inv)%mod * a[i][j])%mod;
if(a[k][j] <0) a[k][j] += mod;
}
}
LL ans = 1;
for(int i = 1; i < n; i++)
ans *= a[i][i], ans %= mod;;
return ans;
}
int solve() {
scanf("%d", &n);
for(int i = 0; i < n; i++)
for(int j = 0; j < n; j++) {
char c; scanf(" %c", &c);
if(c == '1')
C[i][j] = mod-1, D[j][j]++;
}
for(int i = 0; i < n; i++)
for(int j = 0; j < n; j++) {
C[i][j] += D[i][j];
C[i][j] %= mod;
}
printf("%lld\n", gauss(n, C));
}
};
int main() {
// file();
Solver::solve();
return 0;
}
給出n個點,m條帶權邊,問原圖中最小生成樹的方案數是多少。
我們需要知道一個性質,最小生成樹每個權值能起的作用是一定的。
換句話說,即使不同的最小生成樹方案,對於某個權值即使所用的邊的數目是相同的,且這些邊連線了相同的連通塊。
做法1:
先跑一次最小生成樹,獲得每個邊權在最小生成樹中用了幾次。
由於每個權值的作用是獨立的,且相同權值的邊數不超過10,我們先列舉每個權值對答案的貢獻,在列舉當前權值的時候,把當前權值從最小生成樹中摳掉,然後二進位制枚舉出當前權值的每條邊用或不用,看能否組成最小生成樹。乘法原理統計答案即可。
需要注意邊權很大,需要離散化。
namespace solver {
const int maxn = 1111;
const int mod = 31011;
namespace DSU {
int fa[maxn], sz[maxn];
stack<pair<int*, int> > stk;
void init() {
for(int i = 0; i < maxn; i++) fa[i] = i, sz[i] = 1;
}
int find(int x) {
return x == fa[x]? fa[x] : find(fa[x]);
};
void merge(int x, int y, int on) {
int u = find(x), v = find(y);
if(sz[u] > sz[v]) swap(u, v);
if(on) stk.push({&fa[u], fa[u]}), stk.push({&sz[v], sz[u]});
fa[u] = v;
sz[v] += sz[u];
}
void goback() {
while(!stk.empty()) (*stk.top().first) = stk.top().second, stk.pop();
}
}
int n, m;
struct A {
int u, v, w;
bool operator < (const A & b) const {
return w < b.w;
}
};
int bit_count(int x) {
return x == 0?0:x%2+bit_count(x/2);
}
vector<A> G[maxn];
int cnt[maxn];
int solve() {
scanf("%d%d", &n, &m);
vector<A> V, used;
map<int, int> mid;
int id = 0;
int res = 1;
for(int i = 1; i <= m; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
if(mid.count(w) == 0)
mid[w] = ++id;
G[mid[w]].push_back({u, v, w});
V.push_back({u, v, w});
}
sort(all(V));
DSU::init();
int tt = 0;
for(int i = 0; i < V.size(); i++) {
int f1 = DSU::find(V[i].u), f2 = DSU::find(V[i].v);
if(f1 != f2) {
used.push_back(V[i]);
DSU::merge(f1, f2, 0);
cnt[mid[V[i].w]]++;
tt++;
}
}
if(tt != n - 1) return 0 * puts("0");
for(int i = 0; i < maxn; i++) {
if(cnt[i] == 0) continue;
DSU::init();
int tmp = 0;
for(int j = 0; j < used.size(); j++) {
if(mid[used[j].w] == i) continue;
A &buf = used[j];
if(DSU::find(buf.u) != DSU::find(buf.v))
DSU::merge(buf.u, buf.v, 0), tmp++;
}
int len = G[i].size();
int ptmp = tmp, ans = 0;
for(int j = 0; j < (1 << len); j++) {
if(bit_count(j) == cnt[i]) {
for(int k = 0; k < len; k++)
if(j & (1 << k)) {
A &buf = G[i][k];
if(DSU::find(buf.u) != DSU::find(buf.v))
DSU::merge(buf.u, buf.v, 1), tmp++;
}
if(tmp == n - 1)
ans++;
tmp = ptmp;
DSU::goback();
}
}
res *= ans;
res %= mod;
}
printf("%d\n", res);
}
}
做法2:
依舊是列舉邊權,統計每個邊權的貢獻,不過我們不再是二進位制列舉,而是把當前權值的邊單獨拿出來,剩下的各個連通塊用並查集縮點之後建新圖,這樣我們就可以用矩陣樹定理了!
算完之後依舊是乘法原理統計答案。
需要注意:31011不是質數!並且這題方案數很少,浮點數的精度是夠用的。
#define others
#ifdef poj
#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#endif // poj
#ifdef others
#include <bits/stdc++.h>
#endif // others
//#define file
#define all(x) x.begin(), x.end()
using namespace std;
const double eps = 1e-8;
const double pi = acos(-1.0);
int dcmp(double x) {
if(fabs(x)<=eps) return 0;
return (x>0)?1:-1;
};
typedef long long LL;
void file() {
freopen("a.in", "r", stdin);
// freopen("1.txt", "w", stdout);
}
namespace Solver {
const int maxn = 1111;
const int mod = 31011;
int n;
double C[111][111];
double gauss(int n, double a[111][111]) {
n--;
for(int i = 0 ; i < n ; i++){
int r = i;
for(int j = i + 1 ; j < n; j++)
if(fabs(a[j][i]) > fabs(a[r][i]))
r = j;
if(r != i){
for(int j = 0 ; j <= n ; j++)
swap(a[i][j] , a[r][j]);
}
for(int j = n; j >= i; j--)
for(int k = i + 1; k < n; k++)
a[k][j] -= a[k][i]/a[i][i] * a[i][j];
}
double ans = 1;
for(int i = 0; i < n; i++)
ans *= a[i][i];
return fabs(ans);
}
namespace DSU {
int fa[maxn], sz[maxn];
stack<pair<int*, int> > stk;
void init() {
for(int i = 0; i < maxn; i++) fa[i] = i, sz[i] = 1;
}
int find(int x) {
return x == fa[x]? fa[x] : find(fa[x]);
};
void merge(int x, int y, int on) {
int u = find(x), v = find(y);
if(sz[u] > sz[v]) swap(u, v);
if(on) stk.push({&fa[u], fa[u]}), stk.push({&sz[v], sz[u]});
fa[u] = v;
sz[v] += sz[u];
}
void goback() {
while(!stk.empty()) (*stk.top().first) = stk.top().second, stk.pop();
}
}
int m;
struct A {
int u, v, w;
bool operator < (const A & b) const {
return w < b.w;
}
};
int bit_count(int x) {
return x == 0?0:x%2+bit_count(x/2);
}
vector<A> G[maxn];
int cnt[maxn];
vector<int> G_id;
int get(int x) {
return lower_bound(all(G_id), x) - G_id.begin();
}
int solve() {
scanf("%d%d", &n, &m);
vector<A> V, used;
map<int, int> mid;
int id = 0;
int res = 1;
for(int i = 1; i <= m; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
if(mid.count(w) == 0)
mid[w] = ++id;
G[mid[w]].push_back({u, v, w});
V.push_back({u, v, w});
}
sort(all(V));
DSU::init();
int tt = 0;
for(int i = 0; i < V.size(); i++) {
int f1 = DSU::find(V[i].u), f2 = DSU::find(V[i].v);
if(f1 != f2) {
used.push_back(V[i]);
DSU::merge(f1, f2, 0);
cnt[mid[V[i].w]]++;
tt++;
}
}
if(tt != n - 1) return 0 * puts("0");
for(int i = 0; i < maxn; i++) {
if(cnt[i] == 0) continue;
DSU::init();
int tmp = 0;
for(int j = 0; j < used.size(); j++) {
if(mid[used[j].w] == i) continue;
A &buf = used[j];
if(DSU::find(buf.u) != DSU::find(buf.v))
DSU::merge(buf.u, buf.v, 0), tmp++;
}
int len = G[i].size();
G_id.clear();
for(int p = 1; p <= n; p++) G_id.push_back(DSU::find(p));
sort(all(G_id)); G_id.erase(unique(all(G_id)), G_id.end());
memset(C, 0, sizeof C);
for(int j = 0; j < len; j++) {
A &buf = G[i][j];
int u = get(DSU::find(buf.u)), v = get(DSU::find(buf.v));
C[u][u] ++, C[v][v]++;
C[u][v] --, C[v][u] --;
}
LL ans = 0;
ans = gauss(G_id.size(), C)+0.0000001;
res *= ans;
res %= mod;
}
printf("%d\n", res);
}
};
//g++ "C:\Users\meopass\Desktop\code\solution.cpp" -o solution && solution
int main() {
// file();
Solver::solve();
return 0;
}
呼…暫時就到這裡了。