1. 程式人生 > 程式設計 >PyTorch 中的傅立葉卷積實現示例

PyTorch 中的傅立葉卷積實現示例

卷積

卷積在資料分析中無處不在。幾十年來,它們一直被用於訊號和影象處理。最近,它們成為現代神經網路的重要組成部分。如果你處理資料的話,你可能會遇到錯綜複雜的問題。

數學上,卷積表示為:

PyTorch 中的傅立葉卷積實現示例

儘管離散卷積在計算應用程式中更為常見,但在本文的大部分內容中我將使用連續形式,因為使用連續變數來證明卷積定理(下面討論)要容易得多。之後,我們將回到離散情況,並使用傅立葉變換在 PyTorch 中實現它。離散卷積可以看作是連續卷積的近似,其中連續函式離散在規則網格上。因此,我們不會為這個離散的案例重新證明卷積定理。

卷積定理

從數學上來說,卷積定理可以這樣描述:

PyTorch 中的傅立葉卷積實現示例

其中的連續傅立葉變換是(達到正常化常數) :

PyTorch 中的傅立葉卷積實現示例

換句話說,位置空間中的卷積等價於頻率空間中的直乘。這個想法是相當不直觀的,但是對於連續的情況來說,證明卷積定理是驚人的容易。要做到這一點,首先要寫出等式的左邊。

PyTorch 中的傅立葉卷積實現示例

現在切換積分的順序,替換變數(x = y + z) ,並分離兩個被積函式。

PyTorch 中的傅立葉卷積實現示例

我們為什麼要關心這一切?

因為快速傅立葉變換的演算法複雜度低於卷積。直接卷積運算具有複雜度 O(n^2) ,因為在 f 中,我們傳遞 g 中的每個元素,所以可以在 O(nlogn)時間內計算出快速傅立葉變換。當輸入陣列很大時,它們比卷積要快得多。在這些情況下,我們可以使用卷積定理計算頻率空間中的卷積,然後執行逆傅立葉變換回到位置空間。

當輸入較小時(例如3x3卷積核心) ,直接卷積仍然更快。在機器學習應用程式中,使用小核心更為常見,因此像 PyTorch 和 Tensorflow 這樣的深度學習庫只提供直接卷積的實現。但是在現實世界中有很多使用大核心的用例,其中傅立葉卷積演算法更有效。

PyTorch 實現

現在,我將演示如何在 PyTorch 中實現傅立葉卷積函式。它應該模仿 torch.nn.functional.convNd 的功能,並利用 fft,而不需要使用者做任何額外的工作。因此,它應該接受三個 Tensors (signal、kernel 和可選 bias)和應用於輸入的 padding。從概念上講,這個函式的內部工作原理是:

def fft_conv(
  signal: Tensor,kernel: Tensor,bias: Tensor = None,padding: int = 0,) -> Tensor:
  # 1. Pad the input signal & kernel tensors
  # 2. Compute FFT for both signal & kernel
  # 3. Multiply the transformed Tensors together
  # 4. Compute inverse FFT
  # 5. Add bias and return

讓我們按照上面顯示的操作順序逐步構建 FFT 卷積。對於這個例子,我將構建一個一維傅立葉卷積,但是將其擴充套件到二維和三維卷積是很簡單的。

1. 填充輸入陣列

我們需要確保 signal 和 kernel 在填充之後有相同的大小。應用初始填充 signal,然後調整 kernel 的填充以匹配。

# 1. Pad the input signal & kernel tensors
signal = f.pad(signal,[padding,padding])
kernel_padding = [0,signal.size(-1) - kernel.size(-1)]
padded_kernel = f.pad(kernel,kernel_padding)

注意,我只在一邊填充 kernel。我們希望原始核心位於填充陣列的左側,這樣它就可以與 signal 陣列的開始對齊。

2. 計算傅立葉變換

這非常簡單,因為 n 維 fft 已經在 PyTorch 中實現了。我們簡單地使用內建函式,並計算沿每個張量的最後一個維數的 FFT。

# 2. Perform fourier convolution
signal_fr = rfftn(signal,dim=-1)
kernel_fr = rfftn(padded_kernel,dim=-1)

3. 變換張量相乘

令人驚訝的是,這是我們功能中最複雜的部分。這有兩個原因。(1) PyTorch 卷積運行於多維張量上,因此我們的 signal 和 kernel 張量實際上是三維的。從 PyTorch 文件中的這個方程式,我們可以看到矩陣乘法是在前兩個維度上執行的(不包括偏差項) :

PyTorch 中的傅立葉卷積實現示例

我們將需要包括這個矩陣乘法,以及對轉換後的維度的直接乘法。

PyTorch 實際上實現了互相關/值方法而不是卷積方法。(TensorFlow 和其他深度學習庫也是如此。)互相關與卷積密切相關,但有一個重要的標誌變化:

PyTorch 中的傅立葉卷積實現示例

與卷積相比,這有效地逆轉了核的方向(g)。我們不是手動翻轉核心,而是在傅立葉空間中利用核心的共軛複數來糾正這個問題。由於我們不需要建立一個全新的 Tensor,所以這樣做的速度明顯更快,記憶體效率也更高。(本文末尾的附錄中簡要說明了這種方法的工作原理。)

# 3. Multiply the transformed matrices
 
def complex_matmul(a: Tensor,b: Tensor) -> Tensor:
  """Multiplies two complex-valued tensors."""
  # Scalar matrix multiplication of two tensors,over only the first two dimensions.
  # Dimensions 3 and higher will have the same shape after multiplication.
  scalar_matmul = partial(torch.einsum,"ab...,cb... -> ac...") 
 
  # Compute the real and imaginary parts independently,then manually insert them
  # into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,# because Autograd is not enabled for complex matrix operations yet. Not exactly
  # idiomatic PyTorch code,but it should work for all future versions (>= 1.7.0).
  real = scalar_matmul(a.real,b.real) - scalar_matmul(a.imag,b.imag)
  imag = scalar_matmul(a.imag,b.real) + scalar_matmul(a.real,b.imag)
  c = torch.zeros(real.shape,dtype=torch.complex64)
  c.real,c.imag = real,imag
  return c 

# Conjugate the kernel for cross-correlation
kernel_fr.imag *= -1
output_fr = complex_matmul(signal_fr,kernel_fr)

PyTorch 1.7改進了對複數的支援,但是在 autograd 中還不支援對複數張量的許多操作。現在,我們必須編寫我們自己的複雜 matmul 方法作為一個補丁。雖然不是很理想,但是它確實有效,並且在未來的版本中不會出現問題。

4. 計算逆變換

使用 torch.irfftn 可以直接計算逆變換,然後裁剪出額外的陣列填充。

# 4. Compute inverse FFT,and remove extra padded values
output = irfftn(output_fr,dim=-1)
output = output[:,:,:signal.size(-1) - kernel.size(-1) + 1]

5. 新增偏執項並返回

新增偏差項也很容易。請記住,對於輸出陣列中的每個通道,偏置項都有一個元素,並相應地調整其形狀。

# 5. Optionally,add a bias term before returning.
if bias is not None:
  output += bias.view(1,-1,1)

將上述程式碼整合在一起

為了完整起見,讓我們將所有這些程式碼片段編譯成一個內聚函式。

def fft_conv_1d(
  signal: Tensor,) -> Tensor:
  """
  Args:
    signal: (Tensor) Input tensor to be convolved with the kernel.
    kernel: (Tensor) Convolution kernel.
    bias: (Optional,Tensor) Bias tensor to add to the output.
    padding: (int) Number of zero samples to pad the input on the last dimension.
  Returns:
    (Tensor) Convolved tensor
  """
  # 1. Pad the input signal & kernel tensors
  signal = f.pad(signal,padding])
  kernel_padding = [0,signal.size(-1) - kernel.size(-1)]
  padded_kernel = f.pad(kernel,kernel_padding)
 
  # 2. Perform fourier convolution
  signal_fr = rfftn(signal,dim=-1)
  kernel_fr = rfftn(padded_kernel,dim=-1)
 
  # 3. Multiply the transformed matrices
  kernel_fr.imag *= -1
  output_fr = complex_matmul(signal_fr,kernel_fr)
 
  # 4. Compute inverse FFT,and remove extra padded values
  output = irfftn(output_fr,dim=-1)
  output = output[:,:signal.size(-1) - kernel.size(-1) + 1]
 
  # 5. Optionally,add a bias term before returning.
  if bias is not None:
    output += bias.view(1,1)
 
 
  return output

直接卷積測試

最後,我們將使用 torch.nn.functional.conv1d 來確認這在數值上等同於直接一維卷積。我們為所有輸入構造隨機張量,並測量輸出值的相對差異。

import torch
import torch.nn.functional as f 
 
torch.manual_seed(1234)
kernel = torch.randn(2,3,1025)
signal = torch.randn(3,4096)
bias = torch.randn(2)
 
y0 = f.conv1d(signal,kernel,bias=bias,padding=512)
y1 = fft_conv_1d(signal,padding=512)
 
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')
 
# Abs Error Mean: 1.272E-05

考慮到我們使用的是32位精度,每個元素相差大約1e-5ー相當精確!讓我們也執行一個快速的基準來測量每個方法的速度:

from timeit import timeit
direct_time = timeit(
  "f.conv1d(signal,padding=512)",globals=locals(),number=100
) / 100
fourier_time = timeit(
  "fft_conv_1d(signal,number=100
) / 100
print(f"Direct time: {direct_time:.3E} s")
print(f"Fourier time: {fourier_time:.3E} s")
 
# Direct time: 1.523E-02 s
# Fourier time: 1.149E-03 s

測量的基準將隨著您使用的機器而發生顯著的變化。(我正在用一臺非常舊的 Macbook Pro 進行測試。)對於1025的核心,傅立葉卷積似乎要快10倍以上。

總結

我希望這已經提供了一個徹底的介紹傅立葉卷積。我認為這是一個非常酷的技巧,在現實世界中有很多應用程式可以使用它。我也喜歡數學,所以看到程式設計和純數學的結合是很有趣的。歡迎和鼓勵所有的評論和建設性的批評,如果你喜歡這篇文章,請鼓掌!

附錄:

卷積 vs. 互相關

在本文的前面,我們通過在傅立葉空間中取得核心的互相關共軛複數來實現。這實際上顛倒了 kernel 的方向,現在我想演示一下為什麼會這樣。首先,記住卷積和互相關的公式:

PyTorch 中的傅立葉卷積實現示例

然後,讓我們來看看 g(x) 的傅立葉變換:

PyTorch 中的傅立葉卷積實現示例

注意,g(x)是實值的,所以它不受共軛複數變化的影響。然後,更改變數(y =-x)並簡化表示式。

PyTorch 中的傅立葉卷積實現示例

到此這篇關於PyTorch 中的傅立葉卷積實現示例的文章就介紹到這了,更多相關PyTorch 傅立葉卷積內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!