tensorflow:常用API-'a'
阿新 • • 發佈:2018-11-10
1.加法操作
tf.accumulate_n、tf.add_n、tf.add
import tensorflow as tf
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 0], [0, 6]])
c = tf.constant([2, 3])
sess = tf.InteractiveSession()
print(tf.accumulate_n([a, b]).eval())
print(tf.add_n([a, b]).eval())
print(tf.add(a, b).eval())
# 輸出的結果都是
# [[ 6 2]
# [ 3 10]]
# 但是tf.add支援broadcasting
print(tf.add(a, c).eval())
# [[6 2]
# [8 4]]
# print(tf.add_n([a,c]).eval()) #不支援廣播-error
小結:多個tensor對應相加,推薦tf_add_n,若需要支援廣播(2個shape不一樣的tensor進行操作),請使用tf.add
2.argmax和argmin
a1 = tf.constant([[4, 2, 3], [1, 6, 5]])
print(tf.argmax(a1, axis=0).eval())
# 預設,按列取最大值的下標[0 1 1]
print(tf.argmax(a1, axis=1).eval())
# 按行取最大值的下標[0 1]
tf還提供了arg_max函式,其功能和argmax一樣,arg_max是一個待拋棄的函式,推薦使用argmax,argmin和argmax類似
3.assign
assign對tensor的引用進行重新賦值
a2 = tf.Variable(3, dtype=tf.float32)
sess.run(tf.global_variables_initializer())
print(a2.eval()) # 3
tf.assign(a2, 5).eval() # 必須得eval執行下
print(a2.eval()) # 5
tf.assign_add(a2,2).eval() #加上一個值再從新賦值
print(a2.eval()) # 7
4.as_string
類似tostring函式,不過作用在tensor上
a3 = tf.constant([[1.13, 2.02], [3.5, 4.433]])
print(tf.as_string(a3,precision=2).eval()) #保留2位小數
# [[b'1.13' b'2.02']
# [b'3.50' b'4.43']]