tensorflow源碼解析之framework-tensor
目錄
- 核心概念
- tensor
- tensor_reference
- tensor_shape
- tensor_slice
- protos
1. 核心概念
TF的核心數據結構Tensor表示一個張量,它基於eigen3庫,並提供了豐富的API。為了方便引用張量的底層數據,設計了TensorInference類。TensorShape用於表示張量的形狀和數據類型等信息,TensorSlice用於表示張量的索引。
2. tensor
TF全稱叫做TensorFlow,可見tensor的重要性。TF中的tensor基於eigen3庫,是對多維數據的一個封裝。Tensor類包含的數據成員非常簡單:
class Tensor { //... private: TensorShape shape_; TensorBuffer* buffer_; }
顧名思義,一個是張量的形狀,一個是指向底層數據的指針。Tensor作為一個核心數據結構,必然提供了很多API接口,比如常規的構造、析構、賦值、復制、數值屬性獲取等。除此之外,還提供了兩類比較特殊的接口,我們舉例說明:
class Tensor { public: //... //與proto數據的相互轉化 bool FromProto(const TensorProto& other); void AsProtoField(TensorProto* proto); //為底層數據創建新視圖 template <typename T> typename TTypes<T>::Vec vec(); template <typename T> typename TTypes<T>::Matrix matrix(); template <typename T> typename TTypes<T, NDIMS>::Tensor tensor(); }
其中第一類將Tensor與序列化的proto之間相互轉化,在設備間相互傳遞Tensor時,需要先將其序列化。第二類是為當前的Tensor的底層數據提供另外一種視圖,我們重點來說一下視圖的概念。
回顧Tensor包含的私有數據,TensorBuffer* buffer_是一個指向底層數據的指針,關於它的結構在下文中會詳細說明。也就是說,Tensor並不包含實際的底層數據,它實際上只是對底層數據的一種視圖。同樣一份底層數據,可以提供多種視圖。比如對於一個長度為12的數組,如果把它看做向量,它是一個1x12的向量,如果把它看作矩陣,可以認為是3x4或者2x6的矩陣,如果把它當作張量,可以認為是3x2x2的張量。通過這種方法,我們可以對同一份底層數據進行復用,避免了重復申請內存空間,提升了效率。numpy中對多維數組的實現,也是同樣的道理。
接下來我們看一下TensorBuffer到底是什麽樣的結構。找到它的定義,發現它只是一個繼承自引用計數類的虛擬接口,不包含任何實現:
class TensorBuffer : public core::RefCounted {
//...
}
因此懷疑,TensorBuffer只是一個提供接口的基類,實際上能用的只是它的子類。我們看下它的繼承結構:
class BufferBase : public TensorBuffer {
//...
}
class Buffer : public BufferBase {
//...
private:
T* data_;
int64 elem_;
}
結構已經非常清晰了,BufferBase類繼承自TensorBuffer,它除了包含一個內存分配器指針外,還對基類中的部分API進行了實現。而Buffer類是實際可用的,它包含了指向實際數據的指針data_以及元素數量elem_。
另外還要說明一點,Buffer除了申請內存之外,還能調用目標類的構造和析構函數,初始化Buffer的內容,TF為此設計了很多輔助類和函數,這裏就不一一贅述了。
3. tensor_reference
Tensor類的對象除了包含指向底層數據的指針外,還包含了對數據形狀和類型的描述,如果我們並不關心這些,直接使用Tensor會增加構建或者移動的負擔。因此TF推出了tensor_reference這個類,它僅包含了一個指向TensorBuffer的指針,並且每增加一個TensorReference對象,就會增加一個針對底層TensorBuffer的引用計數。因此針對TensorReference來說,我們唯一能做的就是在用完之後Unref掉,否則會造成內存泄漏。
4. tensor_shape
TensorShape相關的核心類繼承體系如下:
graph LR
TensorShape-->TensorShapeBase
TensorShapeBase-->TensorShapeRep
首先來看一下,最底層的TensorShapeRep的私有數據成員:
class TensorShapeRep {
//...
private:
union {
uint8 buf[16];
Rep64* unused_aligner;//除了強制u_與指針對齊外,沒有任何作用
} u_;
int64 num_elements_;
}
}
buf這個數組很有意思,它的前12個元素用來存儲形狀,雖然Tensor最高能支持到256維的張量,但最常用的不超過3維,為了效率,TF提供了三種利用這12個字節的方式,如下:
struct Rep16 {
uint16 dims_[6];//最多可表示6維的張量,每一維的長度不超過2^16-1
};
struct Rep32 {
uint32 dims_[3];//最多可表示3維的張量,每一維的長度不超過2^32-1
};
struct Rep64 {
gtl::InlinedVector<int64, 4>* dims_;//支持任意維度的張量
};
剩下的4個字節也不能浪費,在第14-16個字節中,分別存儲了張量中的數據類型編號、張量的維度數目、張量維度的表示類型(Rep16, Rep32, Rep64)。由於張量維度的數目是用一個字節存儲的,因此最多支持256維。可惜筆者目前仍沒有發現第13個字節的作用,有發現的讀者歡迎告知我。
TensorShapeBase類並沒有添加額外的數據成員,它只是添加了一些允許我們修改張量維度的API接口。
最後再來看下PartialTensorShape類,在構造一個張量的形狀時,如果對於某些維度我們還不知道具體的維度值,可以把這個維度設為未知,因此就會用到PartialTensorShape類,這個類中也包含了一些未知維度操作的API,這裏就不詳述了。
5. tensor_slice
TensorSlice類表示一個張量的索引,它的數據結構非常簡單:
class TensorSlice {
//...
private:
gtl::InlinedVector<int64,4> starts_;
gtl::InlinedVector<int64,4> lengths_;
}
分別是每一個維度索引的開始位置和索引長度,由此我們也知道,TF對Tensor只支持連續索引,不支持間隔索引。
由於TensorSlice用途廣泛,對其進行初始化的方法也多種多樣,包括:
- 創建空索引
- 從單個維度創建(當創建全索引時)
- 從一個整數對數組創建
- 從一個TensorSliceProto創建
- 從一個字符串描述中創建
6. protos
為了方便對張量和與之相關的數據結構進行序列化,TF設計了很多protos,理解起來相對簡單,現只說明下它們的用途,感興趣的讀者可以去看源代碼。
message TensorDescription;//張量的描述,包括數據類型、形狀、內存分配信息
message TensorProto;//張量的數據類型,版本,原始數據等
message VariantTensorDataProto;//對DT_VARIANT類型的序列化表示
message TensorShapeProto;//張量形狀
message TensorSliceProto;//張量索引
tensorflow源碼解析之framework-tensor