1. 程式人生 > 實用技巧 >[AGC034F]RNG and XOR(FWT)

[AGC034F]RNG and XOR(FWT)

題面

https://atcoder.jp/contests/agc034/tasks/agc034_f

題解

前置知識

首先設得到i的期望步數為E[i]。容易看出E[0]=0以及對i>0,\(E[i]=\sum_{j=0}^{2^n-1}p[j]E[i \bigoplus j]+1\)

將1移到左邊,可以得到\(E[i]-1=\sum_{j=0}^{2^n-1}p[j]E[i \bigoplus j]\)。①

發現這個式子很像一個異或卷積,它可以描述為三個陣列之間的關係:

\[{\{}E[0],E[1],…,E[2^n-1]{\}} \bigoplus {\{}p[0],p[1],…,p[2^n-1]{\}}={\{}?,E[1]-1,E[2]-1,…,E[2^n-1]-1{\}} \]

其中\(\bigoplus\)表示兩個數列之間的異或卷積。?的存在是因為E[0]並不滿足①式。不過沒關係,我們可以算出?是什麼。中間的p陣列滿足一個性質——它們的和是1。所以,左邊所有數的和一定等於右邊所有數的和。所以\(?=2^n-1\)

之後,我們將中間陣列的第一項-=1,就可以完美地消掉右邊的E:

\[{\{}E[0],E[1],…,E[2^n-1]{\}} \bigoplus {\{}p[0]-1,p[1],…,p[2^n-1]{\}}={\{}2^n-1,-1,-1,…,-1{\}} \]

看上去我們已經做完了,因為中間和右邊的陣列都已知,求左邊還不容易嗎?只需要把右邊進行FWT,再與中間FWT後的結果的逆元逐項相乘,再UFWT就是左邊了。

問題就出在“逆元”上。

0是沒有逆元的,而我們仔細思考,發現中間FWT以後,是會出現0的!出現0的就是第一項,這一項在FWT以後會變成\(p[0]-1+p[1]+…+p[2^n-1]\),也就是0。其他項可以證明不是0。核對右邊的第一項,FWT以後變成\(2^n-1-1-1…-1\)也是0。感覺我們少了個條件,做不下去了。

並不是!其實E中也有一項已知的,就是\(E[0]=0\)。它可是有巨大的用途的。根據我們的推導,\({\{}E[0],E[1],…,E[2^n-1]{\}}\)FWT以後的結果是\({\{}q[0],q[1],…,q[2^n-1]{\}}\),這個q的1到\(2^n-1\)都已知了,只有q[0]未知。左邊是E[0]已知,其他未知。根據異或卷積的公式,一定存在下列關係

\[\frac{\frac{\frac{\frac{q[0]+q^{(1)}[2^0]}{2}+q^{(2)}[2^1]}{2}+q^{(3)}[2^2]}{…}+q^{(n)}[2^{n-1}]}{2}=E[0]=0 \]

其中\(q^{(i)}[j]\)表示FWT中進行第i輪迴圈時,q[j]的值。由於\(q^{(1)}[2^0],q^{(2)}[2^1],…,q^{(n)}[2^{n-1}]\)這些值都與q[0]線性無關,所以簡化為

\[\frac{q[0]}{2^n}+M=0 \]

這個M極其繁瑣,怎麼求?

強行代q[0]=0,算一個UFWT,算出來\(E[0]=k\)。這代表了\(\frac{0}{2^n}+M=k\),M自然求出。

至此,q[0]求出,本題結束。總時間複雜度\(O(n2^n)\)

程式碼

#include<bits/stdc++.h>

using namespace std;

#define ll long long
#define rg register
#define In inline

const int N = 262144;
const ll mod = 998244353;
const ll iv2 = 499122177;

In ll read(){
	ll s = 0,ww = 1;
	char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
	while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
	return s * ww;
}

In void write(ll x){
	if(x < 0)putchar('-'),x = -x;
	if(x > 9)write(x / 10);
	putchar('0' + x % 10);
}

namespace ModCalc{
	In void Inc(ll &x,ll y){
		x += y;if(x >= mod)x -= mod;
	}
	In void Dec(ll &x,ll y){
		x -= y;if(x < 0)x += mod;
	}
	In ll Add(ll x,ll y){
		Inc(x,y);return x;
	}
	In ll Sub(ll x,ll y){
		Dec(x,y);return x;
	}
}
using namespace ModCalc;

In ll power(ll a,ll n){
	ll s = 1,x = a;
	while(n){
		if(n & 1)s = s * x % mod;
		x = x * x % mod;
		n >>= 1;
	}
	return s;
}

ll n,deg;
ll p[N+5],q[N+5],temp[N+5];

In void calc(ll &x,ll &y,int opt){
	if(opt == 1){
		ll X = Add(x,y),Y = Sub(x,y);
		x = X,y = Y;
	}
	else{
		ll X = Add(x,y) * iv2 % mod,Y = Sub(x,y) * iv2 % mod; 
		x = X,y = Y;
	}
}

void FWT(ll a[],ll deg,int opt){
	for(rg int n = 2;n <= deg;n <<= 1){
		int m = n >> 1;
		for(rg int i = 0;i < deg;i += n){
			for(rg int j = 0;j < m;j++)calc(a[i+j],a[i+j+m],opt);
		}
	}
}

int main(){
	n = read();
	deg = 1ll << n;
	ll s = 0;
	for(rg int i = 0;i < deg;i++)p[i] = read(),s += p[i];
	ll iv = power(s,mod - 2);
	for(rg int i = 0;i < deg;i++)p[i] = p[i] * iv % mod;
	Dec(p[0],1);
	FWT(p,deg,1);
	for(rg int i = 0;i < deg;i++)q[i] = (i == 0 ? deg - 1 : -1);
	FWT(q,deg,1);
	for(rg int i = 1;i < deg;i++)q[i] = q[i] * power(p[i],mod - 2) % mod;
	memcpy(temp,q,sizeof(temp));
	FWT(temp,deg,-1);
	Dec(q[0],temp[0] * deg % mod);
	FWT(q,deg,-1);
	for(rg int i = 0;i < deg;i++)write(q[i]),putchar('\n');
	return 0;
}