1. 程式人生 > 其它 >[NOIP2021] 數列

[NOIP2021] 數列

洛谷題面

感覺這道題純動態規劃的邊界等問題非常麻煩,所以這裡採用記憶化搜尋。

題目大意

給出 \(n,m,k\)\(val_0\cdots val_m\),定義一個值 \(\in [0,m]\) 的序列 \(a\),其權值為 \(\prod\limits_{i=1}^{n} val_{a_i}\)

我們稱 \(S\) 滿足條件當且僅當 \(S=\sum\limits_{i=1}^{n} 2^{a_i}\) 的二進位制表示中,\(1\) 的個數小於等於 \(k\)。此時,也稱序列 \(a\) 為合法序列。

求所有合法序列 \(a\) 的權值和 \(\mod 998244353\) 的結果。

題目分析

\(dfs(bit,now,x,y)\) 表示:

\(S\) 從低到高二進位制的 \(bit\) 位中,用了序列 \(a\) 的前 \(now\) 個數,此時 \(S\) 二進位制下有 \(x\)\(1\),上一位(第 \(bit+1\) 位)進位為 \(y\)

\(mem[biw][now][x][y]\) 則儲存答案。

於是,我們有:

\[mem[bit][now][x][y]=\sum\limits_{i=0}^{n-now}{mem[bit][now+i][x+(y+i)\%2][\left\lfloor\frac{y+i}{2}\right\rfloor])\times sum[bit][i]\times C_{now+i}^{i}} \]

其中 \(C_{i}^{j}\)

表示組合數,\(sum[i][j]\) 表示:

for(register int i=0;i<=m;i++)
{
	sum[i][0]=1;
		
	for(register int j=1;j<=n;j++)
	{
		sum[i][j]=sum[i][j-1]*val[i]%mod;
	}
}

可以看到,\(sum[i][j]\) 主要作用類似於字首和,目的是簡化計算。


邊界部分:

當前轉移到 \(dfs(bit,now,x,y)\)

  • \(now=n\)

\(x+getcnt(y)>k\) 時,返回 \(0\)。表示不需要繼續轉移了。

否則返回 \(1\)

  • \(bit>m\) 則直接返回。

  • \(mem[bit][now][x][y]\) 有數則直接返回該數。

程式碼

//2021/11/30

//2021/12/1

//2021/12/2

#define _CRT_SECURE_NO_WARNINGS

#include <iostream>

#include <cstdio>

#include <climits>//need "INT_MAX","INT_MIN"

#include <cstring>

#define int long long

#define enter() putchar(10)

#define debug(c,que) cerr<<#c<<" = "<<c<<que

#define cek(c) puts(c)

#define blow(arr,st,ed,w) for(register int i=(st);i<=(ed);i++)cout<<arr[i]<<w;

#define speed_up() cin.tie(0),cout.tie(0)

#define endl "\n"

#define Input_Int(n,a) for(register int i=1;i<=n;i++)scanf("%d",a+i);

#define Input_Long(n,a) for(register long long i=1;i<=n;i++)scanf("%lld",a+i);

namespace Newstd
{
	inline int read()
	{
		int x=0,k=1;
		char ch=getchar();
		while(ch<'0' || ch>'9')
		{
			if(ch=='-')
			{
				k=-1;
			}
			ch=getchar();
		}
		while(ch>='0' && ch<='9')
		{
			x=(x<<1)+(x<<3)+ch-'0';
			ch=getchar();
		}
		return x*k;
	}
	inline void write(int x)
	{
		if(x<0)
		{
			putchar('-');
			x=-x;
		}
		if(x>9)
		{
			write(x/10);
		}
		putchar(x%10+'0');
	}
}

using namespace Newstd;

using namespace std;

const int mod=998244353;

const int MA_1=105;

const int MA_2=35;

int val[MA_1];

int C[MA_1][MA_1],sum[MA_1][MA_1];

int mem[MA_1][MA_2][MA_2][MA_2]; 

int n,m,k;

inline void init()
{
	C[0][0]=1;
	
	for(register int i=1;i<=n;i++)
	{
		C[i][0]=1;
		
		for(register int j=1;j<=i;j++)
		{
			C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod; 
		}
	}
}

inline int lowbit(int x)
{
	return x&-x;
}

inline int getcnt(int x)
{
	int ans(0);
	
	while(x!=0)
	{
		x-=lowbit(x);
		
		ans++;
	}
	
	return ans;
}

//dfs(k,now,x,y)
//S從低到高二進位制的 bit 位中,用了數列 a 的前 now 項,且此時 S 中共有 x 個二進位制位為 1,第 now+1 位進了 y 過去 
inline int dfs(int bit,int now,int x,int y)
{
	if(now==n)
	{
		if(x+getcnt(y)>k)
		{
			return 0;
		}
		
		return 1;
	}
	
	if(bit>m)
	{
		return 0;
	}
	
	if(mem[bit][now][x][y]!=-1)
	{
		return mem[bit][now][x][y];
	}
	
	int ans(0);
	
	for(register int i=0;i<=n-now;i++)
	{
		ans=(ans+dfs(bit+1,now+i,x+(y+i)%2,(y+i)/2)*sum[bit][i]%mod*C[now+i][i]%mod)%mod; 
	}
	
	return mem[bit][now][x][y]=ans;
}

#undef int

int main(void)
{
	#define int long long
	
	memset(mem,-1,sizeof(mem));
	
	n=read(),m=read(),k=read();
	
	init();
	
	for(register int i=0;i<=m;i++)
	{
		val[i]=read();
	}
	
	for(register int i=0;i<=m;i++)
	{
		sum[i][0]=1;
		
		for(register int j=1;j<=n;j++)
		{
			sum[i][j]=sum[i][j-1]*val[i]%mod;
		}
	}
	
	printf("%lld\n",dfs(0,0,0,0));
	
	return 0;
}