tf.py_func()函式
阿新 • • 發佈:2018-11-01
tensorflow由於構建的是靜態圖,所以導致在tf.Session().run()之前是沒有實際值的,因此,在網路搭建的時候,是不能對tensor進行判值操作的,即不能插入if…else…之類的程式碼。第二,相較於numpy array,Tensorflow中對tensor的操作介面靈活性並沒有那麼高,使得Tensorflow的靈活性減弱。
在筆者使用Tensorflow的一年中積累的程式設計經驗來看,擴充套件Tensorflow程式的靈活性,有一個重要的手段,就是使用tf.py_func介面。 介面解析
程式碼測試:
def my_func(array1,array2) :
return array1 + array2, array1 - array2
if __name__ =='__main__':
array1 = np.array([[1, 2], [3, 4]])
array2 = np.array([[1, 2], [3, 4]])
a1 = tf.placeholder(tf.float32,[2,2],name = 'array1')
a2 = tf.placeholder(tf.float32,[2,2],name = 'array2')
y1,y2 = tf.py_func(my_func, [a1,a2],[tf.float32, tf.float32])
with tf.Session() as sess:
y1_,y2_ = sess.run([y1,y2],feed_dict={a1:array1,a2:array2})
print(y1_)
print('*'*10)
print(y2_)
從上面的程式碼我們可以看出,tf.py_func()接收的是tensor,然後將其轉化為numpy array送入我們自定義的my_func函式,最後再將my_func函式輸出的numpy array轉化為tensor返回
如果不用tf.py_func()實現的話,我們還可以這樣直接用array的方式操作:
def my_func(array1,array2):
return array1 + array2, array1 - array2
with tf.Session() as sess:
array1 = np.array([[1, 2], [3, 4]])
array2 = np.array([[1, 2], [3, 4]])
y1,y2 = my_func(array1,array2)
print(y1)
print('*' * 10)
print(y2)