1. 程式人生 > >FFT,NTT學習筆記

FFT,NTT學習筆記

快速傅立葉變換,可以將多項式相乘的時間複雜度從最簡單的O(n^2)優化到O(nlgn),詳細過程參考演算法導論.

FFT的流程大致是:

1):構造多項式,複雜度O(n)

2):求兩個多項式的DFT,複雜度O(nlgn)

3):構造多項式乘積的點值表示式,複雜度O(n)

4):求點值表示式的IDFT,複雜度O(nlgn).

下面是兩道最簡單的習題:

求兩個大數乘積.

因為一個大數可以看成是一個多項式,每一位上的值都表示對應次數下的係數,因此可以用FFT加速.

本體的一個坑點就是

len = l1+l2-1;
這句程式碼,可能是精度問題在len更加高位的地方出現了非0值.
#include <bits/stdc++.h>
using namespace std;
#define pi acos (-1)
#define maxn 200010

struct plex {
    double x, y;
    plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
    plex operator + (const plex &a) const {
        return plex (x+a.x, y+a.y);
    }
    plex operator - (const plex &a) const {
        return plex (x-a.x, y-a.y);
    }
    plex operator * (const plex &a) const {
        return plex (x*a.x-y*a.y, x*a.y+y*a.x);
    }
};

void change (plex *y, int len) {
    int i, j, k;
    for(i = 1, j = len / 2; i < len - 1; i++) {
        if (i < j) swap(y[i], y[j]);
        k = len / 2;
        while (j >= k) {
            j -= k;
            k /= 2;
        }
        if (j < k) j += k;
    }
}

void fft(plex y[],int len,int on)
{
    change(y,len);
    for(int h = 2; h <= len; h <<= 1)
    {
        plex wn(cos(-on*2*pi/h),sin(-on*2*pi/h));
        for(int j = 0;j < len;j+=h)
        {
            plex w(1,0);
            for(int k = j;k < j+h/2;k++)
            {
                plex u = y[k];
                plex t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;
            }
        }
    }
    if(on == -1)
        for(int i = 0;i < len;i++)
            y[i].x /= len;
}
char a[maxn], b[maxn];
plex x1[maxn], x2[maxn];
int ans[maxn];

int main () {
    while (scanf ("%s%s", a, b) == 2) {
        int len = 2, l1 = strlen (a), l2 = strlen (b);
        while (len < l1*2 || len < l2*2)
            len <<= 1;
        for (int i = 0; i < l1; i++) {
            x1[i] = plex (a[l1-1-i]-'0', 0);
        }
        for (int i = l1; i < len; i++)
            x1[i] = plex (0, 0);
        for (int i = 0; i < l2; i++) {
            x2[i] = plex (b[l2-1-i]-'0', 0);
        }
        for (int i = l2; i < len; i++)
            x2[i] = plex (0, 0);
        fft (x1, len, 1);
        fft (x2, len, 1);
        for (int i = 0; i < len; i++)
            x1[i] = x1[i]*x2[i];
        fft (x1, len, -1);
        for (int i = 0; i < len; i++) {
            ans[i] = (int)(x1[i].x+0.5);
        }
        for (int i = 0; i < len; i++) {
            if (ans[i] >= 10) {
                ans[i+1] += ans[i]/10;
                ans[i] %= 10;
            }
        }
        len = l1+l2-1;
        while (ans[len] <= 0 && len > 0)
            len--;
        for (int i = len; i >= 0; i--) {
            printf ("%d", ans[i]);
        }
        printf ("\n");
    }
    return 0;
}

HDU 4609:點選開啟連結

題意是給出n個長度,任取3個求能組成三角形的概率.

首先記錄下每個長度的數量,然後用FFT加速求出任取兩個長度下的情況,這裡面有重複:

首先減去兩次都取同一根的情況,減完之後的結果都/2.

最後只需要所有的情況減去不能組成三角形的情況,將最初的長度序列排序後從小到大列舉下標,假設這條邊是最長邊,那麼如果所有兩條邊長度小於等於這條邊的情況就應該減去,這裡用字首和統計下就好了.

#include <bits/stdc++.h>
using namespace std;
#define pi acos (-1)
#define maxn 611111

struct plex {
    double x, y;
    plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
    plex operator + (const plex &a) const {
        return plex (x+a.x, y+a.y);
    }
    plex operator - (const plex &a) const {
        return plex (x-a.x, y-a.y);
    }
    plex operator * (const plex &a) const {
        return plex (x*a.x-y*a.y, x*a.y+y*a.x);
    }
};

void change (plex y[], int len) {
    if (len == 1)
        return ;
    plex a1[len/2], a2[len/2];
    for (int i = 0; i < len; i += 2) {
        a1[i/2] = y[i];
        a2[i/2] = y[i+1];
    }
    change (a1, len>>1);
    change (a2, len>>1);
    for (int i = 0; i < len/2; i++) {
        y[i] = a1[i];
        y[i+len/2] = a2[i];
    }
    return  ;
}

void fft(plex y[],int len,int on)
{
    change(y,len);
    for(int h = 2; h <= len; h <<= 1)
    {
        plex wn(cos(on*2*pi/h),sin(on*2*pi/h));
        for(int j = 0;j < len;j+=h)
        {
            plex w(1,0);
            for(int k = j;k < j+h/2;k++)
            {
                plex u = y[k];
                plex t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;
            }
        }
    }
    if(on == -1)
        for(int i = 0;i < len;i++)
            y[i].x /= len;
}
long long num[maxn], sum[maxn];
int a[maxn];
plex x[maxn];
long long n;

int main () {
    //freopen ("in.txt", "r", stdin);
    int t;
    scanf ("%d", &t);
    while (t--) {
        scanf ("%lld", &n);
        long long Max = 0;
        memset (num, 0, sizeof num);
        for (int i = 1; i <= n; i++) {
            scanf ("%d", &a[i]);
            num[a[i]]++;
            Max = max (Max, (long long)a[i]);
        }
        Max++;
        int len = 2;
        while (len < Max*2)
            len <<= 1;
        for (int i = 0; i < len; i++) {
            x[i] = plex (num[i], 0);
        }
        fft (x, len, 1);
        for (int i = 0; i < len; i++) {
            x[i] = x[i]*x[i];
        }
        fft (x, len, -1);
        for (int i = 0; i < len; i++) {
            num[i] = (long long) (x[i].x+0.5);
        }
        for (int i = 1; i <= n; i++) {//兩次取同一個
            num[a[i]+a[i]]--;
        }
        for (int i = 0; i < len; i++) {//重複計算
            num[i] /= 2;
        }
        sum[0] = 0;
        for (int i = 1; i < len; i++) {
            sum[i] = sum[i-1]+num[i];
        }
        sort (a+1, a+1+n);
        long long tot = n*(n-1)*(n-2)/6, ans = tot;
        for (int i = 3; i <= n; i++) {
            ans -= sum[a[i]];
        }
        printf ("%.7f\n", ans*1.0/tot);
    }
    return 0;
}


但是FFT有一個很致命的弱點就是會產生精度誤差,在換成long double都不行的時候就需要用到NTT。

NTT就是用數論域中的原根代替FFT中的單位負根,其他的程式碼完全相同。

求原根的程式碼:

#include <cstdio>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
int P;
const int NUM = 32170;
int prime[NUM/4];
bool f[NUM];
int pNum = 0;
void getPrime () {//線性篩選素數
    for (int i = 2; i < NUM; ++ i) {
        if (!f[i]) {
            f[i] = 1;
            prime[pNum++] = i;
        }
        for (int j = 0; j < pNum && i*prime[j] < NUM; ++ j) {
            f[i*prime[j]] = 1;
            if (i%prime[j] == 0) {
                break;
            }
        }
    }
}
long long getProduct(int a,int b,int P) {//快速求次冪mod
    long long ans = 1;
    long long tmp = a;
    while (b)
    {
        if (b&1)
        {
            ans = ans*tmp%P;
        }
        tmp = tmp*tmp%P;
        b>>=1;
    }
    return ans;
}

bool judge (int num) {//求num的所有的質因子
    int elem[1000];
    int elemNum = 0;
    int k = P - 1;
    for (int i = 0; i < pNum; ++ i) {
        bool flag = false;
        while (!(k%prime[i])) {
            flag = true;
            k /= prime[i];
        }
        if (flag) {
            elem[elemNum ++] = prime[i];
        }
        if (k == 1) {
            break;
        }
        if (k/prime[i]<prime[i]) {
            elem[elemNum ++] = prime[i];
            break;
        }
    }
    bool flag = true;
    for (int i = 0; i < elemNum; ++ i) {
        if (getProduct (num, (P-1)/elem[i], P) == 1) {
            flag = false;
            break;
        }
    }
    return flag;
}

int main()  
{  
      
    getPrime();  
    while (cin >> P)  
    {  
        for (int i = 2;;++i)  
        {  
            if (judge(i))  
            {  
                cout << i<< endl;  
                break;  
            }  
        }  
    }  
    return 0;  
}  



HDU 1402:

隨便選一個不太大的模數和他的原根就好了。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
#include <map>
#include <vector>
#include <stack>
using namespace std;
#define mod 1004535809LL
#define G 3
#define maxn 400005

long long qpow (long long a, long long b) {  
    long long ret=1;  
    while (b) {  
        if (b&1) ret = (ret*a)%mod;  
        a = (a*a)%mod;  
        b >>= 1;  
    }  
    return ret;  
}  

void change (long long y[], int len) {
    for(int i = 1, j = len / 2; i < len - 1; i++) {
        if(i < j) swap(y[i], y[j]);
        int k = len / 2;
        while(j >= k) {
            j -= k;
            k /= 2;
        }
        if(j < k) j += k;
    }
}
void ntt(long long y[], int len, int on) {
    change (y, len);
    for(int h = 2; h <= len; h <<= 1) {
        long long wn = qpow(G, (mod-1)/h);
        if(on == -1) wn = qpow(wn, mod-2);
        for(int j = 0; j < len; j += h) {
            long long w = 1;
            for(int k = j; k < j + h / 2; k++) {
                long long u = y[k];
                long long t = w * y[k + h / 2] % mod;
                y[k] = (u + t) % mod;
                y[k+h/2] = (u - t + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if(on == -1) {
        long long t = qpow (len, mod - 2);
        for(int i = 0; i < len; i++)
            y[i] = y[i] * t % mod;
    }
}

char a[maxn], b[maxn];
long long x1[maxn], x2[maxn];
long long ans[maxn];

int main () {
    while (scanf ("%s%s", a, b) == 2) {
        int len = 2, l1 = strlen (a), l2 = strlen (b);
        while (len < l1*2 || len < l2*2)
            len <<= 1;
        //cout << len << endl;
        for (int i = 0; i < l1; i++) {
            x1[i] = a[l1-1-i]-'0';
        }
        for (int i = l1; i < len; i++)
            x1[i] = 0;
        for (int i = 0; i < l2; i++) {
            x2[i] = b[l2-1-i]-'0';
        }
        for (int i = l2; i < len; i++)
            x2[i] = 0;
        ntt(x1, len, 1);
        ntt(x2, len, 1);
        for (int i = 0; i < len; i++)
            x1[i] = x1[i]*x2[i]%mod;
        ntt(x1, len, -1);
        for (int i = 0; i < len; i++) {
            ans[i] = x1[i];
        }
        for (int i = 0; i < len; i++) {
            if (ans[i] >= 10) {
                ans[i+1] += ans[i]/10;
                ans[i] %= 10;
            }
        }
        len = l1+l2-1;
        while (ans[len] <= 0 && len > 0)
            len--;
        for (int i = len; i >= 0; i--) {
            printf ("%lld", ans[i]);
        }
        printf ("\n");
    }
    return 0;
}