1. 程式人生 > >【learning】多項式乘法&fft

【learning】多項式乘法&fft

type 表示 了解 這就是 () 學習 a記錄 b+ wap

[吐槽]

  以前一直覺得這個東西十分高端完全不會qwq

  但是向lyy、yxq、yww、dtz等dalao們學習之後發現這個東西的代碼實現其實極其簡潔

  於是趁著還沒有忘記趕緊來寫一篇博

  (說起來這篇東西的文字好像有點多呀qwq啊話癆是真的qwq)

[正題]

  一些預備知識(有了解的就可以直接跳啦,mainly from 算導)

  fft的話,用來解決與多項式乘法有關的問題

  關於多項式

  一個以x為變量的多項式定義在一個代數域$F$上,將函數$A(x)$表示為形式和:

  $A(x) = \sum\limits_{i=0}^{n-1} a_ix_i$

  顯然該多項式有$n$項,我們稱 $a_0, a_1, a_2 ... a_{n-1}$ 為該多項式的系數

  如果說一個多項式$A(x)$的最高次的非零系數是$a_k$,那麽稱$A(x)$的次數是$k$

  任何嚴格大於一個多項式次數的整數都是該多項式的次數界

  那麽顯然,我們可以用一個$n$次多項式的$n$個系數來表示這個多項式

  

  在多項式上定義的運算

  對於加法,就是直接對應系數相加就好了

  對於乘法,

  如果說$A(x)$ 和 $B(x)$ 皆是次數為$n$的多項式,則它們的乘積$C(x)$ 是一個次數界為$2n-1$的多項式

  對於所有屬於定義域的$x$,都有$C(x)=A(x)*B(x)$

  那麽如何快速求兩個多項式的乘積呢?

  我們知道對於一個$n$次多項式,知道了其函數圖像上的$n$個點之後,就能將這個多項式確定下來了

  所以就考慮通過計算求出到$C(x)$函數圖像上的$2n-1$個點,從而確定$C(x)$

  那麽如何用一種高效的方式解決這個問題呢?就是接下來要講的東西啦

  總的來說. . . 我們要幹什麽?

  顯然,我們現在要尋找一種快速的方法來求兩個多項式的乘積

  接下來介紹的方法思路就是上面提到的:先求出點值,再確定多項式

  根據乘積$C(x)$滿足的性質,我們可以取$2n-1$個不同的$x1$,將$A(x1)$和$B(x1)$分別算出來

  然後再用兩者相乘得到這個位置的$C(x1)$

  直接算效率是極低的,但是如果說我們選擇的點有一些特殊性質呢?

  如果說我們選擇的位置滿足某種性質,使得我們在計算系數的時候能夠省掉一些步驟

  (比如說系數中滿足某種關系啊之類的)

  那麽我們的效率就會相對來說高一些了

  接下來介紹的方法,用到一個叫做DFT的東西(說白了就是選擇一些特殊的點),通過求兩個多項式的系數向量的DFT,得到確定$C(x)$所需要的點值,然後再通過其逆運算,得到$C(x)$

  這就是接下來的內容的大概思路

  特殊的點?單位復數根

  (在下文的敘述中用$i$來表示$-1$的平方根)

  $n$次單位復數根是滿足$\omega^n=1$的復數$\omega$

  $n$次單位復數根恰好有$n$個,對於$k=0, 1, ... , n-1$,這些根是$e^{2\pi ik/n}$

  對於這個表達式的計算,我們可以利用復數的指數形式的定義:

$e^{iu} = cos(u) + i sin(u)$

  我們考慮將一個復數在坐標系上用一個點來表示

  對於一個復數$x$,我們可以將其表示為這種形式:

  $x = a+b*i $ $(a,b\in R)$

  考慮這樣的一個坐標系,其橫軸為實數軸,縱軸為虛數軸

  那麽我們可以將$x$這個數表示為該坐標系(其實就是復平面)中的點$(a,b)$

  

  那麽將$n$個$n$次單位復數根畫出來的話(以$n=8$為例),大概是長這樣:

  技術分享

  其實如果畫得足夠標準,這個些單位復數根應該分布在一個以原點為圓心的圓上。。。

  (這個好像一點都不像一個圓啊餵qwq)

  (嗯好像這點在接下來的講述中並不會用到,不過碼上來總是好的ovo)

  

  那麽接下來給出一些關於$n$次單位復數根的基本性質

  (註意這也是後面FFT之所以能在O(nlogn)時間內求得的重要原因)

  消去引理(好像有點像。。約分?哈哈哈)

    對任何整數$n >= 0, k >=0, $以及$d > 0$,有

    $\omega_{dn}^{dk} = \omega_{n}^{k}$

    證明就直接將其定義帶進去就好:$\omega_{dn}^{dk} = (e^{2\pi i/dn})^{dk} = (e^{2\pi i/n})^k = \omega_n^k$

    那麽由這條式子我們可以得到一個推論:

     $\omega_{n}^{n/2} = \omega_2 = -1$

  折半引理

    如果 $n>0$ 為偶數,那麽 $n$ 個 $n$ 次單位復數根的平方的集合就是 $n/2$ 個 $n/2$ 次單位復數根的集合

   

    證明的話:

    首先,根據消去引理,對於任意的非負整數 $k$ ,有

    $(\omega_{n}^{k})^2 = \omega_{n/2}^{k}$

    然後我們會發現,如果對於所有的$n$次單位復數根平方,會得到每個$n/2$次單位根正好2次,因為

    $(\omega_{n}^{k+n/2})^2 = \omega_{n}^{2k+n} = \omega_{n}^{2k} * \omega_{n}^{n} = \omega_{n}^{2k} = (\omega_{n}^{k})^2$

    所以還可以得到這樣一條式子

    $(\omega_{n}^{k+n/2} )^2= (\omega_{n}^{k})^2 $

  求和引理

    對任意整數$n>=1$ 和不能被$n$整除的非負整數$k$,有

    $\sum\limits_{i=0}^{n-1} (\omega_{n}^{k})^i = 0$

    證明:

    $\sum\limits_{i=0}^{n-1} (\omega_{n}^{k})^i = \frac{(\omega_{n}^{k})^n-1}{\omega_{n}^{k}-1} = \frac{(\omega_{n}^{n})^k-1}{\omega_{n}^{k}-1} = \frac{(1)^k-1}{\omega_{n}^{k}-1} =0$

  於是乎我們開始真正步入正題來求……

  DFT

    在介紹完什麽是單位復數根之後,就可以引入DFT的概念了

    計算一個次數界為$n$的多項式 :

$A(x) = \sum\limits_{i=0}^{n-1} a_i x_i$

    在$\omega_{n}^{0},\omega_{n}^{1} ... \omega_{n}^{n-1}$處的取值(也就是在$n$個$n$次單位復數根處),

    定義其結果$y_k$:

    $y_k = A(\omega_n^k)$

    向量$y$就是系數向量 $a = (a_0, a_1, a_2 ,..., a_{n-1})$(也就是A的系數)的DFT(離散傅裏葉變換)

    我們記為$y=DFT_n(a)$

  FFT

    嗯?名字是不是和上面長得很像啊?

    原因是因為,FFT其實就是快速求DFT,叫做快速傅裏葉變換

    利用復數單位根的特殊性質,我們就可以在$O(nlogn)$時間內算出$DFT_n(a)$

    接下來就是算法部分啦

     

    一些必須先約定的東西:接下來的內容中$n$都恰好是2的整數冪

    (如果說實際處理中出現次數界不是2的整數冪呢?強行補成就好啦,不存在的那些項系數=0即可)

    

    我們考慮分治策略,根據$A(x)$中系數下標的奇偶性分成兩組,變成兩個新的次數界為$n/2$的多項式

    這裏分別定義兩個新的多項式:

    $A_0(x) = a_0 + a_2x + a_4x^2 + ... +a_{n-2}x^{n/2-1}$

    $A_1(x) = a_1 + a_3x + a_5x^2 + ... +a_{n-1}x^{n/2-1}$

    ($A_0(x)$中包含了$A$中所有下標為偶數的系數,$A_1(x)$中包含了所有下標為奇數的系數)

    那麽顯然有:

    $A(x) = A_0(x^2) + x*A_1(x^2)$

    至此,發現我們的問題直接就轉化為了:

    求次數界為$n/2$的多項式$A_0(x)$和$A_1(x)$在點$(\omega_n^0)^2 , (\omega_n^1)^2 , ... , (\omega_n^{n-1})^2$的取值

    

    於是乎這樣好像就把我們原來的問題成功拆成了兩個形式與原問題相同的子問題

    

    假設我們現在已經求得了$A_0(\omega_{n/2}^k)$和$A_1(\omega_{n/2}^k)$

    如何得到由它們快速得到$A$中的系數呢?

    這時候就要用到單位復數根的奇妙性質啦

    根據消去引理,有$\omega_{n/2}^k = \omega_{n}^{2k}$

    於是

    $A_0(\omega_{n/2}^k) = A_0(\omega_{n}^{2k})$

    $A_1(\omega_{n/2}^k) = A_1(\omega_{n}^{2k})$

    這個時候我們用表達$A_0(x)$和$A_1(x)$與$A(x)$之間的關系的那條式子推一下,會發現

    $A_0(\omega_{n}^{2k}) + \omega_{n}^{k} * A_1(\omega_{n}^{2k}) = A(\omega_{n}^{k})$

    稍微繞一下彎,還可以得到這樣的一條式子

    $A_0(\omega_{n}^{2k}) - \omega_{n}^{k} * A_1(\omega_{n}^{2k}) = A(\omega_{n}^{k+n/2})$

    為什麽呢?

    一步步來的話是這樣的:

    首先,我們知道$\omega_{n}^{n/2} = \omega_2 = -1$ (消去引理的推論)

    然後有:

    $ - \omega_n^k = -1 * \omega_n^k = \omega_{n}^{n/2} * \omega_n^k = \omega_{n}^{k+n/2}$

    所以第二條式子其實是等於

    $A_0(\omega_{n}^{2k}) + \omega_{n}^{k+n/2} * A_1(\omega_{n}^{2k})$

    然後根據折半引理,我們可以知道$\omega_n^{2k+n} = \omega_n^{2k}$

    所以上面的式子又等於

    $A_0(\omega_{n}^{2k+n}) + \omega_{n}^{k+n/2} * A_1(\omega_{n}^{2k+n}) $

    然後我們會發現這其實就是一個$A_0(x) + x * A_1(x) $的形式

    這樣這條式子最終就等於$A(\omega_{n}^{k+n/2})$啦

    總結一下,如果說我們得到了$A_0(\omega_{n/2}^{k})$(記為$y_0$)以及$A_1(\omega_{n/2}^{k})$(記為$y_1$)

    那麽我們就可以得到$A(\omega_{n}^{k})$以及$A(\omega_{n}^{k+n/2})$了

    其中

    $A(\omega_{n}^{k}) = y_0 + y_1$

    $A(\omega_{n}^{k+n/2}) = y_0 - y_1$

    至此,我們就完成了將原來的問題拆成了兩個規模為一半的問題的求解

    就可以在$O(nlogn)$的時間內求出DFT啦

    遞歸版的代碼如下(這裏是非完整的代碼,完整版會在後面給出)

 1 struct cmplx
 2 {
 3     double a,b;//a記錄這個復數的實數部分,b記錄這個復數的i的系數
 4     cmplx(){}
 5     cmplx(double x,double y){a=x,b=y;}
 6     friend cmplx operator + (cmplx x,cmplx y)
 7     {return cmplx(x.a+y.a,x.b+y.b);}
 8     friend cmplx operator - (cmplx x,cmplx y)
 9     {return cmplx(x.a-y.a,x.b-y.b);}
10     friend cmplx operator * (cmplx x,cmplx y)
11     {return cmplx(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);}
12 };
13 typedef vector<cmplx> vc
14 
15 vc fft(vc ans)
16 {
17     int n=ans.size();
18     if (n==1) return ans;
19     cmplx w_n=cmplx(cos(2*pi/n),sin(2*pi/n)),w=cmplx(1,0);
20     vc a0,a1;
21     for (int i=0;i<n;i+=2)
22         a0.push_back(ans[i]),a1.push_back(ans[i+1]);
23     //得到A0和A1
24 
25     a0=fft(a0,op);
26     a1=fft(a1,op);
27     //遞歸求出將單位復數根帶入得到的值
28 
29     for (int i=0;i<(n>>1);++i)
30     {
31         ans[i]=a0[i]+a1[i]*w;
32         ans[i+(n>>1)]=a0[i]-a1[i]*w;
33         w=w*w_n;
34         //利用得到的關系式由A0和A1推得A
35     }
36     return ans;
37 }

  所以說...我們要怎麽求回來?

  現在我們已經成功滴把DFT搞出來了,也就可以求得我們所需要的用來確定$C(x)$的點值了,剩下的工作就是插值啦

  插值的方法有很多,這裏考慮將DFT寫成一個矩陣方程 $y = V_n a$

  其中向量$y$表示的是DFT,向量$a$為原多項式的系數

  $V_n$是一個由$\omega_n$適當冪次填充成的範德蒙德矩陣

  那麽現在問題來了:

  範德蒙德矩陣又是什麽高端玩意?!

  其實這個東西大概長這樣:

\begin{bmatrix}
1&x_0&x_0^2&...&x_0^{n-1}\\
1&x_1&x_1^2&...&x_1^{n-1}\\
1&...&...&...&...\\
1&x_{n-1}&x_{n-1}^2&...&x_{n-1}^{n-1}
\end{bmatrix}

  (所謂的“$\omega_n$適當冪次填充”其實就是把$x_0, x_1, x_2, ... ,x_{n-1}$換成$n$次單位根)

  所以如果說我們想要由$y$得到$a$,只需要乘上逆矩陣$V_n^{-1}$就好了

  ($V_n^{-1} * V_n = $單位矩陣)

  考慮逆矩陣中的元素的特點

  然後根據求和引理(中間的過程有點。。看算導的話好像會更加清晰一些),可以得出這樣的結論:

  $a_i = \frac{1}{n} \sum\limits_{k=0}^{n-1} y_k \omega_n^{-kj}$

  說得簡單一點就是,

  由DFT反過來求原來的系數只要用$\omega_n^{-1}$替換掉$\omega_n$,並在最後將每個元素除以$n$就好啦

  實現的話,會發現其實與DFT的區別僅僅在於一個負號,其他部分的代碼實現是完全一樣的

  (爽到了爽到了哈哈哈qwq)

  所以說其實完全可以在調用函數的時候多帶一個參數,表示是否是求$DFT^-1$,這樣就十分方便滴將兩個函數合並成一個啦

  最後在這裏附上遞歸版完整的代碼(求的是兩個多項式$a$和$b$的乘積)

技術分享
 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<cmath>
 5 #define ll long long
 6 using namespace std;
 7 const double pi=acos(-1);
 8 const int MAXN=(1<<17)+10;
 9 struct cmplx
10 {
11     double a,b;//a記錄這個復數的實數部分,b記錄這個復數的i的系數
12     cmplx(){}
13     cmplx(double x,double y){a=x,b=y;}
14     friend cmplx operator + (cmplx x,cmplx y)
15     {return cmplx(x.a+y.a,x.b+y.b);}
16     friend cmplx operator - (cmplx x,cmplx y)
17     {return cmplx(x.a-y.a,x.b-y.b);}
18     friend cmplx operator * (cmplx x,cmplx y)
19     {return cmplx(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);}
20 }a[MAXN],b[MAXN];
21 int n,m,k;
22 int fft(cmplx *ans,int n,int op);
23 
24 int main()
25 {
26 //    freopen("a.in","r",stdin);
27 //    freopen("a.out","w",stdout);
28 
29     int type;
30     scanf("%d%d%d",&n,&m,&type);
31     for (int i=0;i<=n;++i) scanf("%lf",&a[i].a);
32     for (int i=0;i<=m;++i) scanf("%lf",&b[i].a);
33     k=1;
34     while (k<n+m) k<<=1;
35     fft(a,k,1); 
36     fft(b,k,1);
37     for (int i=0;i<=k;++i) a[i]=a[i]*b[i];
38     fft(a,k,-1);
39     for (int i=0;i<=n+m;++i)
40         printf("%lld ",(ll)(a[i].a/k+0.5));//最後一定要記得除
41 }
42 
43 int fft(cmplx *ans,int n,int op)
44 {
45     if (n==0) return 0;
46     cmplx a0[n>>1],a1[n>>1],w_n=cmplx(cos(2*pi/n),op*sin(2*pi/n)),w=cmplx(1,0);
47         //註意在求逆DFT的時候,也就是在w_n的i的系數那裏多了一個負號罷了
48     for (int i=0;i<=n;i+=2)
49         a0[i>>1]=ans[i],a1[i>>1]=ans[i+1];
50     fft(a0,n>>1,op);
51     fft(a1,n>>1,op);
52     for (int i=0;i<(n>>1);++i)
53     {
54         ans[i]=a0[i]+a1[i]*w;
55         ans[i+(n>>1)]=a0[i]-a1[i]*w;
56         w=w*w_n;
57     }
58 }
遞歸版

  然後其實還有一種非遞歸的寫法,常數會小很多,寫起來也是十分的簡潔

  但是因為裏面的一些操作的需要用到一些關於二進制的知識講述清楚可能還是需要一定的篇幅

  而這篇東西的篇幅本來就夠長的了。。所以說就先挖個坑貼上代碼,具體就留在下一篇再講吧qwq

  (隨處挖坑 然後不填 系列)

技術分享
 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<cmath>
 5 #include<vector>
 6 #define ll long long
 7 using namespace std;
 8 const double pi=acos(-1);
 9 const int MAXN=(1<<17)+10;
10 struct cmplx
11 {
12     double a,b;
13     cmplx(){}
14     cmplx(double x,double y){a=x,b=y;}
15     friend cmplx operator + (cmplx x,cmplx y)
16     {return cmplx(x.a+y.a,x.b+y.b);}
17     friend cmplx operator - (cmplx x,cmplx y)
18     {return cmplx(x.a-y.a,x.b-y.b);}
19     friend cmplx operator * (cmplx x,cmplx y)
20     {return cmplx(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);}
21 }a[MAXN],b[MAXN],ans[MAXN];
22 int rev[MAXN];
23 int n,m,k,lg;
24 //vc fft(vc ans,int op);
25 int fft(cmplx *a,int op);
26 int get_rev(cmplx *a,int n);
27 
28 int main()
29 {
30 //    freopen("a.in","r",stdin);
31 //    freopen("a.out","w",stdout);
32 
33     int type,x;
34     scanf("%d%d%d",&n,&m,&type);
35     ++n,++m;
36     for (int i=0;i<n;++i) scanf("%lf",&a[i].a);
37     for (int i=0;i<m;++i) scanf("%lf",&b[i].a);
38     k=1;
39     while (k<n+m) k<<=1;
40     fft(a,1);
41     fft(b,1);
42     for (int i=0;i<k;++i) a[i]=a[i]*b[i];
43     fft(a,-1);
44     for (int i=0;i<n+m-1;++i)
45         printf("%lld ",(ll)(a[i].a/k+0.5));
46 }
47 
48 int fft(cmplx *a,int op)
49 {
50     int step,bit=0;
51     cmplx w_n,w,t,u;
52     for (int i=1;i<k;i<<=1,++bit);
53     rev[0]=0;
54     for (int i=0;i<k;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
55     for (int i=0;i<k;++i) 
56         if (i<rev[i]) swap(a[i],a[rev[i]]);
57     //簡單說一下:就是因為我們會發現其實遞歸到最下層的順序是可以確定的
58     //然後就可以通過奇妙的方式(用到有關二進制的東西)得到這個順序,然後就直接模擬向上更新的過程就好啦
59     for (int step=2;step<=k;step<<=1)
60     {
61         w_n=cmplx(cos(2*pi/step),op*sin(2*pi/step));
62         for (int st=0;st<k;st+=step)
63         {
64             w=cmplx(1,0);
65             for (int i=0;i<(step>>1);++i)
66             {
67                 t=a[st+i+(step>>1)]*w;
68                 u=a[st+i];
69                 a[st+i]=u+t;
70                 a[st+i+(step>>1)]=u-t;
71                 w=w*w_n;
72             }
73         }
74     }
75 }
非遞歸版

[總結]

  其實fft這個東西仔細想想還是很有意思的(特別是代碼的簡潔哈哈)

  難得更了一篇這麽長的博,希望對這方面的理解能夠有所幫助吧ovo

【learning】多項式乘法&fft