python(1):tf.transpose函式
阿新 • • 發佈:2018-12-18
tf.transpose(a, perm = None, name = 'transpose')
a是一個張量(Tensor),實際上就是一個數組。
perm是a置換的維度
name是操作的名稱
最後返回一個轉置的張量
影象處理時資料集中儲存資料的形式為:[channel,image_height,image_width],在intel GPU加速的情況下,因為GPU對於影象的處理比較多,希望在訪問同一個channel的畫素是連續的,一般儲存選用NCHW【參考連結:NCHW和NHWC】。而在tensorflow中使用CNN時我們需要將其轉化為[image_height,image_width,channel]的形式,這個時候我們可以使用tf.transpose函式,即tf.transpose(a.[1,2,0])
如果perm沒有給定,那麼預設是perm是=[n-1,n-2,n-3.....,0],其中rank(a)=n
對於二維輸入資料,預設perm就是常規的矩陣轉置操作。
tf.transpose的第二個引數perm=[0,1,2],0代表三維陣列的高(即為二維陣列的個數),1代表二維陣列的行,2代表二維陣列的列。
tf.transpose(x, perm=[1,0,2])代表將三位陣列的高和行進行轉置。
wiki文件如下:
tf.transpose(a, perm=None, name='transpose') Transposes a. Permutes the dimensions according to perm. The returned tensor's dimension i will correspond to the input dimension perm[i]. If perm is not given, it is set to (n-1...0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors. For example: # 'x' is [[1 2 3] # [4 5 6]] tf.transpose(x) ==> [[1 4] [2 5] [3 6]] # Equivalently tf.transpose(x perm=[1, 0]) ==> [[1 4] [2 5] [3 6]] # 'perm' is more useful for n-dimensional tensors, for n > 2 # 'x' is [[[1 2 3] # [4 5 6]] # [[7 8 9] # [10 11 12]]] # Take the transpose of the matrices in dimension-0 tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4] [2 5] [3 6]] [[7 10] [8 11] [9 12]]] Args: •a: A Tensor. •perm: A permutation of the dimensions of a. •name: A name for the operation (optional). Returns: A transposed Tensor.
測試程式碼如下:
import tensorflow as tf #x = tf.constant([[1, 2 ,3],[4, 5, 6]]) x = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[21,22,23,24],[25,26,27,28],[29,30,31,32]]] #a=tf.constant(x) a=tf.transpose(x, [0, 1, 2]) b=tf.transpose(x, [0, 2, 1]) c=tf.transpose(x, [1, 0, 2]) d=tf.transpose(x, [1, 2, 0]) e=tf.transpose(x, [2, 1, 0]) f=tf.transpose(x, [2, 0, 1]) # 'perm' is more useful for n-dimensional tensors, for n > 2 # 'x' is [[[1 2 3] # [4 5 6]] # [[7 8 9] # [10 11 12]]] # Take the transpose of the matrices in dimension-0 #tf.transpose(b, perm=[0, 2, 1]) with tf.Session() as sess: print ('---------------') print (sess.run(a)) print ('---------------') print (sess.run(b)) print ('---------------') print (sess.run(c)) print ('---------------') print (sess.run(d)) print ('---------------') print (sess.run(e)) print ('---------------') print (sess.run(f)) print ('---------------')
測試結果如下:
---------------
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
[[21 22 23 24]
[25 26 27 28]
[29 30 31 32]]]
---------------
[[[ 1 5 9]
[ 2 6 10]
[ 3 7 11]
[ 4 8 12]]
[[21 25 29]
[22 26 30]
[23 27 31]
[24 28 32]]]
---------------
[[[ 1 2 3 4]
[21 22 23 24]]
[[ 5 6 7 8]
[25 26 27 28]]
[[ 9 10 11 12]
[29 30 31 32]]]
---------------
[[[ 1 21]
[ 2 22]
[ 3 23]
[ 4 24]]
[[ 5 25]
[ 6 26]
[ 7 27]
[ 8 28]]
[[ 9 29]
[10 30]
[11 31]
[12 32]]]
---------------
[[[ 1 21]
[ 5 25]
[ 9 29]]
[[ 2 22]
[ 6 26]
[10 30]]
[[ 3 23]
[ 7 27]
[11 31]]
[[ 4 24]
[ 8 28]
[12 32]]]
---------------
[[[ 1 5 9]
[21 25 29]]
[[ 2 6 10]
[22 26 30]]
[[ 3 7 11]
[23 27 31]]
[[ 4 8 12]
[24 28 32]]]
---------------