uoj #176. 新年的繁榮 字典樹+Boruvka演算法
阿新 • • 發佈:2018-12-30
題意
有n個點,第i個點的權值為a[i],在第i個點和第j個點之間連邊的代價為a[i] and a[j]。問這個圖的最大生成樹。
分析
這題我們可以用最小生成樹的Boruvka演算法。具體來說就是一開始對於每個點,找到和該點相連的邊權最大的邊,然後把這條邊加入生成樹中。然後對於每個連通塊,又找到連線該連通塊的邊權最大的邊加入生成樹。這樣做每次連通塊的複雜度至少減少一半,所以最多做不超過logn次。
現在的問題在於我們如何對於每個點i,找到一個j使得a[i] and a[j]最大。
這個我們可以用字典樹。一開始先把每個樹都插進去。然後查詢的時候,若當前位是1,則走1的子樹;若當前位是0,則發現不管走1還是走0都是一樣的。於是我們可以在把數插入完後,從下到上把每個節點的1子樹複製一遍扔到0子樹裡面,這樣的話就可以每次只走0子樹了。
這樣做的話不難發現最壞情況每個點都會被祖先遍歷一次,於是複雜度是 。
我們可以在字典樹每個節點維護其子樹中連通塊編號的最大值和最小值,這樣就能判斷子樹中是否存在不同的連通塊。
那麼總的時間複雜度就是。
程式碼
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define mp(x,y) make_pair(x,y)
#define MAX(x,y) x=max(x,y)
#define MIN(x,y) x=min(x,y)
using namespace std;
typedef long long LL;
typedef pair<int,int> pi;
const int N=100005;
int n,a[N],sz,f[N],rt,m,bin[20];
LL ans;
pi mx[N];
struct tree{int l,r,mn,mx;}t[N*10];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int find(int x)
{
if (f[x]==x) return x;
else return f[x]=find(f[x]);
}
int newnode()
{
int x=++sz;t[x].l=t[x].r=t[x].mx=0;t[x].mn=n;return x;
}
void ins(int &d,int dep,int v,int id)
{
if (!d) d=newnode();
MIN(t[d].mn,id);MAX(t[d].mx,id);
if (dep<0) return;
if (v&bin[dep]) ins(t[d].r,dep-1,v,id);
else ins(t[d].l,dep-1,v,id);
}
pi query(int d,int dep,int v,int id)
{
if (dep<0) return mp(0,id==t[d].mn?t[d].mx:t[d].mn);
pi ans;
if (v&bin[dep])
if (t[d].r&&(id!=t[t[d].r].mn||id!=t[t[d].r].mx)) ans=query(t[d].r,dep-1,v,id),ans.first+=bin[dep];
else ans=query(t[d].l,dep-1,v,id);
else ans=query(t[d].l,dep-1,v,id);
return ans;
}
int merge(int x,int y)
{
if (!y) return x;
if (!x) x=newnode();
MIN(t[x].mn,t[y].mn);
MAX(t[x].mx,t[y].mx);
t[x].l=merge(t[x].l,t[y].l);
t[x].r=merge(t[x].r,t[y].r);
return x;
}
void dfs(int d,int dep)
{
if (dep<0) return;
if (t[d].l) dfs(t[d].l,dep-1);
if (t[d].r) dfs(t[d].r,dep-1);
t[d].l=merge(t[d].l,t[d].r);
}
void build()
{
sz=0;rt=newnode();
for (int i=1;i<=n;i++) ins(rt,m-1,a[i],find(i));
dfs(rt,m-1);
}
int main()
{
n=read();m=read();int now=n;
bin[0]=1;
for (int i=1;i<=m;i++) bin[i]=bin[i-1]*2;
for (int i=1;i<=n;i++) a[i]=read(),f[i]=i;
while (now>1)
{
build();
for (int i=1;i<=n;i++) mx[i]=mp(-1,0);
for (int i=1;i<=n;i++)
{
pi u=query(rt,m-1,a[i],find(i));
if (u.first>mx[find(i)].first) mx[find(i)]=u;
}
for (int i=1;i<=n;i++)
if (f[i]==i)
{
int x=find(mx[i].second);
if (x!=i) f[i]=x,now--,ans+=mx[i].first;
}
}
printf("%lld",ans);
return 0;
}