1. 程式人生 > >Caffe2原始碼理解系列之IO

Caffe2原始碼理解系列之IO

Caffe2 IO

本文主要記錄下我對Caffe2的輸入輸出部分原始碼的理解。資料是以什麼樣的形式輸入進網路的,訓練過程中如何儲存網路模型。與資料輸入相關的Operator是DBReader, ImageInputOp, 與儲存訓練過程中儲存模型相關資訊的是SaveOp, LoadOp,以及一系列與序列化相關的工具類,比如BlobSerializer。下面分別介紹一下,如有理解錯誤,歡迎指出。PS,Caffe2的程式碼寫得真心贊啊。

  • DBReader
  • ImageInputOp
  • SaveOp
  • LoadOp
  • 總結

DBReader

如同Caffe1一樣,一般情況下,在進行模型訓練的時候,Caffe2也需要事先將資料轉成特定格式的資料庫,比如lmdb, leveldb。只不過Caffe2支援的資料庫格式更加豐富,除了上述兩種格式的db外,還有minidb, zmqdb, protodb, rocksdb等等。Caffe2中對lmdb的實現跟Caffe1有所不同,但功能是一樣的。PS,個人以為Caffe1中的實現要優雅些,因為我直接在windows上用Caffe2自帶的lmdb.cc來生成資料庫時執行不通過,直接改成Caffe1中的就OK了。另外由於Caffe2在預設儲存模型時候使用的是minidb, 所以簡單地介紹下minidb。

DBReader封裝瞭如何讀取資料庫的操作。注意在單機多GPU情況下DBReader只有一個例項,為各個GPU共享。在多機的情況下,每臺機器有一個DBReader例項,通過DBReader中的成員變數shard_id_來標識該節點負責讀取哪一部分的資料庫。通常,每一臺機器都會有一份完整的相同的資料庫,當然也可以通過nfs將資料庫從一臺機器對映給其他機器。讀取同一個資料庫的時候。DBReader自動會對資料進行切片,保證每個節點的每個GPU讀取資料庫的不同部分,以此達到資料並行。DBReader的摘要如下:

class DBReader {
...
private:
  string db_type_; //資料庫的型別,包括minidb,leveldb,lmdb等等
string source_; //資料庫的路徑 unique_ptr<DB> db_; //資料庫物件 unique_ptr<Cursor> cursor_; //資料庫遊標 mutable std::mutex reader_mutex_;//單機多GPU環境下,應該是多執行緒進行訓練,多執行緒共享同一個DBReader例項,因此需要用這個reader_mutex來控制對共享變數的訪問。 uint32_t num_shards_; //單機環境下,該值為0,分散式環境下,該值為節點數目。 uint32_t shard_id_; //節點id,從0開始,單機情況下為0,依次遞增,
DISABLE_COPY_AND_ASSIGN(DBReader); public: void Open(const string& db_type, const string& source, const int32_t num_shards = 1, const int32_t shard_id = 0) { //開啟資料庫,該函式會在建構函式裡被呼叫 cursor_.reset(); db_.reset(); db_type_ = db_type; source_ = source; db_ = CreateDB(db_type_, source_, READ); CAFFE_ENFORCE(db_, "Cannot open db: ", source_, " of type ", db_type_); InitializeCursor(num_shards, shard_id); } // for i = 0: batch_size, call Read void Read(string* key, string* value) const { CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized."); std::unique_lock<std::mutex> mutex_lock(reader_mutex_);//這裡注意,只對單機多GPU會阻塞,不同機器之間不會阻塞,因為是不同的DBReader例項,多機通訊會通過rendezvous進行同步,比如redis _store_handler等。 *key = cursor_->key(); *value = cursor_->value(); // 在分散式環境下,由於一次有num_shards臺機器參與讀取資料,因此一次計算讀取的資料量有num_shards * 每臺機器讀取的資料量,所以對於每一臺機器而言,這裡要跳過num_shards個記錄,才是它下一次迭代應該讀取的資料庫位置 for (int s = 0; s < num_shards_; s++) { cursor_->Next(); if (!cursor_->Valid()) { MoveToBeginning(); break; } } } ... };

DB, Transaction, Cursor三個介面類定義瞭如何操作資料庫。對於不同型別的資料庫,會有相應的實現,比如針對lmdb,就有LMDB, LMDBTransaction, LMDBCursor,針對minidb,就有MiniDB, MiniDBTransaction, MiniDBCursor。從Caffe2中實現的lmdb,minidb, leveldb來看,讀資料庫只支援順序讀取,即cursor從頭到尾順序訪問資料庫,當訪問到資料庫末尾時候,cursor又從頭開始,因此並不支援對資料庫的隨機訪問。DB的摘要如下:

class DB {
 public:
  DB(const string& /*source*/, Mode mode) : mode_(mode) {}
  virtual ~DB() { }
  /**
   * Closes the database.
   */
  virtual void Close() = 0;
  /**
   * Returns a cursor to read the database. The caller takes the ownership of
   * the pointer.
   */
  virtual std::unique_ptr<Cursor> NewCursor() = 0;
  /**
   * Returns a transaction to write data to the database. The caller takes the
   * ownership of the pointer.
   */
  virtual std::unique_ptr<Transaction> NewTransaction() = 0;

 protected:
  Mode mode_; //這個mode定義為enum Mode { READ, WRITE, NEW };

  DISABLE_COPY_AND_ASSIGN(DB);
};

minidb相關操作

minidb其實就是簡單地封裝了C語言中的檔案IO呼叫, 沒啥特別之處,直接把caffe2/core/db.cc中的程式碼貼出來。因為有這個minidb的存在,因此Caffe2就不像Caffe1中有辣麼多依賴軟體了。lmdb和leveldb對Caffe2來說就是可選的了。不過,minidb的功能肯定不如lmdb了(個人猜測,minidb的讀寫效率啊,估計也沒有lmdb高)。


class MiniDBCursor : public Cursor {
 public:
  explicit MiniDBCursor(FILE* f, std::mutex* mutex)
    : file_(f), lock_(*mutex), valid_(true) {
    // We call Next() to read in the first entry.
    Next();
  }
  ~MiniDBCursor() {}

  void Seek(const string& /*key*/) override {
    LOG(FATAL) << "MiniDB does not support seeking to a specific key.";
  }

  void SeekToFirst() override {
    fseek(file_, 0, SEEK_SET);
    CAFFE_ENFORCE(!feof(file_), "Hmm, empty file?");
    // Read the first item.
    valid_ = true;
    Next();
  }

  void Next() override {
    // First, read in the key and value length.
    if (fread(&key_len_, sizeof(int), 1, file_) == 0) {
      // Reaching EOF.
      VLOG(1) << "EOF reached, setting valid to false";
      valid_ = false;
      return;
    }
    CAFFE_ENFORCE_EQ(fread(&value_len_, sizeof(int), 1, file_), 1);
    CAFFE_ENFORCE_GT(key_len_, 0);
    CAFFE_ENFORCE_GT(value_len_, 0);
    // Resize if the key and value len is larger than the current one.
    if (key_len_ > key_.size()) {
      key_.resize(key_len_);
    }
    if (value_len_ > value_.size()) {
      value_.resize(value_len_);
    }
    // Actually read in the contents.
    CAFFE_ENFORCE_EQ(
        fread(key_.data(), sizeof(char), key_len_, file_), key_len_);
    CAFFE_ENFORCE_EQ(
        fread(value_.data(), sizeof(char), value_len_, file_), value_len_);
    // Note(Yangqing): as we read the file, the cursor naturally moves to the
    // beginning of the next entry.
  }

  string key() override {
    CAFFE_ENFORCE(valid_, "Cursor is at invalid location!");
    return string(key_.data(), key_len_);
  }

  string value() override {
    CAFFE_ENFORCE(valid_, "Cursor is at invalid location!");
    return string(value_.data(), value_len_);
  }

  bool Valid() override { return valid_; }

 private:
  FILE* file_;
  std::lock_guard<std::mutex> lock_;
  bool valid_;
  int key_len_;
  vector<char> key_;
  int value_len_;
  vector<char> value_;
};

class MiniDBTransaction : public Transaction {
 public:
  explicit MiniDBTransaction(FILE* f, std::mutex* mutex)
    : file_(f), lock_(*mutex) {}
  ~MiniDBTransaction() {
    Commit();
  }

  void Put(const string& key, const string& value) override {
    int key_len = key.size();
    int value_len = value.size();
    CAFFE_ENFORCE_EQ(fwrite(&key_len, sizeof(int), 1, file_), 1);
    CAFFE_ENFORCE_EQ(fwrite(&value_len, sizeof(int), 1, file_), 1);
    CAFFE_ENFORCE_EQ(
        fwrite(key.c_str(), sizeof(char), key_len, file_), key_len);
    CAFFE_ENFORCE_EQ(
        fwrite(value.c_str(), sizeof(char), value_len, file_), value_len);
  }

  void Commit() override {
    if (file_ != nullptr) {
      CAFFE_ENFORCE_EQ(fflush(file_), 0);
      file_ = nullptr;
    }
  }

 private:
  FILE* file_;
  std::lock_guard<std::mutex> lock_;

  DISABLE_COPY_AND_ASSIGN(MiniDBTransaction);
};

class MiniDB : public DB {
 public:
  MiniDB(const string& source, Mode mode) : DB(source, mode), file_(nullptr) {
    switch (mode) {
      case NEW:
        file_ = fopen(source.c_str(), "wb");
        break;
      case WRITE:
        file_ = fopen(source.c_str(), "ab");
        fseek(file_, 0, SEEK_END);
        break;
      case READ:
        file_ = fopen(source.c_str(), "rb");
        break;
    }
    CAFFE_ENFORCE(file_, "Cannot open file: " + source);
    VLOG(1) << "Opened MiniDB " << source;
  }
  ~MiniDB() { Close(); }

  void Close() override {
    if (file_) {
      fclose(file_);
    }
    file_ = nullptr;
  }

  unique_ptr<Cursor> NewCursor() override {
    CAFFE_ENFORCE_EQ(this->mode_, READ);
    return make_unique<MiniDBCursor>(file_, &file_access_mutex_);
  }

  unique_ptr<Transaction> NewTransaction() override {
    CAFFE_ENFORCE(this->mode_ == NEW || this->mode_ == WRITE);
    return make_unique<MiniDBTransaction>(file_, &file_access_mutex_);
  }

 private:
  FILE* file_;
  // access mutex makes sure we don't have multiple cursors/transactions
  // reading the same file.
  std::mutex file_access_mutex_;
};

ImageInputOp

編譯這個ImageInputOp需要opencv的支援。這個operator就是真正把資料庫中儲存的資料轉換成CNN訓練用的圖片了。它就類似於Caffe1中的BasePrefetchingDataLayer,但ImageInputOp的功能比BasePrefetchDataLayer強大得多。除了支援像BasePrefetchDataLayer那樣的隨機裁剪,映象,resize圖片大小之外,還支援更加豐富的Data Augmentation, 比如顏色擾動,對比度,飽和度等,googlenet和resnet中做的資料增廣都已經實現了。另一個顯著的地方是,ImageInputOp除了支援單標籤外,也支援多標籤。ImageInputOp的輸出資料格式是NHWC的形式,雖然Caffe2支援NHWC,NCHW兩種資料格式,它預設支援的資料格式Caffe1的資料格式,即NCHW。預設情況下,當用python來訓練時,呼叫ImageInput時,arg_scope的order預設情況是NCHW, Caffe2的python介面會自動新增NHWC2NCHWOp進行資料排布轉換。

ImageInputOp是一個典型的單生產者,單消費者,只有一個緩衝區容量的非同步OP。對於batchsize個樣本的解碼進行資料增廣操作又是多執行緒並行的。下面分別介紹一下:

生產者消費者模式體現在ImageInputOp的父類PrefetchOp中。

消費者

//每次前傳時候,會呼叫這個Run方法,通知生產者進行生產資料。但這裡為啥沒有一個如同Caffe1一樣設定一個大小為PREFETCH_COUNT容量的緩衝區,讓生產者不停生產,緩衝區滿了後再停止,而是每次都現消費現來生產了?想不通哈,請大神幫助解釋一下。 PS,不過我自己在訓練的時候,並沒法先卡在IO,估計這個IO過程很快吧。
  bool Run(int /* unused */ /*stream_id*/) override { 
    if (!prefetch_thread_) {
      prefetch_thread_.reset(new std::thread([this] { this->PrefetchWorker(); }));
    }
    context_.SwitchToDevice(0);
    std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
    while (!prefetched_)
      consumer_.wait(lock);
    if (!prefetch_success_) {
      LOG(ERROR) << "Prefetching failed.";
      return false;
    }
    if (!CopyPrefetched()) { //CopyPrefetched表示消費,正確返回就表示消費完了,然後就通知生產者繼續生產。
      LOG(ERROR) << "Error when copying prefetched data.";
      return false;
    }
    prefetched_ = false;
    context_.FinishDeviceComputation();
    producer_.notify_one();
    return true;
  }

生產者

  void PrefetchWorker() {
    context_.SwitchToDevice();
    std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
    while (prefetched_)
      producer_.wait(lock);
    while (!finalize_) {
      // We will need to run a FinishDeviceComputation() call because the
      // prefetcher thread and the main thread are potentially using different
      // streams (like on GPU).
      try {
        prefetch_success_ = Prefetch();//Prefetch就代表生產資料了,它是個虛擬函式,ImageInputOp會實現之。
        context_.FinishDeviceComputation();
      } catch (const std::exception& e) {
        // TODO: propagate exception_ptr to the caller side
        LOG(ERROR) << "Prefetching error " << e.what();
        prefetch_success_ = false;
      }
      prefetched_ = true;
      consumer_.notify_one();
      while (prefetched_) //我理解的單生產者單消費者就在此,Pretch完,就等待消費者消費,直到消費完了,難道說,可以在Pretch中擴大緩衝區麼?
        producer_.wait(lock);
    }
  }

在ImageInputOp中的多執行緒解碼轉換資料部分,就體現在成員變數thread_pool_了,它是個執行緒池TaskThreadPool的智慧指標。下面是對解碼部分的一個簡單說明

  for (int item_id = 0; item_id < batch_size_; ++item_id) { 
  .....
  //先做一些必要的準備操作
    thread_pool_->runTaskWithID(std::bind(&ImageInputOp<Context>::DecodeAndTransform,
            this, std::string(value), image_data, item_id, channels, std::placeholders::_1));//往執行緒池裡面新增任務,啟動執行緒來計算。
  }
  thread_pool_->waitWorkComplete();//等待解碼完成
  ...

來看看Caffe2中這個執行緒池是咋實現的吧,直接複製之。以前聽說過執行緒池,但從未實現過,看看這程式碼,也學到不少東西。

class TaskThreadPool{
 private:
    struct task_element_t {
        bool run_with_id;
        const std::function< void() > no_id;
        const std::function< void(std::size_t) > with_id;

        explicit task_element_t(const std::function< void() >& f) :
            run_with_id(false), no_id(f), with_id(nullptr) { }
        explicit task_element_t(const std::function< void(std::size_t) >& f) :
            run_with_id(true), no_id(nullptr), with_id(f) { }
    };
    std::queue<task_element_t> tasks_;
    std::vector<std::thread> threads_;
    std::mutex mutex_;
    std::condition_variable condition_;
    std::condition_variable completed_;
    bool running_;
    bool complete_;
    std::size_t available_;
    std::size_t total_;

 public:
    /// @brief Constructor.
    explicit TaskThreadPool(std::size_t pool_size)
        :  threads_(pool_size), running_(true), complete_(true),
           available_(pool_size), total_(pool_size) {
        for ( std::size_t i = 0; i < pool_size; ++i ) {//執行緒池裡面共有pool_size個工作執行緒在等待tasks_中的任務
            threads_[i] = std::thread(
                std::bind(&TaskThreadPool::main_loop, this, i));
        }
    }

    /// @brief Destructor.
    ~TaskThreadPool() {
        // Set running flag to false then notify all threads.
        {
            std::unique_lock< std::mutex > lock(mutex_);
            running_ = false;
            condition_.notify_all();
        }

        try {
            for (auto& t : threads_) {
              t.join();
            }
        }
        // Suppress all exceptions.
        catch (const std::exception&) {}
    }

    /// @brief Add task to the thread pool if a thread is currently available.
    template <typename Task>
    void runTask(Task task) {
        std::unique_lock<std::mutex> lock(mutex_);

        // Set task and signal condition variable so that a worker thread will
        // wake up and use the task.
        tasks_.push(task_element_t(static_cast<std::function< void() >>(task)));
        complete_ = false;
        condition_.notify_one();
    }

    template <typename Task>
    void runTaskWithID(Task task) {
      std::unique_lock<std::mutex> lock(mutex_);

      // Set task and signal condition variable so that a worker thread will
      // wake up and use the task.
      tasks_.push(task_element_t(static_cast<std::function< void(std::size_t) >>(
                                   task)));
      complete_ = false;
      condition_.notify_one();
    }

    /// @brief Wait for queue to be empty
    void waitWorkComplete() {
        std::unique_lock<std::mutex> lock(mutex_);
        while (!complete_)
          completed_.wait(lock);
    }

 private:
    /// @brief Entry point for pool threads.
    void main_loop(std::size_t index) {
        while (running_) {
            // Wait on condition variable while the task is empty and
            // the pool is still running.
            std::unique_lock<std::mutex> lock(mutex_);
            while (tasks_.empty() && running_) {
                condition_.wait(lock);
            }
            // If pool is no longer running, break out of loop.
            if (!running_) break;

            // Copy task locally and remove from the queue.  This is
            // done within its own scope so that the task object is
            // destructed immediately after running the task.  This is
            // useful in the event that the function contains
            // shared_ptr arguments bound via bind.
            {
                auto tasks = tasks_.front();
                tasks_.pop();
                // Decrement count, indicating thread is no longer available.
                --available_;

                lock.unlock();//由於mutex已經被鎖住了,需要釋放之,以讓其他執行緒能夠獲得任務,不然執行緒就串行了,無法並行。

                // Run the task.
                try {
                  if (tasks.run_with_id) {
                      tasks.with_id(index);
                  } else {
                      tasks.no_id();
                  }
                }
                // Suppress all exceptions.
                catch ( const std::exception& ) {}

                // Update status of empty, maybe
                // Need to recover the lock first
                lock.lock();

                // Increment count, indicating thread is available.
                ++available_;
                if (tasks_.empty() && available_ == total_) {
                    complete_ = true;
                    completed_.notify_one();
                }
            }
        }  // while running_
    }
};

SaveOp

在訓練過程中,一般會每隔一定的迭代次數儲存將當前模型儲存到硬碟上。在Caffe2中與儲存模型有關的save_to_db函式,它是一個Python函式,封裝了應該儲存的資訊,以方便載入用。save_to_db呼叫的是C++端的SaveOp。

模型中需要儲存的資訊有:

模型引數 ——通過ModelHelper的params屬性可以獲得模型引數,比如卷積的卷積核bias,FC的weight,bias, BN的estimated mean和estimated var等等。
模型定義 ——網路的Op集合。如果是depoly的話,就不需要儲存gradient operators, 否則需要儲存graients operators。

當建立一個Operators,比如建立ConvOp, 該ConvOp需要卷積核conv_w以及bias conv_b會自動建立。這些引數名字會自動新增進ModelHelper的param_init_net中,而這個param_init_net就可以被視為包含網路引數的定義以及如何初始化這些引數的prototxt,比如呼叫各種具體的初始化演算法如Xavier, Gaussian, MSRA等等來填充引數tensor。這就是為啥在訓練真正開始之前,需要先呼叫workspace.RunNetOnce(model_helper_obj.param_init_net)的原因。

SaveOp Run方法被呼叫時,其實就是對輸入的vecor<const Blob*> 依次呼叫Serialize進行序列化,儲存到硬碟上。Serialize的函式原型為:

//每次儲存的都是帶名字的鍵值對,這也符合Caffe2的儲存設計思想,即所有的記憶體區域都要有名字,比如workspace的map<string, unique_ptr<Blob> > blob_map_。這個acceptor就是負責和具體的DB打交道的函式,它將轉化後的字串輸出到真正的DB中,完成儲存到硬碟的操作,完全類似Caffe1中那個convert_imageset。
void Blob::Serialize(const string& name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size) const

在Caffe2儲存部分說過,Blob是一個容器,它可以容納任意型別,比如string, tensor,網路的具體定義比如prototxt就是Blob<string>,它的序列化就是呼叫StringSerialzer,沒啥特別之處。比較有意思部分是當Blob儲存的是Tensor時的序列化,畢竟網路引數都是Tensor。這個序列化的過程其實就是把Tensor中的資料轉換成google protobuf。上面那個acceptor的輸入就是這個google protobuf的序列化字串。下面是Tensor的序列化的程式碼,直接複製過來。

template <class Context>
void TensorSerializer<Context>::SerializeWithChunkSize(const Blob& blob, const string& name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size) {
  CAFFE_ENFORCE(blob.IsType<Tensor<Context>>());
  const auto& tensor = blob.template Get<Tensor<Context>>();
  if (chunk_size == kNoChunking) {
    chunk_size = tensor.size() + 1; // to account for empty tensors
  } else if (chunk_size == kDefaultChunkSize) {
    chunk_size = FLAGS_caffe2_tensor_chunk_size;
  }

  auto processChunk = [&](int64_t chunkStart) {
    BlobProto blob_proto;
    blob_proto.set_name(name);
    blob_proto.set_type(kTensorBlobType);
    TensorProto& proto = *blob_proto.mutable_tensor();
    proto.set_name(name);
    this->Serialize(
        tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
    acceptor(
        MakeString(name, kChunkIdSeparator, chunkStart / chunk_size),
        blob_proto.SerializeAsString());
  };

#ifndef __ANDROID__
  std::vector<std::future<void>> futures;
  // Poorman's IOBound ThreadPool
  //對於超大的Tensor儲存,又是多執行緒並行序列化啊,真是追求效能到極致了。再一次膜拜fb的工程師,貢獻了一份如此漂亮的工業級程式碼。學習了。
  SimpleQueue<size_t> chunkQueue;//執行緒安全佇列
  auto task = [&]() {
    size_t chunkStart;
    while (chunkQueue.Pop(&chunkStart)) {
      processChunk(chunkStart);
    }
  };
  if (tensor.size() > chunk_size) {//這裡就是多執行緒開始的地方
    for (int i = 0; i < FLAGS_caffe2_max_tensor_serializer_threads; ++i) {
      futures.emplace_back(std::async(std::launch::async, task));
    }
  }
#endif

  VLOG(1) << "Serializing blob " << name;
  // Serialize whole vector. If vector is empty, it's shape still needs to be
  // serialized in empty proto
  for (size_t chunkBegin = 0;
       chunkBegin < std::max(tensor.size(), static_cast<TIndex>(1));
       chunkBegin += chunk_size) {
    VLOG(2) << "Starting a chunk at " << chunkBegin;
#ifndef __ANDROID__
    if (tensor.size() > chunk_size) {
      chunkQueue.Push(chunkBegin);//Tensor太大了,分塊,對於每一塊都扔給執行緒池來去序列化。
    } else {
      // Sync mode for small tensors
      processChunk(chunkBegin);
    }
#else
    // Since Android does not have std::future, we will always do sync mode
    processChunk(chunkBegin);
#endif
  }

#ifndef __ANDROID__
  chunkQueue.NoMoreJobs();
  for (auto& fut : futures) {
    fut.get();
  }
#endif
}

LoadOp

LoadOp對應於反序列化,主要。這裡主要說明下python端的prepare_prediction_net函式。儲存到資料庫中的模型有3個net,即global_init_net, predict_init_net, predict_net。

global_init_net ——模型引數載入進workspace就是通過這個global_init_net的,它儲存了網路引數的名字。
predict_init_net——輸入輸出,指定網路輸入blob,輸出blob的名字,以及輸入輸出形狀的定義和置0,同樣,它儲存在workspace中。
predict_net——作用跟Caffe1中的deploy.prototxt差不多。

進行預測時候,載入順序是global_init_net,先將引數載入進workspace,然後載入predict_init_net,初始化輸入輸出,並置0。最後是根據predict_net來構件網路結構,建立一個一個Op。

總結

Caffe2中的ImageInputOp中的decode多執行緒部分,儲存時的TensorSerializer中多執行緒序列化,是值得學習的地方。