1. 程式人生 > 程式設計 >Tensorflow進行多維矩陣的拆分與拼接例項

Tensorflow進行多維矩陣的拆分與拼接例項

最近在使用tensorflow進行網路訓練的時候,需要提取出別人訓練好的卷積核的部分層的資料。由於tensorflow中的tensor和python中的list不同,無法直接使用加法進行拼接,後來發現一個函式可以完成tensor的拼接。

函式形式如下:

tf.concat(concat_dim,values,name='concat')

其中,第一個引數表示需要拼接的多維tensor,並且可以將多個tensor同事拼接,第二個表示按照哪一個維度拼接(從數字0開始)。

例子:建立一個三維的tensor,然後分別取出最後一個維度(注意:tensor支援與python中list相似的切片操作,可以使用這種方式進行拆分),然後在拼接在一起。

import tensorflow as tf

weights=tf.Variable(tf.truncated_normal([2,3,4],dtype=tf.float32,stddev=1e-1),name='weights')

weight1=weights[0:2,0:3,1:2]
weight2=weights[0:2,2:3]
weight3=weights[0:2,1:2]
weight4=tf.concat([weight1,weight2,weight3],2) #2表示最後一個維度

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(weights))
 print("****************")
 print(sess.run(weight4))

Tensorflow進行多維矩陣的拆分與拼接例項

以上這篇Tensorflow進行多維矩陣的拆分與拼接例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。