騰訊《王者榮耀》AI“王者絕悟”亮相 2021 世界人工智慧大會,戰勝 5 名職業選手
阿新 • • 發佈:2021-07-09
單位根的定義就不說了。
顯然有:
\[\omega_n^k=\cos {2\pi k \over n}+i\sin {2\pi k \over n} \]帶入這個可以直接證得:
\[\omega_{2n}^{2k}=\cos {2\pi \cdot 2k \over 2n}+i\sin {2\pi \cdot 2k \over2 n}=\omega_n^k \]我們用影象理性理解可得(就是繞著原點把向量轉了180度):
\[\omega_n^{k+{n\over 2}}=-\omega_n^k \]依然根據影象可得(就是繞著原點旋轉了360度):
\[\omega_n^{k+n}=\omega_n^k \]由尤拉公式得:
所以有:
\[\omega_n^k=e^{i\cdot{2\pi k \over n}}=(e^{i\cdot{2\pi \over n}})^k=(\omega_n^1)^k \]我們先帶入\(\omega_n^1,\omega_n^2,...,\omega_n^n\)到多項式\(A\)中,求出\(A(\omega_n^1),A(\omega_n^2),...,A(\omega_n^n)\)。
為了方便,假設長度\(n\)
把A下標奇偶分類:
\[A_1(x)=a_0+a_2 x+a_4 x^2+... \]\[A_2(x)=a_1+a_3 x+a+5 x^2+... \]顯然有:
\[A(x)=A_1(x^2)+xA_2(x^2) \]\[\therefore A(\omega_{n}^{k})=A_1(\omega_{n}^{2k})+\omega_{n}^{k}A_2(\omega_{n}^{2k}) \]\[=A_1(\omega_{n\over 2}^{k})+\omega_{n}^{k}A_2(\omega_{n\over 2}^{k}) \]\[A(\omega_{n}^{k+{n\over 2}})=A_1(\omega_{n}^{2k+n})-\omega_{n}^{k}A_2(\omega_{n}^{2k+n}) \]\[=A_1(\omega_{n\over 2}^{k})-\omega_{n}^{k}A_2(\omega_{n\over 2}^{k}) \]這兩個式子只有後面一項是相反的,可以遞迴求解。
於是給出程式碼:
inline void FFT(complex<double> *a, int len) {
if (!len) return ; complex<double> a1[len], a2[len];
for (int i = 0; i < len; ++i) a1[i] = a[i << 1], a2[i] = a[i << 1 | 1];
FFT(a1, len >> 1); FFT(a2, len >> 1);
complex<double> w(cos(PI / len), sin(PI / len)), wk(1, 0);
for (int i = 0; i < len; ++i, wk *= w)
a[i] = a1[i] + wk * a2[i], a[i + len] = a1[i] - wk * a2[i];
}
考慮怎麼從點值多項式轉換到係數多項式。
我們欽定\(y_i=A(\omega_n^i)\),在有一多項式\(C\),滿足:
\[C(x)=\sum y_i x^i \]則我們帶入\(\omega_n^{-k}\),得到:
\[C(\omega_n^{-k})=c_k=\sum_{i=0}^{n-1} y_i (\omega_n^{-k})^i \]\[=\sum_{i=0}^{n-1} [\sum_{j=0}^{n-1} a_j(\omega_n^{i})^j ](\omega_n^{-k})^i \]\[=\sum_{i=0}^{n-1} \sum_{j=0}^{n-1} a_j(\omega_n^{j})^i(\omega_n^{-k})^i \]\[=\sum_{i=0}^{n-1} \sum_{j=0}^{n-1} a_j(\omega_n^{j-k})^i \]\[=\sum_{i=0}^{n-1} a_i \sum_{j=0}^{n-1}(\omega_n^{i-k})^j \]設:
\[S(\omega_n^k)=\sum_{i=0}^{n-1} (\omega_n^k)^i={(\omega_n^k)^{n}-1\over \omega_n^k-1} \]當\(k\neq 0\)時為0,否則為\(n\)
則:
\[\sum_{j=0}^{n-1}(\omega_n^{i-k})^j=S(\omega_n^{i-k}) \]即當\(i=k\)時為\(n\),所以:
\[c_k=\sum_{i=0}^{n-1} a_i \sum_{j=0}^{n-1}(\omega_n^{i-k})^j=na_k \]我們驚訝的發現這樣對\(C\)做一次FFT之後點值除以n就是多項式的係數了。
程式碼結合一下:
inline void FFT(complex<double> *a, int len, int flag) {
if (!len) return ; complex<double> a1[len], a2[len];
for (int i = 0; i < len; ++i) a1[i] = a[i << 1], a2[i] = a[i << 1 | 1];
FFT(a1, len >> 1, flag); FFT(a2, len >> 1, flag);
complex<double> w(cos(PI / len), flag * sin(PI / len)), wk(1, 0);
for (int i = 0; i < len; ++i, wk *= w)
a[i] = a1[i] + wk * a2[i], a[i + len] = a1[i] - wk * a2[i];
}
發現遞迴版的碼會T,手玩一下發現實際上奇偶變換後下標的操作相當於二進位制反過來,可以改成非遞迴來模擬,自己對著碼手玩看看就明白。
inline void FFT(complex *a, int type) {
for (int i = 0; i < lim; ++i)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
complex wn; wn = complex(cos(pi / mid), type * sin(pi / mid));
for (int j = 0; j < lim; j += mid << 1) {
complex bas; bas = complex(1, 0);
for (int k = 0; k < mid; ++k, bas = bas * wn) {
complex x = a[j + k], y = bas * a[j + mid + k];
a[j + k] = x + y;
a[j + mid + k] = x - y;
}
}
}
}
解釋一下,第一層迴圈列舉的是遞迴的層數,即當前合併的兩個多項式的長度。第二層就是列舉當前要合併多項式的起點,第三層就是列舉的具體的那一個係數。這麼說不是很清楚,還是自己造樣例跟著程式碼手玩一下就明白了。
而NTT呢?設模數為p,g是p的原根,則不需要證明的給出,\(\omega_n^1\)等價於\(g^{p-1\over n} \bmod p\)。把上面的碼程式碼裡的wn換成這個就行了。一般p=998244353,此時g=3。
給個板子:
struct poly {
int n;
vector<ll> x;
inline void NTT(int flag) {
for (int i = 0; i < n; ++i)
if (i < rev[i]) swap(x[i], x[rev[i]]);
for (int mid = 1; mid < n; mid <<= 1) {
ll wn = power(flag == 1 ? G : Gi, (mod - 1) / (mid << 1));
for (int j = 0; j < n; j += mid << 1) {
ll bas = 1;
for (int k = 0; k < mid; ++k, bas = (bas * wn) % mod) {
ll xx = x[j + k], y = (bas * x[j + mid + k]) % mod;
x[j + k] = (xx + y) % mod;
x[j + mid + k] = ((xx - y) % mod + mod) % mod;
}
} cerr << endl;
}
}
};
inline int max_(int a, int b) {
return a > b ? a : b;
}
inline poly mul(poly A, poly B) {
poly a, b; a = A; b = B;
int tmp = a.n + b.n; a.n = 1;
int L = 0;
while (a.n <= tmp) a.n <<= 1, ++L;
for (int i = 0; i <= a.n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L - 1);
b.n = a.n;
a.NTT(1); b.NTT(1);
for (int i = 0; i < a.n; ++i) a.x[i] = (a.x[i] * b.x[i]) % mod;
a.NTT(-1);
const ll inv = power(a.n, mod - 2);
a.n = tmp;
for (int i = 0; i <= a.n; ++i) a.x[i] = (a.x[i] * inv) % mod;
return a;
}