計數難題6:luoguP4935 口袋裡的紙飛機
阿新 • • 發佈:2018-11-04
計數難題6:luoguP4935 口袋裡的紙飛機
標籤(空格分隔): 計數難題題選
題目大意:
連結:戳我!
隨機生成一個大小為\(n\)的數列\(\{a_i\}\),每個數的範圍都在\([1,R]\)之間。
對於每種數列,可以生成一個\(n*n\)的網格,其中格子\((i,j)\)的數為\(a_i*a_j\% P\) 。
對於一個數列,定義其價值為形成的網格中不同的數的個數。
現在你需要求出所有數列的價值之和,答案對\(10^9 + 7\)取模。
資料範圍:\(n\leq 500 , P\leq 5000 , R\leq 10^9\) ,保證\(P\)為大於\(3\)的質數。
題解
顯然列舉每一個數\(x\)
直接算在多少個排列中出現並不好算,我們算它的反面:在多少個排列中不出現。
注意到\(P\)是質數,即有原根。
所以\(a*b \%P = x\)這樣的\((a,b)\)一定是一些互不相交的二元組。
現在數就被分成了三類:
- 若\(a*b\% P=x\),則\((a,b)\)為一個限制二元組。
- 若\(a*a\% P = x\),則\(a\)不能選。
- 其它的數則可以任意選。
顯然我們只需要處理這些限制二元組的選擇,設\(f_{i,j}\)表示前\(i\)個二元組選了\(j\)個數的方案數。
轉移直接插入\(t\)個元素:
\[f_{i,j} (v_1^t + v_2^t) \binom{j + t}{t} \to f_{i+1 , j+t} \]
這樣子的暴力是\(O(P^2 n^2)\)的。
注意到每一個二元組元素\(a,b\)的選擇個數\(v_1\)、\(v_2\)只有可能是\(\lfloor\frac{R}{n}\rfloor\)或\(\lfloor\frac{R}{n}\rfloor+1\)。
所以本質不同的二元組只有\(3\)種。
設\(f[op][i][j]\)表示使用了\(i\)個二元組\(op\),一共選了\(j\)個數的方案數。
暴力轉移,最後用\(f[0/1/2]\)合併答案,最後複雜度是\(O(Pn^2)\)的。
看上去已經沒有什麼可以優化的了。
所以就暴力分塊吧。
對於每一種二元組,我們預處理使用\(1,2,...\sqrt{P}\)
然後利用\(f_{\sqrt{n}}\),我們又可以處理出使用\(\sqrt{P},2\sqrt{P}...(\sqrt{P})^2\)個時的方案數\(g\)。
轉移全部都是插入元素即可。
那麼對於任意一個使用個數\(t\),
我們就可以通過一個\(f\)和一個\(g\)在\(O(n^2)\)的時間內通過卷積算出其所有方案數。
然後又一個結論:本質不同的\(x\)只有\(\sqrt{P}\)個。
所以記憶化後,對於每一種\(x\)暴力算出答案即可,注意特判\(0\)。
複雜度\(O(\sqrt{P}n^2)\)。
實現程式碼
#include<bits/stdc++.h>
#define IL inline
#define _
#define ll long long
using namespace std ;
IL int gi(){
int data = 0 , m = 1; char ch = 0;
while(ch!='-' && (ch<'0'||ch>'9')) ch = getchar();
if(ch == '-'){m = 0 ; ch = getchar() ; }
while(ch >= '0' && ch <= '9'){data = (data<<1) + (data<<3) + (ch^48) ; ch = getchar(); }
return (m) ? data : -data ;
}
#define mod 1000000007
IL int Pow(int ts , int js) {
int al = 1 ;
while(js) {
if(js & 1) al = 1ll * al * ts % mod ;
ts = 1ll * ts * ts % mod ;
js >>= 1 ;
}
return al ;
}
IL void add(int &x , int y){x += y ; if(x >= mod) x-= mod ;}
int Fac[5005],inv[5005],IFac[5005],n,m,R,P,Base ;
int val[5005],ban[5005],Bac,Bac2,Ans,ALL,pw1[5005],pw2[5005],cnt[5005][5] ;
int f[3][505][5005] , g[3][505][5005] , dp[5][5005] , ret[5][5005] , Result[505][505] ;
int oo ;
IL void Numb() {
Fac[0] = Fac[1] = inv[0] = inv[1] = IFac[0] = IFac[1] = 1 ;
for(int i = 2; i <= 5000; i ++) {
inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod ;
Fac[i] = 1ll * i * Fac[i - 1] % mod ;
IFac[i] = 1ll * inv[i] * IFac[i - 1] % mod ;
}
return ;
}
IL int Comb(int N , int M) {
if(M > N) return 0 ;
return 1ll * Fac[N] * IFac[M] % mod * IFac[N - M] % mod ;
}
IL void Solve(int id , int v1 , int v2) {
pw1[0] = pw2[0] = 1 ;
for(int i = 1; i <= n; i ++) pw1[i] = 1ll * pw1[i-1] * v1 % mod , pw2[i] = 1ll * pw2[i-1] * v2 % mod ;
f[id][0][0] = 1 ;
for(int i = 0; i < Bac; i ++)
for(int j = 0; j <= n; j ++)
if(f[id][i][j])
for(int t = 0; t + j <= n; t ++) {
if(t){
add(f[id][i+1][t + j] , 1ll * f[id][i][j] * Comb(j + t , t) % mod * pw1[t] % mod) ;
add(f[id][i+1][t + j] , 1ll * f[id][i][j] * Comb(j + t , t) % mod * pw2[t] % mod) ;
}
else add(f[id][i+1][j] , f[id][i][j]) ;
}
g[id][0][0] = 1 ;
for(int i = 0; i < Bac2; i ++)
for(int j = 0; j <= n; j ++)
if(g[id][i][j])
for(int t = 0; t + j <= n; t ++)
add(g[id][i+1][t + j] , 1ll * g[id][i][j] * Comb(j + t , t) % mod * f[id][Bac][t] % mod) ;
return ;
}
IL void Calc(int id , int s) {
int bs = 0 ;
while(Bac * (bs + 1) <= s) ++ bs ;
int rest = s - Bac * bs ;
for(int j = 0; j <= n; j ++) dp[0][j] = g[id][bs][j] , dp[1][j] = 0 ;
for(int j = 0; j <= n; j ++)
for(int t = 0; t + j <= n; t ++)
add(dp[1][j + t] , 1ll * dp[0][j] * f[id][rest][t] % mod * Comb(j + t , t) % mod) ;
return ;
}
IL void Query(int Id , int s0 , int s1 , int s2) {
int isum[3] = {s0,s1,s2} ;
for(int id = 0; id < 3; id ++) {
Calc(id , isum[id]) ;
for(int j = 0; j <= n; j ++) ret[id + 1][j] = dp[1][j] ;
}
for(int i = 0; i <= 3; i ++)
for(int j = 0; j <= n; j ++) dp[i][j] = 0 ;
for(int j = 0; j <= n; j ++) dp[0][j] = 0 ; dp[0][0] = 1 ;
for(int i = 0; i < 3; i ++)
for(int j = 0; j <= n; j ++)
if(dp[i][j])
for(int t = 0; j + t <= n; t ++)
add(dp[i+1][j + t] , 1ll * dp[i][j] * ret[i+1][t] % mod * Comb(j + t , t) % mod) ;
for(int j = 0; j <= n; j ++) Result[Id][j] = dp[3][j] ;
return ;
}
struct Hash{
int a0,a1,a2 ;
bool operator < (const Hash &B) const {
return (a2 ^ B.a2) ? a2 < B.a2 : ((a0 ^ B.a0) ? a0 < B.a0 : a1 < B.a1) ;
}
};map<Hash,int>ID ;
int main() {
n = gi() ; P = gi() ; R = gi() ;
val[0] = R / P ;
Numb() ;
Base = R / P ;
Bac = sqrt(P) ; Bac2 = (P + Bac - 1) / Bac ;
for(int i = 1; i < P; i ++)
if(R - Base * P >= i) val[i] = Base + 1 ; else val[i] = Base ;
for(int x1 = 1 ; x1 < P; x1 ++)
for(int x2 = 1; x2 <= x1; x2 ++) {
int v = 1ll * x1 * x2 % P ;
if(x1 == x2) ban[v] += val[x1] ;
else {
ban[v] += val[x1] + val[x2] ;
int s1 = val[x1] , s2 = val[x2] ;
if(s1 > s2) swap(s1 , s2) ;
if(s1 == Base && s2 == Base) cnt[v][0] ++ ;
else if(s1 == Base && s2 == Base + 1) cnt[v][1] ++ ;
else if(s1 == Base + 1 && s2 == Base + 1) cnt[v][2] ++ ;
}
}
Solve(0 , Base , Base) ;
Solve(1 , Base , Base + 1) ;
Solve(2 , Base + 1 , Base + 1) ;
for(int v = 1; v < P; v ++) {
Hash al ; al.a0 = cnt[v][0] ; al.a1 = cnt[v][1] ; al.a2 = cnt[v][2] ;
if(ID.find(al) == ID.end()) ID[al] = ++ oo ;
}
for(map<Hash,int>::iterator it = ID.begin(); it != ID.end(); it ++) {
Hash al = it->first ;
int id = it->second ;
Query(id , al.a0 , al.a1 , al.a2) ;
}
ALL = Pow(R , n) ;
for(int x = 1; x < P; x ++) {
Hash al ; al.a0 = cnt[x][0] ; al.a1 = cnt[x][1] ; al.a2 = cnt[x][2] ;
int id = ID[al] ;
int AL = 0 ;
for(int j = 0; j <= n; j ++) add(AL , 1ll * Result[id][j] * Pow(R - ban[x] , n - j) % mod * Comb(n , j) % mod) ;
add(Ans , (ALL - AL + mod) % mod) ;
}
Ans = (Ans + (ALL - Pow((R-val[0] + mod) % mod , n) + mod) % mod) % mod ;
cout << Ans << endl ;
return 0 ;
}