CF623E Transforming Sequence
我一開始沒看到模數
看到這題,\(n\le 10^{18}\) ,\(k\le 10^4\) 就很迷惑,不是 \(n>k\) 就無解的嗎??
然而事實就是這樣。。。如果像我一樣手寫快讀的注意第一個數要開 long long
讀。
看懂題目後題意迅速轉化成了:選 \(n\) 次數,每次選一個元素 \(\in [1,k]\) 的集合,要求至少一個元素與之前選所有元素的不同,求方案數。
接下去文章分為兩部分:分割線之前都是我自己踩的雷,分割線之後是正解
輕鬆搞出一個 \(nk^2\) 的dp,設 \(dp(n,k)\) 表示取了 \(n\)
就是列舉之前選了多少種元素,然後再到 \(n\) 種元素裡找多出的 \(k-i\) 種分配位置,而且之前選的 \(i\) 個元素可以選或者不選。
顯然那個dp可以卷積於是變成了 \(k^2\log k\) ,但是有一個細節,\(l\) 的上限是 \(j-1\)。
看了半天,想著後面那個 \(k\log k\) 大概率消不掉。這個dp每次轉移 \(1\) 太浪費了吧。誒對了,說不定可以倍增FFT。
一眨眼 \(2\) 小時過去了。。。woc怎麼倍增啊???
倍增FFT必須還要有一個 \(dp(a+b)\) 與 \(dp(a),dp(b)\) 之間的轉移啊,這個沒法轉移啊。
自閉了,去看了眼題解。狀態原來不應該這麼開!或許有很多初學者會和我犯同樣的錯誤,所以上面那一大段被寫了下來。
考慮dp,設 \(dp(n,k)\) 表示前 \(n\) 輪操作中選了 \(k\) 種不同的數的總方案數,但是是欽定k種的前提下,也就是說哪k種已經定了。
所以統計答案變成了:\(ans=\sum_{i=n}^{k}\binom{k}{i}dp(n,i)\)
忽然想起之前看粉兔的部落格一直不理解為啥EGF要除階乘最後再乘回來,大概原因就是這個了吧。
千萬注意那個上界是 \(j-1\) 啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊,不然會像我一樣調一晚上的啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊
轉移方程應該很顯然吧?在現在的 \(j\) 中值裡列舉 \(l\) 個給之前的,每一個都可以選或者不選。
現在這個dp很好合並了,幾乎是一眼可以看出下面那個式子
\[dp(a+b,i)=\sum_{j=0}^{i-1}dp(a,j)dp(b,i-j)\binom{i}{j}2^{bj} \]在最終的 \(i\) 個裡分配 \(j\) 個值給前 \(a\) 個,後面 \(b\) 個直接考慮剩下的 \(i-j\) 個值即可。而且每一次轉移都可以選擇是否選 \(a\) 箇中的每一個值,所以轉移一次就乘 \(2^j\) 。這個dp其實可以看做轉移的合併,那麼轉移 \(b\) 次就要乘 \(2^{bj}\)
然後就很好倍增FFT了吧!
\[\begin{cases} dp(n+1,i)=\sum_{j=0}^{i-1}dp(n,j)\binom{i}{j}2^j\\ dp(2n,i)=\sum_{j=0}^{i-1}dp(n,j)dp(n,i-j)\binom{i}{j}2^{nj} \end{cases} \]化成可以卷積的式子:
\[\begin{cases} dp(n+1,i)=i!\sum_{j=0}^{i-1}dp(n,j)\dfrac{2^j}{j!}\dfrac{1}{(i-j)!}\\ dp(2n,i)=i!\sum_{j=0}^{i-1}dp(n,j)\dfrac{2^{nj}}{j!}dp(n,i-j)\dfrac{1}{(i-j)!} \end{cases} \]三個坑點:
-
上界要減一!!!
-
這個出題人不講武德,好端端的NTT題,結果mod=1e9+7,我沒閃,被偷襲了。
-
MTT精度要好,建議預處理單位根,比不預處理快了一倍(因為不預處理會被卡精度,然後WA,所以要開
long double
)
//Orz cyn2006
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
const int N=30005;
const int M=N<<2;
const int mod=1e9+7;
inline int qpow(int n,int k){int res=1;for(;k;k>>=1,n=1ll*n*n%mod)if(k&1)res=1ll*n*res%mod;return res;}
int n,k,f[M],ans;
namespace poly{
const db pi=acos(-1.0);
int rev[M],lg,lim;
int fac[M],ifc[M];
void initmath(const int&n){
fac[0]=1;for(int i=1;i<=n;++i)fac[i]=1ll*i*fac[i-1]%mod;
ifc[n]=qpow(fac[n],mod-2);for(int i=n-1;i>=0;--i)ifc[i]=1ll*ifc[i+1]*(i+1)%mod;
}
struct cp{
db x,y;
cp(){x=y=0;}
cp(db x_,db y_){x=x_,y=y_;}
cp operator + (const cp&t)const{return cp(x+t.x,y+t.y);}
cp operator - (const cp&t)const{return cp(x-t.x,y-t.y);}
cp operator * (const cp&t)const{return cp(x*t.x-y*t.y,x*t.y+y*t.x);}
}w[M];
void init_poly(const int&n){
for(lim=1,lg=0;lim<=n;lim<<=1,++lg);
for(int i=0;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1)),w[i]=cp(cos(2.*pi*i/lim),sin(2.*pi*i/lim));
}
void FFT(cp*a,int op){
for(int i=0;i<lim;++i)if(i>rev[i])swap(a[i],a[rev[i]]);
for(int i=1,t=lim>>1;i<lim;i<<=1,t>>=1){
for(int j=0;j<lim;j+=i<<1){
for(int k=0;k<i;++k){
const cp X=a[j+k],Y=w[t*k]*a[i+j+k];
a[j+k]=X+Y,a[i+j+k]=X-Y;
}
}
}
if(!op)for(int i=0;i<lim;++i)a[i].x/=lim;
}
void MTT(int*f,int*g,int*ans){
static cp A[M],B[M],C[M],D[M],E[M],F[M],G[M];
for(int i=0;i<lim;++i)
A[i]=cp(f[i]&65535,0),B[i]=cp(f[i]>>16,0),
C[i]=cp(g[i]&65535,0),D[i]=cp(g[i]>>16,0);
FFT(A,1),FFT(B,1),FFT(C,1),FFT(D,1);
for(int i=0;i<lim;++i)
E[i]=A[i]*C[i],F[i]=A[i]*D[i]+B[i]*C[i],G[i]=B[i]*D[i],w[i].y*=-1;
FFT(E,0),FFT(F,0),FFT(G,0);
for(int i=0;i<lim;++i)
ans[i]=LL(G[i].x+0.5)%mod,
ans[i]=((65536ll*ans[i]%mod)+LL(F[i].x+0.5)%mod)%mod,
ans[i]=((65536ll*ans[i]%mod)+LL(E[i].x+0.5)%mod)%mod,
w[i].y*=-1;
}
#define clr(a,n) memset(a,0,sizeof(int)*(n))
#define cpy(a,b) memcpy(a,b,sizeof(int)*(n))
void shift(const int&n,const int&len){
static int g[M],h[M];
clr(g,lim),clr(h,lim);
for(int i=0,bas=qpow(2,len),j=1;i<n;++i,j=1ll*j*bas%mod)g[i]=1ll*f[i]*j%mod*ifc[i]%mod;
for(int i=1;i<=n;++i)h[i]=1ll*f[i]*ifc[i]%mod;
MTT(g,h,f);
for(int i=0;i<=n;++i)f[i]=1ll*f[i]*fac[i]%mod;
clr(f+n+1,lim-n);
}
void setbit(const int&n){
static int g[M],h[M];
clr(g,lim),clr(h,lim);
for(int i=0,j=1;i<n;++i,(j<<=1)%=mod)g[i]=1ll*f[i]*j%mod*ifc[i]%mod;
for(int i=1;i<=n;++i)h[i]=ifc[i];
MTT(g,h,f);
for(int i=0;i<=n;++i)f[i]=1ll*f[i]*fac[i]%mod;
clr(f+n+1,lim-n);
}
}
using poly::fac;
using poly::ifc;
signed main(){
LL whatsthis;scanf("%lld%d",&whatsthis,&k);
if(whatsthis>k)return puts("0"),0;
n=whatsthis;
f[0]=0;rep(i,1,k)f[i]=1;
poly::init_poly(k<<1),poly::initmath(k);
for(int i=log2(n)-1,len=1;i>=0;--i){
poly::shift(k,len),len<<=1;
if(n>>i&1)poly::setbit(k),++len;
}
for(int i=n;i<=k;++i)ans=(ans+1ll*f[i]*fac[k]%mod*ifc[i]%mod*ifc[k-i]%mod)%mod;
printf("%d\n",ans);
return 0;
}