完全圖的最小生成樹計數
阿新 • • 發佈:2020-09-10
題意
給定一個長度為 \(n\) 的陣列 \(a_1,a_2,\dots ,a_n\),有一幅完全圖,滿足 \((u,v)\) 的邊權為 \(a_u \text{xor}\ a_v\) 。求邊權和最小的生成樹,你需要輸出邊權和以及方案數對 \(1e9+7\) 取模的值(邊權和不要取模)。
\(1\leq n \leq 10^5,0\leq a_i <2^{30}\)
題目連結:https://vjudge.net/problem/51Nod-1601
分析
求邊權和直接按照異或最小生成樹的模板求。
求方案數時,當兩個聯通塊相連,如果相連的最小邊權的邊存在多條,那麼按照乘法原理,應該乘到答案中。同時,由於在字典樹中,點權相同的點位於同一個點,這些點之間也要連邊。可以轉化為完全圖的生成樹個數,根據 \(\text{prufer}\)
程式碼
#include <bits/stdc++.h> #define pb push_back using namespace std; typedef long long ll; const int mod=1e9+7; const int N=1e5+5; const int maxn=2e6+5; int trie[maxn][2],a[N],cnt; int id[maxn],num[maxn],n; ll ans,cot; vector<int>value[N]; ll power(ll x,ll y) { ll res=1; x%=mod; while(y) { if(y&1) res=res*x%mod; x=x*x%mod; y>>=1; } return res; } void add(int x,int k) { int rt=1; for(int i=29;i>=0;i--) { int t=((x>>i)&1); if(trie[rt][t]==0) trie[rt][t]=++cnt; rt=trie[rt][t]; } id[rt]=k; value[k].pb(x); num[rt]+=(upper_bound(a+1,a+1+n,x)-lower_bound(a+1,a+1+n,x)); //cout<<"->"<<num[rt]<<endl; } int matching(int x,int rt,int d) { int res=(1<<d); for(int i=d-1;i>=0;i--) { int t=((x>>i)&1); if(trie[rt][t]>0) rt=trie[rt][t]; else { rt=trie[rt][1-t]; res|=(1<<i); } } return res; } void solve(int rt,int d) { if(trie[rt][0]>0) solve(trie[rt][0],d-1); if(trie[rt][1]>0) solve(trie[rt][1],d-1); if(trie[rt][0]>0&&trie[rt][1]>0) { int min_xor=(1<<30); int x=id[trie[rt][0]],y=id[trie[rt][1]]; ll w=0; ll u=1,v=1; if(num[trie[rt][0]]>1) u=power(num[trie[rt][0]],num[trie[rt][0]]-2); if(num[trie[rt][1]]>1) v=power(num[trie[rt][1]],num[trie[rt][1]]-2); //cout<<"d="<<d<<" u="<<u<<" v="<<v<<endl; if(value[x].size()<value[y].size()) { for(int i=0;i<value[x].size();i++) { int tmp=value[x][i]; int val=matching(tmp,trie[rt][1],d-1); int tn=upper_bound(a+1,a+1+n,tmp)-lower_bound(a+1,a+1+n,tmp); int ct=upper_bound(a+1,a+1+n,(tmp^val))-lower_bound(a+1,a+1+n,(tmp^val)); ll res=1LL*tn*ct%mod*u%mod*v%mod; if(val<min_xor) min_xor=val,w=res; else if(val==min_xor) w=(w+res)%mod; value[y].pb(tmp); } id[rt]=y; } else { for(int i=0;i<value[y].size();i++) { int tmp=value[y][i]; int val=matching(tmp,trie[rt][0],d-1); int tn=upper_bound(a+1,a+1+n,tmp)-lower_bound(a+1,a+1+n,tmp); int ct=upper_bound(a+1,a+1+n,(tmp^val))-lower_bound(a+1,a+1+n,(tmp^val)); ll res=1LL*tn*ct%mod*u%mod*v%mod; if(val<min_xor) min_xor=val,w=res; else if(val==min_xor) w=(w+res)%mod; value[x].pb(tmp); } id[rt]=x; } ans+=min_xor; cot=cot*w%mod; } else { if(trie[rt][0]>0||trie[rt][1]>0) { id[rt]=id[trie[rt][0]+trie[rt][1]]; num[rt]=num[trie[rt][0]]+num[trie[rt][1]]; } } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&a[i]); sort(a+1,a+1+n); cnt=1; add(a[1],1); for(int i=2;i<=n;i++) { if(a[i]!=a[i-1]) add(a[i],i); } ans=0,cot=1; solve(1,30); if(num[1]>1) cot=cot*power(num[1],num[1]-2)%mod; printf("%lld\n%lld\n",ans,cot); return 0; }