玲瓏杯 1160
阿新 • • 發佈:2019-02-19
題意,在n本書中要拿k本書的倍數的方案,每本書都不同,一本都不拿也算一種方案
開始以為是直接求C(n,0)+C(n,k)+C(n,2k)…
求不出來 orz
看了題解後 問了yql大佬
先是可以得到一個遞推式
F[i][j]:表示前i本書,拿j本的方案
F[i][j]=F[i-1][j]+F[i-1][j-1]
因為j比較大,我們可以用滾動陣列,j=j% k
然後可以得到
答案就是f[n][0],我們只用求第一行就行了。
設,A為圖中第一個矩陣,A矩陣是k*k,如果樸素求第一行的話,時間複雜度為k*k*logn,超時gg… 然後我們發現這個可以用NTT來加速
注意到A是迴圈矩陣
(什麼是迴圈矩陣?類似於 1a3a2a2a1a3a3a2a1
如果A=
則A*A的第一行為(a1*a1+a2*a3+a3*a2 , a1*a2+a2*a1+a3*a3, a1*a3+a2*a2+a3*a1)
這個就是
卷積就可以用NTT了~用NTT的總時間複雜度為O(k*logn*logk),當k為3e4時,為1e7,但k為3e5時就會超時。因為當k>3e4時,k只能為2的冪。一般去長度為k的迴圈卷積,肯定做的是>2k的FFT,來保證不會出錯,但是如果k是2的次冪,就可以直接做長度為k的FFT,就可以直接變成點值之後快速冪。(yql教的:>)當k為2的冪次,時間大概是O(k*logn)
<從這個題中學到了很多,感謝yql~>
<基本上是yql的程式碼….>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
using namespace std;
const int maxn = 600005;
#define mod 998244353
#define ll long long
int A[maxn],B[maxn],Ans[maxn],X[maxn];
ll n;int k;
int gg=3;
int fexp(int x,int p){int ans=1;for(;p;p>>=1,x=1LL*x*x%mod)if(p&1)ans=1LL*ans*x%mod;return ans;}
void NTT(int *a,int f,int k){
for(int i=0,j=0;i<k;i++){
if(i>j)swap(a[i],a[j]);
for(int l=k>>1;(j^=l)<l;l>>=1);
}
for(int i=1;i<k;i<<=1)
{
int w=fexp(gg,(f*(mod-1)/(i<<1)+mod-1)%(mod-1));
for(int j=0;j<k;j+=i<<1){int e=1;
for(int k=0;k<i;k++,e=1LL*e*w%mod){int x,y;
x=a[j+k];y=1LL*a[j+k+i]*e%mod;
a[j+k]=(x+y)%mod;a[j+k+i]=(x-y+mod)%mod;
}
}
}
if(f==-1){
int _inv=fexp(k,mod-2);
for(int i=0;i<k;i++)a[i]=1LL*a[i]*_inv%mod;
}
}
void Work(){
if((k&(-k))==k)
{
NTT(X,1,k);
NTT(Ans,1,k);
for(;n;n>>=1)
{
if(n&1) for(int i=0;i<k;i++) Ans[i]=1LL*Ans[i]*X[i]%mod;
for(int i=0;i<k;i++) X[i]=1LL*X[i]*X[i]%mod;
}
NTT(Ans,-1,k);
}
else {
int t;
for(t=1;t<=(k*2);t<<=1);
for(;n;n>>=1)
{
if(n&1){
for(int i=0;i<t;i++) A[i]=B[i]=0;
for(int i=0;i<k;i++) A[i]=Ans[i],B[i]=X[i];
NTT(A,1,t),NTT(B,1,t);
for(int i=0;i<t;i++) A[i]=1LL*A[i]*B[i]%mod;
NTT(A,-1,t);
for(int i=0;i<t;i++) Ans[i]=0;
for(int i=0;i<t;i++) Ans[i%k]=(Ans[i%k]+A[i])%mod;
}
for(int i=0;i<t;i++) A[i]=B[i]=0;
for(int i=0;i<k;i++) A[i]=X[i];
NTT(A,1,t);
for(int i=0;i<t;i++) A[i]=1LL*A[i]*A[i]%mod;
NTT(A,-1,t);
for(int i=0;i<k;i++) X[i]=0;
for(int i=0;i<t;i++) X[i%k]=(X[i%k]+A[i])%mod;
}
}
printf("%d\n",Ans[0]);
}
void init()
{
Ans[0]=1;X[0]++,X[k-1]++;
}
int main()
{
scanf("%lld %d",&n,&k);
init();
Work();
return 0;
}