1. 程式人生 > 實用技巧 >完全圖的最小生成樹計數

完全圖的最小生成樹計數

題意

給定一個長度為 \(n\) 的陣列 \(a_1,a_2,\dots ,a_n\),有一幅完全圖,滿足 \((u,v)\) 的邊權為 \(a_u \text{xor}\ a_v\) 。求邊權和最小的生成樹,你需要輸出邊權和以及方案數對 \(1e9+7\) 取模的值(邊權和不要取模)。

\(1\leq n \leq 10^5,0\leq a_i <2^{30}\)

題目連結:https://vjudge.net/problem/51Nod-1601

分析

求邊權和直接按照異或最小生成樹的模板求。

求方案數時,當兩個聯通塊相連,如果相連的最小邊權的邊存在多條,那麼按照乘法原理,應該乘到答案中。同時,由於在字典樹中,點權相同的點位於同一個點,這些點之間也要連邊。可以轉化為完全圖的生成樹個數,根據 \(\text{prufer}\)

序列,答案為 \(n^{n-2}\) 個。注意,當點是一個重複的點,如果沒有計算,要上傳到其父親節點,直到計算為止。

程式碼

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int N=1e5+5;
const int maxn=2e6+5;
int trie[maxn][2],a[N],cnt;
int id[maxn],num[maxn],n;
ll ans,cot;
vector<int>value[N];
ll power(ll x,ll y)
{
    ll res=1;
    x%=mod;
    while(y)
    {
        if(y&1) res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void add(int x,int k)
{
    int rt=1;
    for(int i=29;i>=0;i--)
    {
        int t=((x>>i)&1);
        if(trie[rt][t]==0)
            trie[rt][t]=++cnt;
        rt=trie[rt][t];
    }
    id[rt]=k;
    value[k].pb(x);
    num[rt]+=(upper_bound(a+1,a+1+n,x)-lower_bound(a+1,a+1+n,x));
    //cout<<"->"<<num[rt]<<endl;
}
int matching(int x,int rt,int d)
{
    int res=(1<<d);
    for(int i=d-1;i>=0;i--)
    {
        int t=((x>>i)&1);
        if(trie[rt][t]>0)
            rt=trie[rt][t];
        else
        {
            rt=trie[rt][1-t];
            res|=(1<<i);
        }
    }
    return res;
}
void solve(int rt,int d)
{
    if(trie[rt][0]>0) solve(trie[rt][0],d-1);
    if(trie[rt][1]>0) solve(trie[rt][1],d-1);
    if(trie[rt][0]>0&&trie[rt][1]>0)
    {
        int min_xor=(1<<30);
        int x=id[trie[rt][0]],y=id[trie[rt][1]];
        ll w=0;
        ll u=1,v=1;
        if(num[trie[rt][0]]>1) u=power(num[trie[rt][0]],num[trie[rt][0]]-2);
        if(num[trie[rt][1]]>1) v=power(num[trie[rt][1]],num[trie[rt][1]]-2);
        //cout<<"d="<<d<<" u="<<u<<" v="<<v<<endl;
        if(value[x].size()<value[y].size())
        {
            for(int i=0;i<value[x].size();i++)
            {
                int tmp=value[x][i];
                int val=matching(tmp,trie[rt][1],d-1);
                int tn=upper_bound(a+1,a+1+n,tmp)-lower_bound(a+1,a+1+n,tmp);
                int ct=upper_bound(a+1,a+1+n,(tmp^val))-lower_bound(a+1,a+1+n,(tmp^val));
                ll res=1LL*tn*ct%mod*u%mod*v%mod;
                if(val<min_xor)
                    min_xor=val,w=res;
                else if(val==min_xor)
                    w=(w+res)%mod;
                value[y].pb(tmp);
            }
            id[rt]=y;
        }
        else
        {
            for(int i=0;i<value[y].size();i++)
            {
                int tmp=value[y][i];
                int val=matching(tmp,trie[rt][0],d-1);
                int tn=upper_bound(a+1,a+1+n,tmp)-lower_bound(a+1,a+1+n,tmp);
                int ct=upper_bound(a+1,a+1+n,(tmp^val))-lower_bound(a+1,a+1+n,(tmp^val));
                ll res=1LL*tn*ct%mod*u%mod*v%mod;
                if(val<min_xor)
                    min_xor=val,w=res;
                else if(val==min_xor)
                    w=(w+res)%mod;
                value[x].pb(tmp);
            }
            id[rt]=x;
        }
        ans+=min_xor;
        cot=cot*w%mod;
    }
    else
    {
        if(trie[rt][0]>0||trie[rt][1]>0)
        {
            id[rt]=id[trie[rt][0]+trie[rt][1]];
            num[rt]=num[trie[rt][0]]+num[trie[rt][1]];
        }
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    sort(a+1,a+1+n);
    cnt=1;
    add(a[1],1);
    for(int i=2;i<=n;i++)
    {
        if(a[i]!=a[i-1])
            add(a[i],i);
    }
    ans=0,cot=1;
    solve(1,30);
    if(num[1]>1) cot=cot*power(num[1],num[1]-2)%mod;
    printf("%lld\n%lld\n",ans,cot);
    return 0;
}