FFT,NTT學習筆記
阿新 • • 發佈:2019-01-01
快速傅立葉變換,可以將多項式相乘的時間複雜度從最簡單的O(n^2)優化到O(nlgn),詳細過程參考演算法導論.
FFT的流程大致是:
1):構造多項式,複雜度O(n)
2):求兩個多項式的DFT,複雜度O(nlgn)
3):構造多項式乘積的點值表示式,複雜度O(n)
4):求點值表示式的IDFT,複雜度O(nlgn).
下面是兩道最簡單的習題:
求兩個大數乘積.
因為一個大數可以看成是一個多項式,每一位上的值都表示對應次數下的係數,因此可以用FFT加速.
本體的一個坑點就是
len = l1+l2-1;
這句程式碼,可能是精度問題在len更加高位的地方出現了非0值.
#include <bits/stdc++.h> using namespace std; #define pi acos (-1) #define maxn 200010 struct plex { double x, y; plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {} plex operator + (const plex &a) const { return plex (x+a.x, y+a.y); } plex operator - (const plex &a) const { return plex (x-a.x, y-a.y); } plex operator * (const plex &a) const { return plex (x*a.x-y*a.y, x*a.y+y*a.x); } }; void change (plex *y, int len) { int i, j, k; for(i = 1, j = len / 2; i < len - 1; i++) { if (i < j) swap(y[i], y[j]); k = len / 2; while (j >= k) { j -= k; k /= 2; } if (j < k) j += k; } } void fft(plex y[],int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { plex wn(cos(-on*2*pi/h),sin(-on*2*pi/h)); for(int j = 0;j < len;j+=h) { plex w(1,0); for(int k = j;k < j+h/2;k++) { plex u = y[k]; plex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0;i < len;i++) y[i].x /= len; } char a[maxn], b[maxn]; plex x1[maxn], x2[maxn]; int ans[maxn]; int main () { while (scanf ("%s%s", a, b) == 2) { int len = 2, l1 = strlen (a), l2 = strlen (b); while (len < l1*2 || len < l2*2) len <<= 1; for (int i = 0; i < l1; i++) { x1[i] = plex (a[l1-1-i]-'0', 0); } for (int i = l1; i < len; i++) x1[i] = plex (0, 0); for (int i = 0; i < l2; i++) { x2[i] = plex (b[l2-1-i]-'0', 0); } for (int i = l2; i < len; i++) x2[i] = plex (0, 0); fft (x1, len, 1); fft (x2, len, 1); for (int i = 0; i < len; i++) x1[i] = x1[i]*x2[i]; fft (x1, len, -1); for (int i = 0; i < len; i++) { ans[i] = (int)(x1[i].x+0.5); } for (int i = 0; i < len; i++) { if (ans[i] >= 10) { ans[i+1] += ans[i]/10; ans[i] %= 10; } } len = l1+l2-1; while (ans[len] <= 0 && len > 0) len--; for (int i = len; i >= 0; i--) { printf ("%d", ans[i]); } printf ("\n"); } return 0; }
HDU 4609:點選開啟連結
題意是給出n個長度,任取3個求能組成三角形的概率.
首先記錄下每個長度的數量,然後用FFT加速求出任取兩個長度下的情況,這裡面有重複:
首先減去兩次都取同一根的情況,減完之後的結果都/2.
最後只需要所有的情況減去不能組成三角形的情況,將最初的長度序列排序後從小到大列舉下標,假設這條邊是最長邊,那麼如果所有兩條邊長度小於等於這條邊的情況就應該減去,這裡用字首和統計下就好了.
#include <bits/stdc++.h> using namespace std; #define pi acos (-1) #define maxn 611111 struct plex { double x, y; plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {} plex operator + (const plex &a) const { return plex (x+a.x, y+a.y); } plex operator - (const plex &a) const { return plex (x-a.x, y-a.y); } plex operator * (const plex &a) const { return plex (x*a.x-y*a.y, x*a.y+y*a.x); } }; void change (plex y[], int len) { if (len == 1) return ; plex a1[len/2], a2[len/2]; for (int i = 0; i < len; i += 2) { a1[i/2] = y[i]; a2[i/2] = y[i+1]; } change (a1, len>>1); change (a2, len>>1); for (int i = 0; i < len/2; i++) { y[i] = a1[i]; y[i+len/2] = a2[i]; } return ; } void fft(plex y[],int len,int on) { change(y,len); for(int h = 2; h <= len; h <<= 1) { plex wn(cos(on*2*pi/h),sin(on*2*pi/h)); for(int j = 0;j < len;j+=h) { plex w(1,0); for(int k = j;k < j+h/2;k++) { plex u = y[k]; plex t = w*y[k+h/2]; y[k] = u+t; y[k+h/2] = u-t; w = w*wn; } } } if(on == -1) for(int i = 0;i < len;i++) y[i].x /= len; } long long num[maxn], sum[maxn]; int a[maxn]; plex x[maxn]; long long n; int main () { //freopen ("in.txt", "r", stdin); int t; scanf ("%d", &t); while (t--) { scanf ("%lld", &n); long long Max = 0; memset (num, 0, sizeof num); for (int i = 1; i <= n; i++) { scanf ("%d", &a[i]); num[a[i]]++; Max = max (Max, (long long)a[i]); } Max++; int len = 2; while (len < Max*2) len <<= 1; for (int i = 0; i < len; i++) { x[i] = plex (num[i], 0); } fft (x, len, 1); for (int i = 0; i < len; i++) { x[i] = x[i]*x[i]; } fft (x, len, -1); for (int i = 0; i < len; i++) { num[i] = (long long) (x[i].x+0.5); } for (int i = 1; i <= n; i++) {//兩次取同一個 num[a[i]+a[i]]--; } for (int i = 0; i < len; i++) {//重複計算 num[i] /= 2; } sum[0] = 0; for (int i = 1; i < len; i++) { sum[i] = sum[i-1]+num[i]; } sort (a+1, a+1+n); long long tot = n*(n-1)*(n-2)/6, ans = tot; for (int i = 3; i <= n; i++) { ans -= sum[a[i]]; } printf ("%.7f\n", ans*1.0/tot); } return 0; }
但是FFT有一個很致命的弱點就是會產生精度誤差,在換成long double都不行的時候就需要用到NTT。
NTT就是用數論域中的原根代替FFT中的單位負根,其他的程式碼完全相同。
求原根的程式碼:
#include <cstdio> #include <cmath> #include <algorithm> #include <iostream> #include <vector> #include <cstring> using namespace std; int P; const int NUM = 32170; int prime[NUM/4]; bool f[NUM]; int pNum = 0; void getPrime () {//線性篩選素數 for (int i = 2; i < NUM; ++ i) { if (!f[i]) { f[i] = 1; prime[pNum++] = i; } for (int j = 0; j < pNum && i*prime[j] < NUM; ++ j) { f[i*prime[j]] = 1; if (i%prime[j] == 0) { break; } } } } long long getProduct(int a,int b,int P) {//快速求次冪mod long long ans = 1; long long tmp = a; while (b) { if (b&1) { ans = ans*tmp%P; } tmp = tmp*tmp%P; b>>=1; } return ans; } bool judge (int num) {//求num的所有的質因子 int elem[1000]; int elemNum = 0; int k = P - 1; for (int i = 0; i < pNum; ++ i) { bool flag = false; while (!(k%prime[i])) { flag = true; k /= prime[i]; } if (flag) { elem[elemNum ++] = prime[i]; } if (k == 1) { break; } if (k/prime[i]<prime[i]) { elem[elemNum ++] = prime[i]; break; } } bool flag = true; for (int i = 0; i < elemNum; ++ i) { if (getProduct (num, (P-1)/elem[i], P) == 1) { flag = false; break; } } return flag; } int main() { getPrime(); while (cin >> P) { for (int i = 2;;++i) { if (judge(i)) { cout << i<< endl; break; } } } return 0; }
HDU 1402:
隨便選一個不太大的模數和他的原根就好了。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
#include <map>
#include <vector>
#include <stack>
using namespace std;
#define mod 1004535809LL
#define G 3
#define maxn 400005
long long qpow (long long a, long long b) {
long long ret=1;
while (b) {
if (b&1) ret = (ret*a)%mod;
a = (a*a)%mod;
b >>= 1;
}
return ret;
}
void change (long long y[], int len) {
for(int i = 1, j = len / 2; i < len - 1; i++) {
if(i < j) swap(y[i], y[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void ntt(long long y[], int len, int on) {
change (y, len);
for(int h = 2; h <= len; h <<= 1) {
long long wn = qpow(G, (mod-1)/h);
if(on == -1) wn = qpow(wn, mod-2);
for(int j = 0; j < len; j += h) {
long long w = 1;
for(int k = j; k < j + h / 2; k++) {
long long u = y[k];
long long t = w * y[k + h / 2] % mod;
y[k] = (u + t) % mod;
y[k+h/2] = (u - t + mod) % mod;
w = w * wn % mod;
}
}
}
if(on == -1) {
long long t = qpow (len, mod - 2);
for(int i = 0; i < len; i++)
y[i] = y[i] * t % mod;
}
}
char a[maxn], b[maxn];
long long x1[maxn], x2[maxn];
long long ans[maxn];
int main () {
while (scanf ("%s%s", a, b) == 2) {
int len = 2, l1 = strlen (a), l2 = strlen (b);
while (len < l1*2 || len < l2*2)
len <<= 1;
//cout << len << endl;
for (int i = 0; i < l1; i++) {
x1[i] = a[l1-1-i]-'0';
}
for (int i = l1; i < len; i++)
x1[i] = 0;
for (int i = 0; i < l2; i++) {
x2[i] = b[l2-1-i]-'0';
}
for (int i = l2; i < len; i++)
x2[i] = 0;
ntt(x1, len, 1);
ntt(x2, len, 1);
for (int i = 0; i < len; i++)
x1[i] = x1[i]*x2[i]%mod;
ntt(x1, len, -1);
for (int i = 0; i < len; i++) {
ans[i] = x1[i];
}
for (int i = 0; i < len; i++) {
if (ans[i] >= 10) {
ans[i+1] += ans[i]/10;
ans[i] %= 10;
}
}
len = l1+l2-1;
while (ans[len] <= 0 && len > 0)
len--;
for (int i = len; i >= 0; i--) {
printf ("%lld", ans[i]);
}
printf ("\n");
}
return 0;
}