1. 程式人生 > >tensorflow相關tensor計算函式

tensorflow相關tensor計算函式

1. tf.split

該函式主要用於對tensor進行分割,一般在設定多GPU平行計算時經常會被用到,主要是將一個batch資料集進行平分,分配給各個GPU,最後再彙總各個GPU得到的損失,從而加快模型的訓練速度,其主要引數的定義如下:

  • value:待分割的 `Tensor` .
  • num_or_size_splits: 可以是一個整數,表示分割的後的數量,也可以是一個整數列表,表示分割後每一份的size
  • axis:分割的維度,預設的第一維 
import tensorflow as tf

tf.split(
    value, 
    num_or_size_splits, 
    axis=0, 
    num=None, 
    name="split"
)

2. tf.add_n

該函式主要是對輸入的tensor列表中每一個tensor進行加總,要求每個tensor的維度必須相同,當開啟平行計算時,該函式也經常被用來計算各個GPU得到的損失,其主要引數定義如下:

  • inputs:一個tensor列表
import tensorflow as tf

tf.add_n(inputs, name=None)