1. 程式人生 > >洛谷 P4208 [JSOI2008]最小生成樹計數 矩陣樹定理

洛谷 P4208 [JSOI2008]最小生成樹計數 矩陣樹定理

題目描述

現在給出了一個簡單無向加權圖。你不滿足於求出這個圖的最小生成樹,而希望知道這個圖中有多少個不同的最小生成樹。(如果兩顆最小生成樹中至少有一條邊不同,則這兩個最小生成樹就是不同的)。由於不同的最小生成樹可能很多,所以你只需要輸出方案數對31011的模就可以了。

輸入輸出格式

輸入格式:
第一行包含兩個數,n和m,其中1<=n<=100; 1<=m<=1000; 表示該無向圖的節點數和邊數。每個節點用1~n的整數編號。

接下來的m行,每行包含兩個整數:a, b, c,表示節點a, b之間的邊的權值為c,其中1<=c<=1,000,000,000。

資料保證不會出現自回邊和重邊。注意:具有相同權值的邊不會超過10條。

輸出格式:
輸出不同的最小生成樹有多少個。你只需要輸出數量對31011的模就可以了。

輸入輸出樣例

輸入樣例#1:
4 6
1 2 1
1 3 1
1 4 1
2 3 2
2 4 1
3 4 1
輸出樣例#1:
8
說明

說明 1<=n<=100; 1<=m<=1000;1ci109

分析:
一個圖不同的最小生成樹有兩個性質。
第一是所有最小生成樹中相同權值的邊使用了相同多次。我們考慮我們已經建好了一棵最小生成樹,對於一條不在這棵樹上的邊(u,v),保證在樹上uv的路徑上的權值都小於等於這條邊的權值,而且只有替換相同權值的邊,新樹才會是最小生成樹。
第二是同一種權值的邊連完後,連通塊完全一樣。也可以理解為相同權值的邊無論怎樣先後順序如何,連線後的連通塊完全一樣。這個是很顯然的。
假如我們假如了權值小於等於

w的邊,形成若干連通塊。此時加入邊權為w的邊,將會連線一些連通塊,把這些連通塊看做點,連線相當於形成一棵樹,使用矩陣樹就可以。還有一種特殊情況,就是同一種邊權的邊連線形成的不是一個連通圖。舉個例子,比如說第一條邊連線了12,第二條邊連線34,此時如果直接建圖做矩陣樹det就是0。那麼我們可以強行把他建成連通圖,把每個連通塊連成一個鏈,因為鏈上的邊一定會被選,所以相當於每個連通塊的樹的個數的乘積。

程式碼:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm> #define LL long long const int maxn=107; const int maxe=1007; const LL mod=31011; using namespace std; int n,m,cnt; LL a[maxn][maxn]; int b[maxn],p[maxn],f[maxn]; struct edge{ int x,y,w; }g[maxe]; bool cmp(edge x,edge y) { return x.w<y.w; } int find(int x,int *p) { int y=x,root; while (p[x]) x=p[x]; root=x; x=y; while (p[x]) { y=p[x]; p[x]=root; x=y; } return root; } void uni(int x,int y,int *p) { int u=find(x,p); int v=find(y,p); if (u==v) return; p[u]=v; } LL det() { int n=cnt-1; LL ans=1; for (int i=1;i<=n;i++) { for (int j=i+1;j<=n;j++) { while (a[j][i]) { LL rate=a[i][i]/a[j][i]; for (int k=i;k<=n;k++) { a[i][k]=(a[i][k]-rate*a[j][k]%mod+mod)%mod; swap(a[i][k],a[j][k]); } ans=mod-ans; } } ans=(ans*a[i][i])%mod; } return ans; } int main() { scanf("%d%d",&n,&m); for (int i=1;i<=m;i++) scanf("%d%d%d",&g[i].x,&g[i].y,&g[i].w); sort(g+1,g+m+1,cmp); LL ans=1,num=0; for (int i=1,last;i<=m;i=last+1) { last=i; memset(b,0,sizeof(b)); memset(a,0,sizeof(a)); memset(f,0,sizeof(f)); cnt=0; while (g[i].w==g[last].w) { int x=g[last].x; int y=g[last].y; if (find(x,p)!=find(y,p)) { int u=find(x,p),v=find(y,p); if (b[u]==0) b[u]=++cnt; if (b[v]==0) b[v]=++cnt; a[b[u]][b[v]]=(a[b[u]][b[v]]-1+mod)%mod; a[b[v]][b[u]]=(a[b[v]][b[u]]-1+mod)%mod; a[b[u]][b[u]]++; a[b[v]][b[v]]++; uni(b[u],b[v],f); } last++; } last--; for (int j=2;j<=cnt;j++) { if (find(j,f)!=find(j-1,f)) { uni(j-1,j,f); a[j-1][j]=(a[j-1][j]-1+mod)%mod; a[j][j-1]=(a[j][j-1]-1+mod)%mod; a[j-1][j-1]++; a[j][j]++; } } ans=(ans*det())%mod; for (int j=i;j<=last;j++) { if (find(g[j].x,p)!=find(g[j].y,p)) { num++; uni(g[j].x,g[j].y,p); } } } if (num<n-1) printf("0"); else printf("%lld",ans); }