Tensorflow進行多維矩陣的拆分與拼接例項
阿新 • • 發佈:2020-02-09
最近在使用tensorflow進行網路訓練的時候,需要提取出別人訓練好的卷積核的部分層的資料。由於tensorflow中的tensor和python中的list不同,無法直接使用加法進行拼接,後來發現一個函式可以完成tensor的拼接。
函式形式如下:
tf.concat(concat_dim,values,name='concat')
其中,第一個引數表示需要拼接的多維tensor,並且可以將多個tensor同事拼接,第二個表示按照哪一個維度拼接(從數字0開始)。
例子:建立一個三維的tensor,然後分別取出最後一個維度(注意:tensor支援與python中list相似的切片操作,可以使用這種方式進行拆分),然後在拼接在一起。
import tensorflow as tf weights=tf.Variable(tf.truncated_normal([2,3,4],dtype=tf.float32,stddev=1e-1),name='weights') weight1=weights[0:2,0:3,1:2] weight2=weights[0:2,2:3] weight3=weights[0:2,1:2] weight4=tf.concat([weight1,weight2,weight3],2) #2表示最後一個維度 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(weights)) print("****************") print(sess.run(weight4))
以上這篇Tensorflow進行多維矩陣的拆分與拼接例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。