AGC002 D Stamp Rally 整體二分 並查集 整體二分學習筆記
題意:
給你一張無向連通圖,有n個點m條邊,邊的編號按照輸入順序來排。有q次詢問,每次詢問給出兩個點x和y,詢問從x和y分別出發,一共經過了z個點經過的所有邊的最大編號最小是多少。n,m,q<=1e5
題解:
這題是我做的第一道整體二分的題,所以用這道題為例做一下學習筆記。
首先對於這道題,我們看到最大值最小,不難想到一種單次詢問用二分+並查集的做法,但是多次詢問複雜度爆炸。但是我們仍然沿用二分答案這個想法,我們這裡要用到一種整體二分的方法。
我按照我的理解寫一下整體二分。可能會有諸多不嚴謹或者錯誤的地方,望諸位大神斧正。
整體二分是一種離線演算法,它的思想是把所有詢問離線下來,然後二分答案,每次對於當前二分的答案,我們把所有詢問都帶進去檢驗,看當前二分的答案對於每一組是否可行。我們記錄哪些是當前可行的,哪些是當前答案不可行的,然後再分別遞迴到
和
檢驗可行性。
對於這道題,我們發現並查集的話在二分的答案變化後可能需要撤回,於是不能路徑壓縮,為了保證複雜度,我們採用啟發式合併,這樣保證樹高是 的。我們對於當前二分的值,把編號在 之間的邊合併,並且為了撤回,記錄下是從誰合併到了誰,然後更新 。接下來對於所有遞迴到這個區間的詢問一一進行判斷,看在每一個詢問下一步遞迴應該遞迴到 還是 。判斷完了之後要先撤回當前操作,因為你會先遞迴到左區間,左區間的答案會小於等於當前 ,所以現在需要暫時撤回。而我們遞迴到一個 的區間時就意味著遞迴到這段區間的詢問的答案就是 (或者說是 ),這時由於我們先遞迴左區間,上面一層又撤回了,因為我們要在遞迴右區間時只會加當前區間的這些邊,於是要保證之前編號更小的邊已經加過了,那麼我們就在遞迴到底的時候不會撤回地合併。
以上就是這題的做法和我理解的整體二分的思想了,下面是程式碼。
程式碼:
#include <bits/stdc++.h>
using namespace std;
int n,m,qq,ans[100010],f[100010],sz[100010];
struct node
{
int x,y;
}a[100010];
struct qwq
{
int x,y,z,id;
}q[100010],ji[100010];
stack<node> sta;
inline int getr(int x)
{
if(x==f[x])
return x;
else
return getr(f[x]);
}
inline void solve(int l,int r,int x,int y)
{
if(l==r)
{
for(int i=x;i<=y;++i)
ans[q[i].id]=l;
int fx=getr(a[l].x),fy=getr(a[l].y);
if(sz[fx]>sz[fy])
swap(fx,fy);
if(fx!=fy)
{
f[fx]=fy;
sz[fy]+=sz[fx];
}
return;
}
int mid=(l+r)>>1;
for(int i=l;i<=mid;++i)
{
int fx=getr(a[i].x),fy=getr(a[i].y);
if(sz[fx]>sz[fy])
swap(fx,fy);
if(fx!=fy)
{
f[fx]=fy;
sz[fy]+=sz[fx];
sta.push((node){fx,fy});
}
}
int cnt1=x-1,cnt2=0;
for(int i=x;i<=y;++i)
{
int fx=getr(q[i].x),fy=getr(q[i].y),size;
if(fx==fy)
size=sz[fx];
else
size=sz[fx]+sz[fy];
if(size>=q[i].z)
q[++cnt1]=q[i];
else
ji[++cnt2]=q[i];
}
for(int i=1;i<=cnt2;++i)
q[cnt1+i]=ji[i];
while(!sta.empty())
{
node qwqqq=sta.top();
sta.pop();
f[qwqqq.x]=qwqqq.x;
sz[qwqqq.y]-=sz[qwqqq.x];
}
solve(l,mid,x,cnt1);
solve(mid+1,r,cnt1+1,y);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;++i)
scanf("%d%d",&a[i].x,&a[i].y);
for(int i=1;i<=n;++i)
{
f[i]=i;
sz[i]=1;
}
scanf("%d",&qq);
for(int i=1;i<=qq;++i)
{
scanf("%d%d%d",&q[i].x,&q[i].y,&q[i].z);
q[i].id=i;
}
solve(1,m,1,qq);
for(int i=1;i<=qq;++i)
printf("%d\n",ans[i]);
return 0;
}