1. 程式人生 > >Tensorflow原始碼解析2 -- 前後端連線的橋樑

Tensorflow原始碼解析2 -- 前後端連線的橋樑

1 Session概述

Session是TensorFlow前後端連線的橋樑。使用者利用session使得client能夠與master的執行引擎建立連線,並通過session.run()來觸發一次計算。它建立了一套上下文環境,封裝了operation計算以及tensor求值的環境。

session建立時,系統會分配一些資源,比如graph引用、要連線的計算引擎的名稱等。故計算完畢後,需要使用session.close()關閉session,避免引起記憶體洩漏,特別是graph無法釋放的問題。可以顯式呼叫session.close(),或利用with上下文管理器,或者直接使用InteractiveSession。

session之間採用共享graph的方式來提高執行效率。一個session只能執行一個graph例項,但一個graph可以執行在多個session中。一般情況下,建立session時如果不指定Graph例項,則會使用系統預設Graph。常見情況下,我們都是使用一個graph,即預設graph。當session建立時,不會重新建立graph例項,而是預設graph引用計數加1。當session close時,引用計數減1。只有引用計數為0時,graph才會被回收。這種graph共享的方式,大大減少了graph建立和回收的資源消耗,優化了TensorFlow執行效率。

2 預設session

op運算和tensor求值時,如果沒有指定執行在哪個session中,則會執行在預設session中。通過session.as_default()可以將自己設定為預設session。但個人建議最好還是通過session.run(operator)和session.run(tensor)來進行op運算和tensor求值。

operation.run()

operation.run()等價於tf.get_default_session().run(operation)

@tf_export("Operation")
class Operation(object):
    # 通過operation.run()呼叫,進行operation計算
    def run(self, feed_dict=None, session=None):
        _run_using_default_session(self, feed_dict, self.graph, session)
        
    def
_run_using_default_session(operation, feed_dict, graph, session=None): # 沒有指定session,則獲取預設session if session is None: session = get_default_session() # 最終還是通過session.run()進行執行的。tf中任何運算,都是通過session來run的。 # 通過session來建立client和master的連線,並將graph傳送給master,master再進行執行 session.run(operation, feed_dict)

tensor.eval()

tensor.eval()等價於tf.get_default_session().run(tensor), 如下

@tf_export("Tensor")
class Tensor(_TensorLike):
    # 通過tensor.eval()呼叫,進行tensor運算
	def eval(self, feed_dict=None, session=None):
		return _eval_using_default_session(self, feed_dict, self.graph, session)
		
	def _eval_using_default_session(tensors, feed_dict, graph, session=None):
           # 如果沒有指定session,則獲取預設session
		  if session is None:
    		session = get_default_session()
  		  return session.run(tensors, feed_dict)

預設session的管理

tf通過執行時維護的session本地執行緒棧,來管理預設session。故不同的執行緒會有不同的預設session,預設session是執行緒作用域的。

# session棧
_default_session_stack = _DefaultStack()

# 獲取預設session的介面
@tf_export("get_default_session")
def get_default_session():
  return _default_session_stack.get_default()

# _DefaultStack預設session棧是執行緒相關的
class _DefaultStack(threading.local):
      # 預設session棧的建立,其實就是一個list
      def __init__(self):
    	super(_DefaultStack, self).__init__()
    	self._enforce_nesting = True
    	self.stack = []  
        
      # 獲取預設session
      def get_default(self):
    	return self.stack[-1] if len(self.stack) >= 1 else None

3 前端Session型別

session類圖

會話Session的UML類圖如下

分為兩種型別,普通Session和互動式InteractiveSession。InteractiveSession和Session基本相同,區別在於

  1. InteractiveSession建立後,會將自己替換為預設session。使得之後operation.run()和tensor.eval()的執行通過這個預設session來進行。特別適合Python互動式環境。
  2. InteractiveSession自帶with上下文管理器。它在建立時和關閉時會呼叫上下文管理器的enter和exit方法,從而進行資源的申請和釋放,避免記憶體洩漏問題。這同樣很適合Python互動式環境。

Session和InteractiveSession的程式碼邏輯不多,主要邏輯均在其父類BaseSession中。主要程式碼如下

@tf_export('Session')
class Session(BaseSession):
    def __init__(self, target='', graph=None, config=None):
        # session建立的主要邏輯都在其父類BaseSession中
        super(Session, self).__init__(target, graph, config=config)
    	self._default_graph_context_manager = None
    	self._default_session_context_manager = None
@tf_export('InteractiveSession')
class InteractiveSession(BaseSession):
    def __init__(self, target='', graph=None, config=None):
        self._explicitly_closed = False
        
        # 將自己設定為default session
    	self._default_session = self.as_default()
    	self._default_session.enforce_nesting = False
        
        # 自動呼叫上下文管理器的__enter__()方法
    	self._default_session.__enter__()
    	self._explicit_graph = graph
       
    def close(self):
    	super(InteractiveSession, self).close()
    	## 省略無關程式碼
        ## 自動呼叫上下文管理器的__exit__()方法,避免記憶體洩漏
        self._default_session.__exit__(None, None, None)
        self._default_session = None

BaseSession

BaseSession基本包含了所有的會話實現邏輯。包括會話的整個生命週期,也就是建立 執行 關閉和銷燬四個階段。生命週期後面詳細分析。BaseSession包含的主要成員變數有graph引用,序列化的graph_def, 要連線的tf引擎target,session配置資訊config等。

4 後端Session型別

在後端master中,根據前端client呼叫tf.Session(target=’’, graph=None, config=None)時指定的target,來建立不同的Session。target為要連線的tf後端執行引擎,預設為空字串。Session建立採用了抽象工廠模式,如果為空字串,則建立本地DirectSession,如果以grpc://開頭,則建立分散式GrpcSession。類圖如下

DirectSession只能利用本地裝置,將任務建立到本地的CPU GPU上。而GrpcSession則可以利用遠端分散式裝置,將任務建立到其他機器的CPU GPU上,然後通過grpc協議進行通訊。grpc協議是谷歌發明並開源的遠端通訊協議。

5 Session生命週期

Session作為前後端連線的橋樑,以及上下文執行環境,其生命週期尤其關鍵。大致分為4個階段

  1. 建立:通過tf.Session()建立session例項,進行系統資源分配,特別是graph引用計數加1
  2. 執行:通過session.run()觸發計算的執行,client會將整圖graph傳遞給master,由master進行執行
  3. 關閉:通過session.close()來關閉,會進行系統資源的回收,特別是graph引用計數減1.
  4. 銷燬:Python垃圾回收器進行GC時,呼叫session.__del__()進行回收。

生命週期方法入口基本都在前端Python的BaseSession中,它會通過swig自動生成的函式符號對映關係,呼叫C層的實現。

5.1 建立

先從BaseSession類的init方法看起,只保留了主要程式碼。

def __init__(self, target='', graph=None, config=None):
    # graph表示構建的圖。TensorFlow的一個session會對應一個圖。這個圖包含了所有涉及到的運算元
    # graph如果沒有設定(通常都不會設定),則使用預設graph
    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      self._graph = graph

    self._opened = False
    self._closed = False

    self._current_version = 0
    self._extend_lock = threading.Lock()

    # target為要連線的tf執行引擎
    if target is not None:
      self._target = compat.as_bytes(target)
    else:
      self._target = None

    self._delete_lock = threading.Lock()
    self._dead_handles = []

    # config為session的配置資訊
    if config is not None:
      self._config = config
      self._add_shapes = config.graph_options.infer_shapes
    else:
      self._config = None
      self._add_shapes = False

    self._created_with_new_api = ops._USE_C_API

    # 呼叫C層來建立session
    self._session = None
    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
    self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, status)
        

BaseSession先進行成員變數的賦值,然後呼叫TF_NewSession來建立session。TF_NewSession()方法由swig自動生成,在bazel-bin/tensorflow/python/pywrap_tensorflow_internal.py中

def TF_NewSession(graph, opts, status):
    return _pywrap_tensorflow_internal.TF_NewSession(graph, opts, status)

_pywrap_tensorflow_internal包含了C層函式的符號表。在swig模組import時,會載入pywrap_tensorflow_internal.so動態連結庫,從而得到符號表。在pywrap_tensorflow_internal.cc中,註冊了供Python呼叫的函式的符號表,從而實現Python到C的函式對映和呼叫。

// c++函式呼叫的符號表,Python通過它可以呼叫到C層程式碼。符號表和動態連結庫由swig自動生成
static PyMethodDef SwigMethods[] = {
      // .. 省略其他函式定義
	 // TF_NewSession的符號表,通過這個對映,Python中就可以呼叫到C層程式碼了。
	 { (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL},
     // ... 省略其他函式定義
}

最終呼叫到c_api.c中的TF_NewSession()

// TF_NewSession建立session的新實現,在C層後端程式碼中
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
                          TF_Status* status) {
  Session* session;
  // 建立session
  status->status = NewSession(opt->options, &session);
  if (status->status.ok()) {
    TF_Session* new_session = new TF_Session(session, graph);
    if (graph != nullptr) {
      // 採用了引用計數方式,多個session共享一個圖例項,效率更高。
      // session建立時,引用計數加1。session close時引用計數減1。引用計數為0時,graph才會被回收。
      mutex_lock l(graph->mu);
      graph->sessions[new_session] = Status::OK();
    }
    return new_session;
  } else {
    DCHECK_EQ(nullptr, session);
    return nullptr;
  }
}

session建立時,並建立graph,而是採用共享方式,只是引用計數加1了。這種方式減少了session建立和關閉時的資源消耗,提高了執行效率。NewSession()根據前端傳遞的target,使用sessionFactory建立對應的TensorFlow::Session例項。

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  const Status s = SessionFactory::GetFactory(options, &factory);

  // 通過sessionFactory建立多型的Session。本地session為DirectSession,分散式為GRPCSession
  *out_session = factory->NewSession(options);
  if (!*out_session) {
    return errors::Internal("Failed to create session.");
  }
  return Status::OK();
}

建立session採用了抽象工廠模式。根據client傳遞的target,來建立不同的session。如果target為空字串,則建立本地DirectSession。如果以grpc://開頭,則建立分散式GrpcSession。TensorFlow包含本地執行時和分散式執行時兩種執行模式。

下面來看DirectSessionFactory的NewSession()方法

class DirectSessionFactory : public SessionFactory {
 public:
  Session* NewSession(const SessionOptions& options) override {
    std::vector<Device*> devices;
      
    // job在本地執行
    const Status s = DeviceFactory::AddDevices(
        options, "/job:localhost/replica:0/task:0", &devices);
    if (!s.ok()) {
      LOG(ERROR) << s;
      return nullptr;
    }

    DirectSession* session =
        new DirectSession(options, new DeviceMgr(devices), this);
    {
      mutex_lock l(sessions_lock_);
      sessions_.push_back(session);
    }
    return session;
  }

GrpcSessionFactory的NewSession()方法就不詳細分析了,它會將job任務建立在分散式裝置上,各job通過grpc協議通訊。

5.2 執行

通過session.run()可以啟動graph的執行。入口在BaseSession的run()方法中, 同樣只列出關鍵程式碼

class BaseSession(SessionInterface):
    def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
        # fetches可以為單個變數,或者陣列,或者元組。它是圖的一部分,可以是操作operation,也可以是資料tensor,或者他們的名字String
        # feed_dict為對應placeholder的實際訓練資料,它的型別為字典
        result = self._run(None, fetches, feed_dict, options_ptr,run_metadata_ptr)
        return result
    
    def _run(self, handle, fetches, feed_dict, options, run_metadata):
    	# 建立fetch處理器fetch_handler
        fetch_handler = _FetchHandler(
            self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)

        # 經過不同型別的fetch_handler處理,得到最終的fetches和targets
        # targets為要執行的operation,fetches為要執行的tensor
        _ = self._update_with_movers(feed_dict_tensor, feed_map)
        final_fetches = fetch_handler.fetches()
        final_targets = fetch_handler.targets()

        # 開始執行
        if final_fetches or final_targets or (handle and feed_dict_tensor):
          results = self._do_run(handle, final_targets, final_fetches,
                                 feed_dict_tensor, options, run_metadata)
        else:
          results = []

        # 輸出結果到results中
        return fetch_handler.build_results(self, results)

    def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata):
        # 將要執行的operation新增到graph中
        self._extend_graph()
        
        # 執行一次執行run,會呼叫底層C來實現
        return tf_session.TF_SessionPRunSetup_wrapper(
              session, feed_list, fetch_list, target_list, status)
      
    # 將要執行的operation新增到graph中
    def _extend_graph(self):
        with self._extend_lock:
          if self._graph.version > self._current_version:
            # 生成graph_def物件,它是graph的序列化表示
            graph_def, self._current_version = self._graph._as_graph_def(
                from_version=self._current_version, add_shapes=self._add_shapes)

            # 通過TF_ExtendGraph將序列化後的graph,也就是graph_def傳遞給後端
            with errors.raise_exception_on_not_ok_status() as status:
              tf_session.TF_ExtendGraph(self._session,
                                        graph_def.SerializeToString(), status)
            self._opened = True

邏輯還是十分複雜的,主要有一下幾步

  1. 入參處理,建立fetch處理器fetch_handler,得到最終要執行的operation和tensor
  2. 對graph進行序列化,生成graph_def物件
  3. 將序列化後的grap_def物件傳遞給後端master。
  4. 通過後端master來run。

我們分別來看extend和run。

5.2.1 extend新增節點到graph中

TF_ExtendGraph()會呼叫到c_api中,這個邏輯同樣通過swig工具自動生成。下面看c_api.cc中的TF_ExtendGraph()方法

// 增加節點到graph中,proto為序列化後的graph
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
                    size_t proto_len, TF_Status* status) {
  GraphDef g;
  // 先將proto反序列化,得到client傳遞的graph,放入g中
  if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
    status->status = InvalidArgument("Invalid GraphDef");
    return;
  }

  // 再呼叫session的extend方法。根據建立的不同session型別,多型呼叫不同方法。
  status->status = s->session->Extend(g);
}

後端系統根據生成的Session型別,多型的呼叫Extend方法。如果是本地session,則呼叫DirectSession的Extend()方法。如果是分散式session,則呼叫GrpcSession的相關方法。下面來看GrpcSession的Extend方法。

Status GrpcSession::Extend(const GraphDef& graph) {
  CallOptions call_options;
  call_options.SetTimeout(options_.config.operation_timeout_in_ms());
  return ExtendImpl(&call_options, graph);
}

Status GrpcSession::ExtendImpl(CallOptions* call_options,
                               const GraphDef& graph) {
  bool handle_is_empty;
  {
    mutex_lock l(mu_);
    handle_is_empty = handle_.empty();
  }
  if (handle_is_empty) {
    // 如果graph控制代碼為空,則表明graph還沒有建立好,此時extend就等同於create
    return Create(graph);
  }
  mutex_lock l(mu_);
  ExtendSessionRequest req;
  req.set_session_handle(handle_);
  *req.mutable_graph_def() = graph;
  req.set_current_graph_version(current_graph_version_);
  ExtendSessionResponse resp;
    
  // 呼叫底層實現,來新增節點到graph中
  Status s = master_->ExtendSession(call_options, &req, &resp);
  if (s.ok()) {
    current_graph_version_ = resp.new_graph_version();
  }
  return s;
}

Extend()方法中要注意的一點是,如果是首次執行Extend(), 則要先呼叫Create()方法進行graph的註冊。否則才是執行新增節點到graph中。

5.2.2 run執行圖的計算

同樣,Python通過swig自動生成的程式碼,來實現對C API的呼叫。C層實現在c_api.cc的TF_Run()中。

// session.run()的C層實現
void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
            // Input tensors,輸入的資料tensor
            const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
            // Output tensors,執行計算後輸出的資料tensor
            const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
            // Target nodes,要執行的節點
            const char** c_target_oper_names, int ntargets,
            TF_Buffer* run_metadata, TF_Status* status) {
  // 省略一段程式碼
  TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
                c_outputs, target_oper_names, run_metadata, status);
}

// 真正的實現了session.run()
static void TF_Run_Helper() {
    RunMetadata run_metadata_proto;
    // 呼叫不同的session實現類的run方法,來執行
    result = session->Run(run_options_proto, input_pairs, output_tensor_names,
                          target_oper_names, &outputs, &run_metadata_proto);
    // 省略程式碼
}

最終會呼叫建立的session來執行run方法。DirectSession和GrpcSession的Run()方法會有所不同。後面很複雜,就不接著分析了。

5.3 關閉session

通過session.close()來關閉session,釋放相關資源,防止記憶體洩漏。

class BaseSession(SessionInterface):
	def close(self):
        tf_session.TF_CloseSession(self._session, status)

會呼叫到C API的TF_CloseSession()方法。

void TF_CloseSession(TF_Session* s, TF_Status* status) {
  status->status = s->session->Close();
}

最終根據建立的session,多型的呼叫其Close()方法。同樣分為DirectSession和GrpcSession兩種。

::tensorflow::Status DirectSession::Close() {
    cancellation_manager_->StartCancel();
  {
    mutex_lock l(closed_lock_);
    if (closed_) return ::tensorflow::Status::OK();
    closed_ = true;
  }
    
  // 登出session
  if (factory_ != nullptr) factory_->Deregister(this);
  return ::tensorflow::Status::OK();
}

DirectSessionFactory中的Deregister()方法如下

void Deregister(const DirectSession* session) {
    mutex_lock l(sessions_lock_);
    // 釋放相關資源
    sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
                    sessions_.end());
  }

5.4 銷燬session

session的銷燬是由Python的GC自動執行的。python通過引用計數方法來判斷是否回收物件。當物件的引用計數為0,且虛擬機器觸發了GC時,會呼叫物件的__del__()