1. 程式人生 > >tensorflow:常用API-'a'

tensorflow:常用API-'a'

1.加法操作
tf.accumulate_n、tf.add_n、tf.add

import tensorflow as tf
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 0], [0, 6]])
c = tf.constant([2, 3])

sess = tf.InteractiveSession()

print(tf.accumulate_n([a, b]).eval())
print(tf.add_n([a, b]).eval())
print(tf.add(a, b).eval())
# 輸出的結果都是
# [[ 6 2] # [ 3 10]] # 但是tf.add支援broadcasting print(tf.add(a, c).eval()) # [[6 2] # [8 4]] # print(tf.add_n([a,c]).eval()) #不支援廣播-error

小結:多個tensor對應相加,推薦tf_add_n,若需要支援廣播(2個shape不一樣的tensor進行操作),請使用tf.add

2.argmax和argmin

a1 = tf.constant([[4, 2, 3], [1, 6, 5]])
print(tf.argmax(a1, axis=0).eval())
# 預設,按列取最大值的下標[0 1 1]
print(tf.argmax(a1, axis=1).eval()) # 按行取最大值的下標[0 1]

tf還提供了arg_max函式,其功能和argmax一樣,arg_max是一個待拋棄的函式,推薦使用argmax,argmin和argmax類似

3.assign
assign對tensor的引用進行重新賦值

a2 = tf.Variable(3, dtype=tf.float32)
sess.run(tf.global_variables_initializer())
print(a2.eval())  # 3
tf.assign(a2, 5).eval()  # 必須得eval執行下
print(a2.eval()) # 5 tf.assign_add(a2,2).eval() #加上一個值再從新賦值 print(a2.eval()) # 7

4.as_string
類似tostring函式,不過作用在tensor上

a3 = tf.constant([[1.13, 2.02], [3.5, 4.433]])
print(tf.as_string(a3,precision=2).eval()) #保留2位小數
# [[b'1.13' b'2.02']
# [b'3.50' b'4.43']]