[Code+#4] 組合數問題2
阿新 • • 發佈:2018-11-09
一開始入隊的數字肯定是\(C^{n/2}_n\),然後將它上下左右能入堆的入堆,取出堆首元素後以此類推
要注意同一個組合數不能重複進堆,所以需要判重,但是如果根據當前擴充套件的元素是否在答案序列中來判斷是否應該進堆,比如這樣:
void bfs(int n,int k) { const int dx[4]= {1,-1,0,0},dy[4]= {0,0,1,-1}; q.push(Comb(n,n/2)); while(!q.empty()) { Comb u=q.top();q.pop(); ans.insert(make_pair(u.n,u.m)); if(ans.size()==k) return; for(int i=0; i<4; i++) { int a=u.n+dx[i],b=u.m+dy[i]; if(a>=0 && a<=n && b>=0 && b<=a && !ans.count(make_pair(a,b))) {//這裡用是否在答案序列中來判斷是否進堆 q.push(Comb(a,b)); } } } }
這樣還是會有很多重複元素進堆,還是會超時
因為如果當前擴充套件到了\(C_n^m\) ,而\(C_n^m\) 不在答案序列中,但是它有可能已經在堆中了!只是還沒有彈出!所以\(C_n^m\)就會被再次壓入堆中!
輸出中間結果可以發現,這樣的話同一個元素遠不止會進堆兩次,造成了大量的重複
所以要根據"是否已經進過堆"來判重.
再開一個vis
來記錄就行了
#include <set> #include <cmath> #include <queue> #include <cstdio> #include <iostream> using namespace std; typedef pair<int,int> pii; const int MAXN=1e6+10,mod=1e9+7; int f[MAXN],inv[MAXN],g[MAXN]; double ln_fact[MAXN]; struct Comb { int n,m; Comb(int a=0,int b=0):n(a),m(b) {} bool operator < (const Comb& rhs)const { return ln_fact[n]-ln_fact[m]-ln_fact[n-m] < ln_fact[rhs.n]-ln_fact[rhs.m]-ln_fact[rhs.n-rhs.m]; } }; priority_queue<Comb>q; set<pii>ans,vis; inline int pow(int base,int index) { int ans=1; while(index) { if(index&1) ans=(long long)ans*base%mod; base=(long long)base*base%mod; index>>=1; } return ans; } inline void init(int maxn) { ln_fact[1]=0;f[0]=1;inv[1]=1;g[0]=1; for(int i=2; i<=maxn; i++) ln_fact[i]=ln_fact[i-1]+log(i); for(int i=1; i<=maxn; i++) f[i]=(long long)f[i-1]*i%mod; for(int i=2; i<=maxn; i++) inv[i]=(long long)(mod-mod/i)*inv[mod%i]%mod; for(int i=1; i<=maxn; i++) g[i]=(long long)g[i-1]*inv[i]%mod; } void bfs(int n,int k) { const int dx[4]= {1,-1,0,0},dy[4]= {0,0,1,-1}; q.push(Comb(n,n/2)); while(!q.empty()) { Comb u=q.top();q.pop(); ans.insert(make_pair(u.n,u.m)); if(ans.size()==k) return; for(int i=0; i<4; i++) { int a=u.n+dx[i],b=u.m+dy[i]; if(a>=0 && a<=n && b>=0 && b<=a && !vis.count(make_pair(a,b))) { vis.insert(make_pair(a,b)); q.push(Comb(a,b)); } } } } int main() { int n,k; scanf("%d %d",&n,&k); init(n); bfs(n,k); long long sum=0; for(set<pii>::iterator it=ans.begin(); it!=ans.end(); it++) { int a=(*it).first,b=(*it).second; sum=(sum+(long long)f[a]*g[b]%mod*g[a-b]%mod)%mod; } cout<<sum; return 0; }
壓行技術大有提升qaq