1. 程式人生 > 其它 >356. 次小生成樹

356. 次小生成樹

題目連結

356. 次小生成樹

給定一張 \(N\) 個點 \(M\) 條邊的無向圖,求無向圖的嚴格次小生成樹。

設最小生成樹的邊權之和為 \(sum\),嚴格次小生成樹就是指邊權之和大於 \(sum\) 的生成樹中最小的一個。

輸入格式

第一行包含兩個整數 \(N\)\(M\)

接下來 \(M\) 行,每行包含三個整數 \(x,y,z\),表示點 \(x\) 和點 \(y\) 之前存在一條邊,邊的權值為 \(z\)

輸出格式

包含一行,僅一個數,表示嚴格次小生成樹的邊權和。(資料保證必定存在嚴格次小生成樹)

資料範圍

\(N≤10^5,M≤3×10^5\)

輸入樣例:

5 6
1 2 1
1 3 2
2 4 3
3 5 4
3 4 3
4 5 6

輸出樣例:

11

解題思路

倍增,最小生成樹

先求出最小生成樹,最小生成樹中的邊稱為“樹邊”,對於每一條“非樹邊”,計算加上這條邊後的最少代價,即加上這條邊後會形成一個環,當權值最大的邊的權值與噹噹前“非樹邊”的權值不相等時,需要在環去掉這條權值最大的邊,否則去掉權值次大的邊。所以關鍵在於尋找最小生成樹中兩點之間權值最大和次大的邊,可類比LCA:

  • 狀態表示:\(dp[i][j][k]\) 表示 \(j\) 向上走 \(2^k\) 步的最大值(\(i=0\))/次大值(\(i=1\)

  • 狀態計算:

  • \(dp[0][y][i-1]==dp[0][f[y][i-1]][i-1]\),其中 \(f[i][j]\)

    表示 \(i\) 向上走 \(2^j\) 步到達的節點

    • \(dp[0][y][i]=dp[0][y][i-1]\)
    • \(dp[1][y][i]=max(dp[1][y][i-1],dp[1][f[y][i-1]][i-1])\)
  • \(dp[0][y][i-1]\neq dp[0][f[y][i-1]][i-1]\)

    • \(dp[0][y][i]=max(dp[0][y][i-1],dp[0][f[y][i-1]][i-1])\)
    • \(dp[1][y][i]=max(\{min(dp[0][y][i-1],dp[0][f[y][i-1]][i-1]),dp[1][y][i-1],dp[1][f[y][i-1]][i-1]\})\)

      分析:整體最大值為前後兩部分的最大值,當前後兩部分的最大值相等時,次大值為前後兩部分次大值的較大值;否則為前後兩部分最大值的較小值和前後兩部分的次大值的最大值

最後求最小生成樹兩點的最大和次大值可類比LCA

  • 時間複雜度:\(O(m(logn+logm))\)

程式碼

// Problem: 次小生成樹
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/description/358/
// Memory Limit: 512 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// %%%Skyqwq
#include <bits/stdc++.h>
 
// #define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}


const int N=3e5+5;
const LL inf=1e16;
int n,m,fa[N],d[N],f[N][20];
LL dp[2][N][20],val1,val2,res,ret;
bool v[N];
vector<PII> adj[N];
struct T
{
	int x,y,z;
	bool operator<(const T &t)
	{
		return z<t.z;
	}
}tr[N];
int find(int x)
{
	return x==fa[x]?x:fa[x]=find(fa[x]);
}
void kruskal()
{
    sort(tr+1,tr+1+m);
	for(int i=1;i<=m;i++)
	{
		int x=tr[i].x,y=tr[i].y,z=tr[i].z;
		int a=find(x),b=find(y);
		if(a==b)continue;
		res+=z;
		adj[x].pb({y,z});
		adj[y].pb({x,z});
		v[i]=true;
		fa[a]=b;
	}
}
void bfs()
{
	d[1]=0;
	queue<int> q;
	q.push(1);
	while(q.size())
	{
		int x=q.front();
		q.pop();
		int len=log2(d[x]);
		for(auto t:adj[x])
		{
			int y=t.fi,z=t.se;
			if(y==f[x][0])continue;
			q.push(y);
			f[y][0]=x;
			d[y]=d[x]+1;
			dp[0][y][0]=z,dp[1][y][0]=-inf;
			for(int i=1;i<=len;i++)
			{
				f[y][i]=f[f[y][i-1]][i-1];
				if(dp[0][y][i-1]==dp[0][f[y][i-1]][i-1])
				{
					dp[0][y][i]=dp[0][y][i-1];
					dp[1][y][i]=max(dp[1][y][i-1],dp[1][f[y][i-1]][i-1]);
				}
				else
				{
					dp[0][y][i]=max(dp[0][y][i-1],dp[0][f[y][i-1]][i-1]);
					dp[1][y][i]=max({min(dp[0][y][i-1],dp[0][f[y][i-1]][i-1]),dp[1][y][i-1],dp[1][f[y][i-1]][i-1]});
				}
			}
		}
	}
}
void update(int x)
{
	if(val1<x)val2=val1,val1=x;
	else if(val2<x&&x!=val1)val2=x;
}
void lca(int x,int y)
{
	val1=val2=-inf;
	if(d[x]>d[y])swap(x,y);
	while(d[x]<d[y])
	{
		int len=log2(d[y]-d[x]);
		update(dp[0][y][len]);
		update(dp[1][y][len]);
		y=f[y][len];
	}
	if(x==y)return ;
	for(int len=log2(d[x]);len>=0;len--)
		if(f[x][len]!=f[y][len])
		{
			update(dp[0][x][len]);
			update(dp[1][x][len]);
			update(dp[0][y][len]);
			update(dp[1][y][len]);
			x=f[x][len];
			y=f[y][len];
		}
	update(dp[0][x][0]);
	update(dp[1][x][0]);
	update(dp[0][y][0]);
	update(dp[1][y][0]);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)fa[i]=i;
    for(int i=1;i<=m;i++)
    {
    	int x,y,z;
    	scanf("%d%d%d",&x,&y,&z);
    	tr[i]={x,y,z};
    }
    kruskal();
    bfs();
    ret=inf;
    for(int i=1;i<=m;i++)
    {
    	if(v[i])continue;
    	lca(tr[i].x,tr[i].y);
    	if(tr[i].z==val1)
    		ret=min(ret,res-val2+tr[i].z);
    	else
    		ret=min(ret,res-val1+tr[i].z);
    }
    printf("%lld",ret);
    return 0;
}