pytorch系列6 -- activation_function 啟用函式 relu, leakly_relu, tanh, sigmoid及其優缺點
阿新 • • 發佈:2018-11-11
主要包括:
- 為什麼需要非線性啟用函式?
- 常見的啟用函式有哪些?
- python程式碼視覺化啟用函式線上性迴歸中的變現
- pytorch啟用函式的原始碼
為什麼需要非線性的啟用函式呢?
只是將兩個或多個線性網路層疊加,並不能學習一個新的東西,接下來通過簡單的例子來說明一下:
假設
- 輸入
- 第一層網路引數:
- 第二層網路引數:
經過第一層後輸出為 經過第二層後的輸出為:
是不是等同於一層網路:
所以說簡單的堆疊網路層,而不經過非線性啟用函式啟用,並不能學習到新的特徵學到的仍然是線性關係。
接下來看一下經過啟用函式呢?
仍假設
- 輸入
- 第一層網路引數:
- 經過啟用函式Relu:
- 第二層網路引數:
通過啟用函式的加入可以學到非線性的關係,這對於特徵提取具有更強的能力。接下來結合函式看一下,輸入的 在經過兩個網路後的輸出結果:
# -*- coding: utf-8 -*-
"""
Spyder Editor
This is a temporary script file.
"""
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(-3,3, step=0.5)
def non_activation_function_model(x):
y_1 = x * 3 + 1
y_2 = y_1 * 2 + 2
print(y_2)
return y_2
def activation_function_model(x):
y_1 = x * 3 + 1
y_relu =np.where( y_1 > 0, y_1, 0)
# print(y_relu)
y_2 = y_relu * 2 + 1
print(y_2)
return y_2
y_non = non_activation_function_model(x)
y_ = activation_function_model(x)
plt.plot(x, y_non, label='non_activation_function')
plt.plot(x, y_, label='activation_function')
plt.legend()
plt.show()
out:
可以看出,通過啟用函式,網路結構學到了非線性特徵,而不使用啟用函式,只能得到學到線性特徵。
常用的啟用函式有:
- Sigmoid
- Tanh
- ReLU
- Leaky ReLU
分式函式的求導函式:
- Sigmoid函式
其導函式為:
兩者的函式影象:
import numpy as np
import matplotlib.pyplot as plt
def sigma(x):
return 1 / (1 + np.exp(-x))
def sigma_diff(x):
return sigma(x) * (1 - sigma(x))
x = np.arange(-6, 6, step=0.5)
y_sigma = sigma(x)
y_sigma_diff = sigma_diff(x)
axes = plt.subplot(111)
axes.plot(x, y_sigma, label='sigma')
axes.plot(x, y_sigma_diff, label='sigma_diff')
axes.spines['bottom'].set_position(('data',0))
axes.spines['left'].set_position(('data',0))
axes.legend()
plt.show()
優點:
-
是便於求導的平滑函式;
-
能壓縮資料,保證資料幅度不會趨於
缺點:
-
容易出現梯度消失(gradient vanishing)的現象:當啟用函式接近飽和區時,變化太緩慢,導數接近0,根據後向傳遞的數學依據是微積分求導的鏈式法則,當前導數需要之前各層導數的乘積,幾個比較小的數相乘,導數結果很接近0,從而無法完成深層網路的訓練。
-
Sigmoid的輸出均值不是0(zero-centered)的:這會導致後層的神經元的輸入是非0均值的訊號,這會對梯度產生影響。以 f=sigmoid(wx+b)為例, 假設輸入均為正數(或負數),那麼對w的導數總是正數(或負數),這樣在反向傳播過程中要麼都往正方向更新,要麼都往負方向更新,使得收斂緩慢。
-
指數運算相對耗時
- tanh函式
其導函式:
兩者的函式影象:
import numpy as np
import matplotlib.pyplot as plt
def tanh(x):
return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
def tanh_diff(x):
return 4 / np.power(np.exp(x) + np.exp(-x), 2)
x = np.arange(-6, 6, step=0.5)
y_sigma = tanh(x)
y_sigma_diff = tanh_diff(x)
axes = plt.subplot(111)
axes.plot(x, y_sigma, label='sigma')
axes.plot