NOIP2017提高A組模擬9.7】JZOJ 計數題
阿新 • • 發佈:2020-10-21
【NOIP2017提高A組模擬9.7】JZOJ 計數題
題目
Description
Input
Output
Sample Input
5
2 2 3 4 5
Sample Output
8
6
Data Constraint
題解
題意
給出\(a[i]\),有一完全圖,\(i\)與\(j\)之間的邊的值為\(a[i] \oplus a[j]\)(\(\oplus\)為異或的意思)
求最小生成樹及方案數
題解
科普一個東西,\(n\)個點的完全圖的生成樹個數是\(n^{n-2}\)
這個東西叫做凱萊定理,大家可以自行了解一下
100\(\%\)
看到異或,而且要最小,且\(a[i]\)
想到可以按照最高位往下分治,分成當前這位是0和1的兩堆,然後為了取值最小,那麼這兩堆只能連一條
那麼就找到這兩堆裡面異或值最小的,這是\(trie\)應用的經典問題
然後分治一位一位往下
最後把所有最小值加一起,方案數乘起來即可
Code
#include<cmath> #include<cstdio> #include<algorithm> #define mod 1000000007 using namespace std; long long n,mx,num,ans,ans1,tot,a[1000001],er[31],c1[1000001],c2[1000001]; struct node { long long left,right,size; }trie[400005]; long long read() { long long res=0;char ch=getchar(); while (ch<'0'||ch>'9') ch=getchar(); while (ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch-'0'),ch=getchar(); return res; } void insert(long long x) { long long now=1; ++trie[now].size; for (long long i=mx;i>=0;--i) { if (x&er[i]) { if (trie[now].left==0) trie[now].left=++num,trie[num].left=trie[num].right=trie[num].size=0; now=trie[now].left; ++trie[now].size; } else { if (trie[now].right==0) trie[now].right=++num,trie[num].left=trie[num].right=trie[num].size=0; now=trie[now].right; ++trie[now].size; } } } long long calc(long long x) { long long now=1,s=0; for (long long i=mx;i>=0;--i) { if (x&er[i]) { if (trie[trie[now].left].size>0) now=trie[now].left; else s+=er[i],now=trie[now].right; } else { if (trie[trie[now].right].size>0) now=trie[now].right; else s+=er[i],now=trie[now].left; } } tot=trie[now].size; return s; } long long ksm(long long x,long long y) { long long res=1; while (y) { if (y&1) res=res*x%mod; x=x*x%mod; y>>=1; } return res; } long long dg(long long l,long long r,long long d) { if (r<=l) return 1; if (d<0) return ksm(r-l+1,r-l-1); long long t1=0,t2=0; for (long long i=l;i<=r;++i) { if (a[i]&er[d]) c1[++t1]=a[i]; else c2[++t2]=a[i]; } for (long long i=1;i<=t1;++i) a[l+i-1]=c1[i]; for (long long i=1;i<=t2;++i) a[l+t1+i-1]=c2[i]; long long s1=dg(l,l+t1-1,d-1),s2=dg(l+t1,r,d-1); long long s3=(s1*s2)%mod,s4=2147483647,s5=0; if (t1==0||t2==0) return s3; num=1; trie[1].left=trie[1].right=trie[1].size=0; for (long long i=1;i<=t1;++i) insert(a[l+i-1]); for (long long i=1;i<=t2;++i) { long long sum=calc(a[l+t1+i-1]); if (sum<s4) s4=sum,s5=tot; else if (sum==s4) s5=(s5+tot)%mod; } ans+=s4; return (s3*s5)%mod; } int main() { freopen("jst.in","r",stdin); freopen("jst.out","w",stdout); n=read(); for (long long i=1;i<=n;++i) a[i]=read(),mx=max(mx,a[i]); mx=log2(mx); er[0]=1; for (long long i=1;i<=31;++i) er[i]=er[i-1]*2%mod; num=1; ans1=dg(1,n,mx); printf("%lld\n%lld\n",ans,ans1); fclose(stdin); fclose(stdout); return 0; }