[洛谷P4491] HAOI2018 染色
問題描述
為了報答小 C 的蘋果, 小 G 打算送給熱愛美術的小 C 一塊畫布, 這塊畫布可 以抽象為一個長度為 \(N\) 的序列, 每個位置都可以被染成 \(M\) 種顏色中的某一種.
然而小 C 只關心序列的 \(N\) 個位置中出現次數恰好為 \(S\) 的顏色種數, 如果恰 好出現了 \(S\) 次的顏色有 \(K\) 種, 則小 C 會產生 \(W_k\) 的愉悅度.
小 C 希望知道對於所有可能的染色方案, 他能獲得的愉悅度的和對 \(1004535809\) 取模的結果是多少。
輸入格式
從標準輸入讀入資料. 第一行三個整數 \(N, M, S\).
接下來一行 \(M + 1\) 個整數, 第 \(i\)
輸出格式
輸出到標準輸出中. 輸出一個整數表示答案.
樣例輸入
8 8 3
3999 8477 9694 8454 3308 8961 3018 2255 4910
樣例輸出
524070430
資料範圍
\(N\le 10^7, M\le 10^5, S\le 150\)
解析
考慮容斥。設 \(f_i\) 表示至少有 \(i\) 種顏色出現了 \(S\) 次的方案數。實際上欽定出現了 \(S\) 次的顏色後就是一個可重排列,沒有被欽定的顏色任意選擇。即:
\[f_i={m\choose i} \frac{n!}{(S!)^i(n-iS)!} (n-iS)^{m-i} \]
設 \(g_i\) 表示恰好有 \(i\) 種顏色出現了 \(S\) 次的方案數。不難得到:
\[\begin{aligned}g_i&=\sum_{j=i}^m (-1)^{j-i}{j\choose i} f_j\\ &=\sum_{j=i}^m \frac{(-1)^{j-i}}{(j-i)!}\times j!f_j\end{aligned} \]
差卷積一下即可。
程式碼
#include <iostream> #include <cstdio> #define N 10000002 #define M 500002 #define int long long using namespace std; const int mod=1004535809; const int G=3; int n,m,n1=1,m1,lim,s,i,w[M],f[M],g[M],r[M],fac[N],inv[N]; int read() { char c=getchar(); int w=0; while(c<'0'||c>'9') c=getchar(); while(c<='9'&&c>='0'){ w=w*10+c-'0'; c=getchar(); } return w; } int poww(int a,int b) { int ans=1,base=a; while(b){ if(b&1) ans=ans*base%mod; base=base*base%mod; b>>=1; } return ans; } int C(int n,int m) { return fac[n]*inv[m]%mod*inv[n-m]%mod; } void NTT(int *a,int n,int inv) { for(int i=0;i<n;i++){ if(i<r[i]) swap(a[i],a[r[i]]); } for(int l=2;l<=n;l<<=1){ int mid=l/2; int cur=poww(G,(mod-1)/l); if(inv==-1) cur=poww(cur,mod-2); for(int i=0;i<n;i+=l){ int omg=1; for(int j=0;j<mid;j++,omg=omg*cur%mod){ int tmp=omg*a[i+j+mid]%mod; a[i+j+mid]=(a[i+j]-tmp+mod)%mod; a[i+j]=(a[i+j]+tmp)%mod; } } } if(inv==-1){ for(int i=0;i<n;i++) a[i]=a[i]*poww(n,mod-2)%mod; } } signed main() { n=read();m=read();s=read(); for(i=0;i<=m;i++) w[i]=read(); for(i=fac[0]=1;i<=max(n,m);i++) fac[i]=fac[i-1]*i%mod; inv[max(n,m)]=poww(fac[max(n,m)],mod-2); for(i=max(n,m)-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod; m1=min(m,n/s); for(i=0;i<=m1;i++) f[i]=fac[i]*C(m,i)%mod*fac[n]%mod*poww(inv[s],i)%mod*inv[n-s*i]%mod*poww(m-i,n-i*s)%mod; for(i=0;i<=m1;i++){ g[i]=inv[m1-i]; if((m1-i)%2!=0) g[i]=mod-g[i]; } while(n1<=2*m1) n1<<=1,lim++; for(i=0;i<n1;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(lim-1)); NTT(f,n1,1);NTT(g,n1,1); for(i=0;i<n1;i++) f[i]=f[i]*g[i]%mod; NTT(f,n1,-1); int ans=0; for(i=0;i<=m1;i++) ans=(ans+w[i]*f[m1+i]%mod*inv[i]%mod)%mod; printf("%lld\n",ans); return 0; }