[AGC034F]RNG and XOR(FWT)
題面
https://atcoder.jp/contests/agc034/tasks/agc034_f
題解
前置知識
首先設得到i的期望步數為E[i]。容易看出E[0]=0以及對i>0,\(E[i]=\sum_{j=0}^{2^n-1}p[j]E[i \bigoplus j]+1\)。
將1移到左邊,可以得到\(E[i]-1=\sum_{j=0}^{2^n-1}p[j]E[i \bigoplus j]\)。①
發現這個式子很像一個異或卷積,它可以描述為三個陣列之間的關係:
\[{\{}E[0],E[1],…,E[2^n-1]{\}} \bigoplus {\{}p[0],p[1],…,p[2^n-1]{\}}={\{}?,E[1]-1,E[2]-1,…,E[2^n-1]-1{\}} \]其中\(\bigoplus\)表示兩個數列之間的異或卷積。?的存在是因為E[0]並不滿足①式。不過沒關係,我們可以算出?是什麼。中間的p陣列滿足一個性質——它們的和是1。所以,左邊所有數的和一定等於右邊所有數的和。所以\(?=2^n-1\)。
之後,我們將中間陣列的第一項-=1,就可以完美地消掉右邊的E:
\[{\{}E[0],E[1],…,E[2^n-1]{\}} \bigoplus {\{}p[0]-1,p[1],…,p[2^n-1]{\}}={\{}2^n-1,-1,-1,…,-1{\}} \]看上去我們已經做完了,因為中間和右邊的陣列都已知,求左邊還不容易嗎?只需要把右邊進行FWT,再與中間FWT後的結果的逆元逐項相乘,再UFWT就是左邊了。
問題就出在“逆元”上。
0是沒有逆元的,而我們仔細思考,發現中間FWT以後,是會出現0的!出現0的就是第一項,這一項在FWT以後會變成\(p[0]-1+p[1]+…+p[2^n-1]\),也就是0。其他項可以證明不是0。核對右邊的第一項,FWT以後變成\(2^n-1-1-1…-1\)也是0。感覺我們少了個條件,做不下去了。
並不是!其實E中也有一項已知的,就是\(E[0]=0\)。它可是有巨大的用途的。根據我們的推導,\({\{}E[0],E[1],…,E[2^n-1]{\}}\)FWT以後的結果是\({\{}q[0],q[1],…,q[2^n-1]{\}}\),這個q的1到\(2^n-1\)都已知了,只有q[0]未知。左邊是E[0]已知,其他未知。根據異或卷積的公式,一定存在下列關係
其中\(q^{(i)}[j]\)表示FWT中進行第i輪迴圈時,q[j]的值。由於\(q^{(1)}[2^0],q^{(2)}[2^1],…,q^{(n)}[2^{n-1}]\)這些值都與q[0]線性無關,所以簡化為
\[\frac{q[0]}{2^n}+M=0 \]這個M極其繁瑣,怎麼求?
強行代q[0]=0,算一個UFWT,算出來\(E[0]=k\)。這代表了\(\frac{0}{2^n}+M=k\),M自然求出。
至此,q[0]求出,本題結束。總時間複雜度\(O(n2^n)\)。
程式碼
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define rg register
#define In inline
const int N = 262144;
const ll mod = 998244353;
const ll iv2 = 499122177;
In ll read(){
ll s = 0,ww = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
return s * ww;
}
In void write(ll x){
if(x < 0)putchar('-'),x = -x;
if(x > 9)write(x / 10);
putchar('0' + x % 10);
}
namespace ModCalc{
In void Inc(ll &x,ll y){
x += y;if(x >= mod)x -= mod;
}
In void Dec(ll &x,ll y){
x -= y;if(x < 0)x += mod;
}
In ll Add(ll x,ll y){
Inc(x,y);return x;
}
In ll Sub(ll x,ll y){
Dec(x,y);return x;
}
}
using namespace ModCalc;
In ll power(ll a,ll n){
ll s = 1,x = a;
while(n){
if(n & 1)s = s * x % mod;
x = x * x % mod;
n >>= 1;
}
return s;
}
ll n,deg;
ll p[N+5],q[N+5],temp[N+5];
In void calc(ll &x,ll &y,int opt){
if(opt == 1){
ll X = Add(x,y),Y = Sub(x,y);
x = X,y = Y;
}
else{
ll X = Add(x,y) * iv2 % mod,Y = Sub(x,y) * iv2 % mod;
x = X,y = Y;
}
}
void FWT(ll a[],ll deg,int opt){
for(rg int n = 2;n <= deg;n <<= 1){
int m = n >> 1;
for(rg int i = 0;i < deg;i += n){
for(rg int j = 0;j < m;j++)calc(a[i+j],a[i+j+m],opt);
}
}
}
int main(){
n = read();
deg = 1ll << n;
ll s = 0;
for(rg int i = 0;i < deg;i++)p[i] = read(),s += p[i];
ll iv = power(s,mod - 2);
for(rg int i = 0;i < deg;i++)p[i] = p[i] * iv % mod;
Dec(p[0],1);
FWT(p,deg,1);
for(rg int i = 0;i < deg;i++)q[i] = (i == 0 ? deg - 1 : -1);
FWT(q,deg,1);
for(rg int i = 1;i < deg;i++)q[i] = q[i] * power(p[i],mod - 2) % mod;
memcpy(temp,q,sizeof(temp));
FWT(temp,deg,-1);
Dec(q[0],temp[0] * deg % mod);
FWT(q,deg,-1);
for(rg int i = 0;i < deg;i++)write(q[i]),putchar('\n');
return 0;
}