1. 程式人生 > 其它 >[省選集訓2022] 模擬賽19

[省選集訓2022] 模擬賽19

矩陣樹

題目描述

給定一個 \(n\) 個點的無向完全圖,其中邊 \((i,j)\) 的個數是 \(a(i,j)\);有 \(k\) 個要求,第 \(i\) 個要求是點集 \(S\) 的匯出子圖要連通,問滿足條件的生成樹個數,答案對 \(998244353\) 取模。

\(n\leq 500,k\leq 2000\)

解法

如果沒有限制就是裸的矩陣樹定理,這其實我們往矩陣樹的方向思考。

首先觀察 \(S\) 的匯出子圖連通有什麼性質,我們可以將限制轉化到邊上,那麼就相當於有恰好 \(|S|-1\) 條邊(忽略 \(|S|=0\) 的情況),滿足其的兩個端點都在 \(S\) 中。並且有一個關鍵的 \(\tt observation\)

:無論合法還是不合法的情況,這樣的邊數最多有 \(|S|-1\) 個。

那麼說明本題的判據跟最值有一定關聯了,那麼對於全部的 \(k\) 個限制,設 \(w(i,j)\) 表示邊 \((i,j)\) 的兩個端點都出現在了多少個 \(S\) 中,那麼只有最大生成樹才可能成為答案

可以用 \(\tt bitset\)\(O(\frac{n^2k}{w})\) 的時間求出每條邊的邊權,最大生成樹計數是經典問題。由於每種邊權的數量固定,對於每種邊權的每個連通塊,我們單獨跑矩陣樹定理,限制好矩陣大小時間複雜度就是 \(O(n^3)\) 的。

#include <cstdio>
#include <bitset>
#include <iostream>
#include <algorithm>
using namespace std;
const int M = 505;
const int N = 2005;
const int MOD = 998244353;
#define int long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,k,sum,ans,a[M][M],fa[M],fn[M],vis[M],id[M];
bitset<N> g[M];char s[N];
struct edge{int u,v,c;};vector<edge> e[N],G[M];
struct matrix
{
	int n,a[M][M];
	void clear()
	{
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				a[i][j]=0;
		n=0;
	}
	void add(int u,int v,int c)
	{
        a[u][u]=(a[u][u]+c)%MOD;
		a[v][v]=(a[v][v]+c)%MOD;
        a[u][v]=(a[u][v]+MOD-c)%MOD;
		a[v][u]=(a[v][u]+MOD-c)%MOD;
	}
	int qkpow(int a,int b)
	{
		int r=1;
		while(b>0)
		{
			if(b&1) r=r*a%MOD;
			a=a*a%MOD;
			b>>=1;
		}
		return r;
	}
	int gauss()
	{
		int ans=1;
		for(int i=2;i<=n;i++)
	    {
	        for(int j=i+1;j<=n;j++)
	            if(!a[i][i] && a[j][i])
	            {
	                ans=MOD-ans;
	                swap(a[i],a[j]);
	                break;
	            }
	        ans=ans*a[i][i]%MOD;
	        int inv=qkpow(a[i][i],MOD-2);
	        for(int j=i+1;j<=n;j++)
	        {
	            int tmp=a[j][i]*inv%MOD;
	            for(int k=i;k<=n;k++)
	                a[j][k]=(a[j][k]-a[i][k]*tmp
					%MOD+MOD)%MOD;
	        }
	    }
		return ans;
	}
}z;
int find(int x)
{
	if(x==fa[x]) return x;
	return fa[x]=find(fa[x]);
}
int zxy(int x)
{
	if(x==fn[x]) return x;
	return fn[x]=zxy(fn[x]);
}
void kruskal()
{
	ans=1;
	for(int i=1;i<=n;i++) fa[i]=fn[i]=i;
	for(int i=k;i>=0;i--) if(e[i].size())
	{
		for(edge x:e[i])
			fn[zxy(x.u)]=fn[zxy(x.v)];
		for(edge x:e[i])
		{
			int u=find(x.u),v=find(x.v);
			if(u^v) G[zxy(u)].push_back(x);
		}
		for(int R=1;R<=n;R++) if(G[R].size())
		{
			for(int j=1;j<=n;j++) vis[j]=id[j]=0;
			int m=0;
			for(edge x:G[R])
			{
				int u=find(x.u),v=find(x.v);
				if(!vis[u]) vis[u]=1,id[u]=++m;
				if(!vis[v]) vis[v]=1,id[v]=++m;
				z.add(id[u],id[v],x.c);
			}
			z.n=m;
			ans=ans*z.gauss()%MOD;
			z.clear();
			G[R].clear();
		}
		for(edge x:e[i])
		{
			int u=find(x.u),v=find(x.v);
			if(u^v) sum-=i,fa[u]=v;
		}
	}
	printf("%lld\n",(sum==0)?ans:0); 
}
signed main()
{
	freopen("treecnt.in","r",stdin);
	freopen("treecnt.out","w",stdout);
	n=read();k=read();
	for(int i=1;i<=n;i++)
		for(int j=i+1;j<=n;j++)
			a[i][j]=read();
	for(int i=1;i<=k;i++)
	{
		scanf("%s",s+1);
		int fl=0;
		for(int j=1;j<=n;j++) if(s[j]=='1')
			g[j][i]=1,fl=1,sum++;
		sum-=fl;
	}
	for(int i=1;i<=n;i++)
		for(int j=i+1;j<=n;j++)
		{
			int w=(g[i]&g[j]).count();
			e[w].push_back({i,j,a[i][j]});
		}
	kruskal();
}