1. 程式人生 > >2018.11.17 hdu5829Rikka with Subset(ntt)

2018.11.17 hdu5829Rikka with Subset(ntt)

傳送門
n t t ntt 基礎題。
考慮計算每一個數在排名為 k k 時被統計了多少次來更新答案。
這樣的話,設 a

n s k ans_k 表示所有數的值乘上排名為 k k 的子集數的總和。
a
n s k = i = k
n
a i ( i 1 k 1 ) 2 n i ans_k=\sum_{i=k}^na_i\binom{i-1}{k-1}2^{n-i}

=> a n s k = 1 ( k 1 ) ! i = k n a i ( i 1 ) ! ( i k ) ! 2 n i ans_k=\frac1{(k-1)!}\sum_{i=k}^na_i\frac{(i-1)!}{(i-k)!}2^{n-i}
=> a n s k = 1 ( k 1 ) ! 2 k i = 0 n k a i + k ( i + k 1 ) ! 2 n i i ! ans_k=\frac1{(k-1)!2^k}\sum_{i=0}^{n-k}a_{i+k}(i+k-1)!\frac{2^{n-i}}{i!}
然後令 x i = 2 n i i ! , y i = a i ( i 1 ) ! x_i=\frac{2^{n-i}}{i!},y_i=a_i(i-1)!
那麼將 y y 陣列翻轉再平移一下就可以卷積了。
程式碼:

#include<bits/stdc++.h>
#define ri register int
using namespace std;
inline int read(){
    int ans=0;
    char ch=getchar();
    while(!isdigit(ch))ch=getchar();
    while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
    return ans;
}
typedef long long ll;
const int N=4e5+5,mod=998244353;
int tim,lim,n,a[N],b[N],fac[N],ifac[N],pos[N],pow2[N],num[N];
inline int ksm(int a,int p){int ret=1;for(;p;p>>=1,a=(ll)a*a%mod)if(p&1)ret=(ll)ret*a%mod;return ret;}
inline void prepare(){
    ifac[0]=ifac[1]=fac[0]=fac[1]=1,pow2[0]=1;
    for(ri i=1;i<=1e5;++i)pow2[i]=(ll)pow2[i-1]*2%mod;
    for(ri i=2;i<=1e5;++i)ifac[i]=(ll)ifac[mod%i]*(mod-mod/i)%mod,fac[i]=(ll)fac[i-1]*i%mod;
    for(ri i=2;i<=1e5;++i)ifac[i]=(ll)ifac[i-1]*ifac[i]%mod;
}
inline void init(){
    lim=1,tim=0;
    while(lim<=n*2)lim<<=1,++tim;
    for(ri i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
}
inline void ntt(int *a,int type){
    for(ri i=0;i<lim;++i)if(i<pos[i])swap(a[i],a[pos[i]]);
    int typ=type==1?3:(mod+1)/3,mult=(mod-1)>>1;
    for(ri mid=1,wn;mid<lim;mid<<=1,mult>>=1){
        wn=ksm(typ,mult);
        for(ri j=0,len=mid<<1;j<lim;j+=len){
            for(ri k=0,w=1;k<mid;++k,w=(ll)w*wn%mod){
                int a0=a[j+k],a1=(ll)a[j+k+mid]*w%mod;
                a[j+k]=(a0+a1)%mod,a[j+k+mid]=(a0-a1+mod)%mod;
            }
        }
    }
    if(type==-1){
        int inv=ksm(lim,mod-2);
        for(ri i=0;i<lim;++i)a[i]=(ll)a[i]*inv%mod;
    }
}
int main(){
    prepare();
    for(int tt=read();tt;--tt,puts("")){
        n=read(),init();
        for(ri i=1;i<=n;++i)num[i]=read();
        sort(num+1,num+n+1),reverse(num+1,num+n+1);
        for(ri i=0;i<n;++i)a[i]=(ll)ifac[i]*pow2[n-i]%mod;
        for(ri i=1;i<=n;++i)b[i]=(ll)fac[i-1]*num[i]%mod;
        reverse(b+1,b+n+1);
        for(ri i=0;i<n;++i)b[i]=b[i+1];
        b[n]=0;
        ntt(a,1),ntt(b,1);
        for(ri i=0;i<lim;++i)a[i]=(ll)a[i]*b[i]%mod;
        ntt(a,-1);
        for(ri inv=(mod+1)/2,ans=0,last=0,i=1;i<=n;++i){
            ans=((ll)inv*ifac[i-1]%mod*a[n-i]%mod+last)%mod;
            printf("%d ",last=ans);
            inv=(ll)inv*(mod+1)/2%mod;
        }
        memset(num,0,sizeof(num)),memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
    }
    return 0;
}