1. 程式人生 > >AGC002 D Stamp Rally 整體二分 並查集 整體二分學習筆記

AGC002 D Stamp Rally 整體二分 並查集 整體二分學習筆記

題目連結

題意:
給你一張無向連通圖,有n個點m條邊,邊的編號按照輸入順序來排。有q次詢問,每次詢問給出兩個點x和y,詢問從x和y分別出發,一共經過了z個點經過的所有邊的最大編號最小是多少。n,m,q<=1e5

題解:
這題是我做的第一道整體二分的題,所以用這道題為例做一下學習筆記。

首先對於這道題,我們看到最大值最小,不難想到一種單次詢問用二分+並查集的做法,但是多次詢問複雜度爆炸。但是我們仍然沿用二分答案這個想法,我們這裡要用到一種整體二分的方法。

我按照我的理解寫一下整體二分。可能會有諸多不嚴謹或者錯誤的地方,望諸位大神斧正。
整體二分是一種離線演算法,它的思想是把所有詢問離線下來,然後二分答案,每次對於當前二分的答案,我們把所有詢問都帶進去檢驗,看當前二分的答案對於每一組是否可行。我們記錄哪些是當前可行的,哪些是當前答案不可行的,然後再分別遞迴到 [

l , m i d ] [l,mid] [ m
i d + 1 , r ] [mid+1,r]
檢驗可行性。

對於這道題,我們發現並查集的話在二分的答案變化後可能需要撤回,於是不能路徑壓縮,為了保證複雜度,我們採用啟發式合併,這樣保證樹高是 l

o g n logn 的。我們對於當前二分的值,把編號在 [ l , m i d ] [l,mid] 之間的邊合併,並且為了撤回,記錄下是從誰合併到了誰,然後更新 s i z e size 。接下來對於所有遞迴到這個區間的詢問一一進行判斷,看在每一個詢問下一步遞迴應該遞迴到 [ l , m i d ] [l,mid] 還是 [ m i d + 1 , r ] [mid+1,r] 。判斷完了之後要先撤回當前操作,因為你會先遞迴到左區間,左區間的答案會小於等於當前 m i d mid ,所以現在需要暫時撤回。而我們遞迴到一個 l = r l=r 的區間時就意味著遞迴到這段區間的詢問的答案就是 l l (或者說是 r r ),這時由於我們先遞迴左區間,上面一層又撤回了,因為我們要在遞迴右區間時只會加當前區間的這些邊,於是要保證之前編號更小的邊已經加過了,那麼我們就在遞迴到底的時候不會撤回地合併。

以上就是這題的做法和我理解的整體二分的思想了,下面是程式碼。

程式碼:

#include <bits/stdc++.h>
using namespace std;

int n,m,qq,ans[100010],f[100010],sz[100010];
struct node
{
	int x,y;
}a[100010];
struct qwq
{
	int x,y,z,id;
}q[100010],ji[100010];
stack<node> sta;
inline int getr(int x)
{
	if(x==f[x])
	return x;
	else
	return getr(f[x]);
}
inline void solve(int l,int r,int x,int y)
{
	if(l==r)
	{
		for(int i=x;i<=y;++i)
		ans[q[i].id]=l;
		int fx=getr(a[l].x),fy=getr(a[l].y);
		if(sz[fx]>sz[fy])
		swap(fx,fy);
		if(fx!=fy)
		{
			f[fx]=fy;
			sz[fy]+=sz[fx];
		}
		return;
	}
	int mid=(l+r)>>1;
	for(int i=l;i<=mid;++i)
	{
		int fx=getr(a[i].x),fy=getr(a[i].y);
		if(sz[fx]>sz[fy])
		swap(fx,fy);
		if(fx!=fy)
		{
			f[fx]=fy;
			sz[fy]+=sz[fx];
			sta.push((node){fx,fy});
		}
	}
	int cnt1=x-1,cnt2=0;
	for(int i=x;i<=y;++i)
	{
		int fx=getr(q[i].x),fy=getr(q[i].y),size;
		if(fx==fy)
		size=sz[fx];
		else
		size=sz[fx]+sz[fy];
		if(size>=q[i].z)
		q[++cnt1]=q[i];
		else
		ji[++cnt2]=q[i];
	}
	for(int i=1;i<=cnt2;++i)
	q[cnt1+i]=ji[i];
	while(!sta.empty())
	{
		node qwqqq=sta.top();
		sta.pop();
		f[qwqqq.x]=qwqqq.x;
		sz[qwqqq.y]-=sz[qwqqq.x];
	}
	solve(l,mid,x,cnt1);
	solve(mid+1,r,cnt1+1,y);
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=m;++i)
	scanf("%d%d",&a[i].x,&a[i].y);
	for(int i=1;i<=n;++i)
	{
		f[i]=i;
		sz[i]=1;
	}
	scanf("%d",&qq);
	for(int i=1;i<=qq;++i)
	{
		scanf("%d%d%d",&q[i].x,&q[i].y,&q[i].z);
		q[i].id=i;
	}
	solve(1,m,1,qq);
	for(int i=1;i<=qq;++i)
	printf("%d\n",ans[i]);
	return 0;
}