TensorFlow函式之tf.truncated_normal()
阿新 • • 發佈:2018-11-17
tf.truncated_normal()函式是一種“截斷”方式生成正太分佈隨機值,“截斷”意思指生成的隨機數值與均值的差不能大於兩倍中誤差,否則會重新生成。
此函式有別於tf.random_normal()正太函式,請參考本部落格關於tf.random_normal()函式的介紹
(TensorFlow函式之tf.random_normal())
tf.truncated_normal()函式的格式為:
tf.truncated_normal(shape, mean, stddev, dtype, seed, name)
引數說明:
- shape:表示生成隨機數的維度
- mean:正太分佈的均值,預設為0
- stddev:正太分佈的標準差
- dtype:生成正太分佈資料的型別
- seed:一個整數,當設定之後,每次生成的隨機數都一樣
- name:正太分佈的名字
下邊舉兩個例子說明函式的用法:
1、下邊例子,均值mean=0,stddev=0.2,則生成的隨機數與均值差不能大於兩倍中誤差,即範圍為:[-0.4,0.4]
import tensorflow as tf v = tf.truncated_normal([2, 2], mean=0, stddev=0.2, dtype=tf.float32, seed=1, name='v') sess = tf.Session() print(sess.run(v)) sess.close()
輸出為:
[[-0.16226365 0.29691976]
[ 0.01306587 0.01984968]]
2、修改標準差,檢視生成隨機數的差別,這裡設定stddev=0.1,則生成的範圍:[-0.2,0.2]
import tensorflow as tf
v = tf.truncated_normal([2, 2], mean=0, stddev=0.1, dtype=tf.float32, seed=1, name='v')
sess = tf.Session()
print(sess.run(v))
sess.close()
輸出為:
[[-0.08113182 0.14845988] [ 0.00653294 0.00992484]]