[NOIP2021] 數列
阿新 • • 發佈:2021-12-02
感覺這道題純動態規劃的邊界等問題非常麻煩,所以這裡採用記憶化搜尋。
題目大意
給出 \(n,m,k\) 及 \(val_0\cdots val_m\),定義一個值 \(\in [0,m]\) 的序列 \(a\),其權值為 \(\prod\limits_{i=1}^{n} val_{a_i}\)
我們稱 \(S\) 滿足條件當且僅當 \(S=\sum\limits_{i=1}^{n} 2^{a_i}\) 的二進位制表示中,\(1\) 的個數小於等於 \(k\)。此時,也稱序列 \(a\) 為合法序列。
求所有合法序列 \(a\) 的權值和 \(\mod 998244353\) 的結果。
題目分析
令 \(dfs(bit,now,x,y)\) 表示:
\(S\) 從低到高二進位制的 \(bit\) 位中,用了序列 \(a\) 的前 \(now\) 個數,此時 \(S\) 二進位制下有 \(x\) 個 \(1\),上一位(第 \(bit+1\) 位)進位為 \(y\)。
\(mem[biw][now][x][y]\) 則儲存答案。
於是,我們有:
\[mem[bit][now][x][y]=\sum\limits_{i=0}^{n-now}{mem[bit][now+i][x+(y+i)\%2][\left\lfloor\frac{y+i}{2}\right\rfloor])\times sum[bit][i]\times C_{now+i}^{i}} \]其中 \(C_{i}^{j}\)
for(register int i=0;i<=m;i++)
{
sum[i][0]=1;
for(register int j=1;j<=n;j++)
{
sum[i][j]=sum[i][j-1]*val[i]%mod;
}
}
可以看到,\(sum[i][j]\) 主要作用類似於字首和,目的是簡化計算。
邊界部分:
當前轉移到 \(dfs(bit,now,x,y)\)。
- 若 \(now=n\):
當 \(x+getcnt(y)>k\) 時,返回 \(0\)。表示不需要繼續轉移了。
否則返回 \(1\)
-
若 \(bit>m\) 則直接返回。
-
若 \(mem[bit][now][x][y]\) 有數則直接返回該數。
程式碼
//2021/11/30
//2021/12/1
//2021/12/2
#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstdio>
#include <climits>//need "INT_MAX","INT_MIN"
#include <cstring>
#define int long long
#define enter() putchar(10)
#define debug(c,que) cerr<<#c<<" = "<<c<<que
#define cek(c) puts(c)
#define blow(arr,st,ed,w) for(register int i=(st);i<=(ed);i++)cout<<arr[i]<<w;
#define speed_up() cin.tie(0),cout.tie(0)
#define endl "\n"
#define Input_Int(n,a) for(register int i=1;i<=n;i++)scanf("%d",a+i);
#define Input_Long(n,a) for(register long long i=1;i<=n;i++)scanf("%lld",a+i);
namespace Newstd
{
inline int read()
{
int x=0,k=1;
char ch=getchar();
while(ch<'0' || ch>'9')
{
if(ch=='-')
{
k=-1;
}
ch=getchar();
}
while(ch>='0' && ch<='9')
{
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
return x*k;
}
inline void write(int x)
{
if(x<0)
{
putchar('-');
x=-x;
}
if(x>9)
{
write(x/10);
}
putchar(x%10+'0');
}
}
using namespace Newstd;
using namespace std;
const int mod=998244353;
const int MA_1=105;
const int MA_2=35;
int val[MA_1];
int C[MA_1][MA_1],sum[MA_1][MA_1];
int mem[MA_1][MA_2][MA_2][MA_2];
int n,m,k;
inline void init()
{
C[0][0]=1;
for(register int i=1;i<=n;i++)
{
C[i][0]=1;
for(register int j=1;j<=i;j++)
{
C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
}
}
}
inline int lowbit(int x)
{
return x&-x;
}
inline int getcnt(int x)
{
int ans(0);
while(x!=0)
{
x-=lowbit(x);
ans++;
}
return ans;
}
//dfs(k,now,x,y)
//S從低到高二進位制的 bit 位中,用了數列 a 的前 now 項,且此時 S 中共有 x 個二進位制位為 1,第 now+1 位進了 y 過去
inline int dfs(int bit,int now,int x,int y)
{
if(now==n)
{
if(x+getcnt(y)>k)
{
return 0;
}
return 1;
}
if(bit>m)
{
return 0;
}
if(mem[bit][now][x][y]!=-1)
{
return mem[bit][now][x][y];
}
int ans(0);
for(register int i=0;i<=n-now;i++)
{
ans=(ans+dfs(bit+1,now+i,x+(y+i)%2,(y+i)/2)*sum[bit][i]%mod*C[now+i][i]%mod)%mod;
}
return mem[bit][now][x][y]=ans;
}
#undef int
int main(void)
{
#define int long long
memset(mem,-1,sizeof(mem));
n=read(),m=read(),k=read();
init();
for(register int i=0;i<=m;i++)
{
val[i]=read();
}
for(register int i=0;i<=m;i++)
{
sum[i][0]=1;
for(register int j=1;j<=n;j++)
{
sum[i][j]=sum[i][j-1]*val[i]%mod;
}
}
printf("%lld\n",dfs(0,0,0,0));
return 0;
}