1. 程式人生 > >caffe原始碼分析-BlockingQueue

caffe原始碼分析-BlockingQueue

BlockingQueue執行緒安全的佇列, 作為caffe訓練時資料同步的重要資料結構,本文做簡要分析。


template<typename T>
class BlockingQueue {
public:
    explicit BlockingQueue();

    void push(const T& t);
    bool try_pop(T* t);// non-blocking
    // This logs a message if the threads needs to be blocked
    // useful for detecting e.g. when data feeding is too slow
T pop(const string& log_on_wait = ""); bool try_peek(T* t); // Return element without removing it T peek(); size_t size() const; protected: class sync; // class froward decalration std::queue<T> queue_; shared_ptr<sync> sync_; DISABLE_COPY_AND_ASSIGN
(BlockingQueue); };

執行緒同步的條件變數:


template<typename T>
class BlockingQueue<T>::
sync {
public:
    mutable boost::mutex mutex_;
    boost::condition_variable condition_;
};

建構函式:


template<typename T>
BlockingQueue<T>::BlockingQueue()
        : sync_(new sync()) {
}

下面僅僅給出阻塞pop與非阻塞的try_pop,說明BlockingQueue的使用:

template<typename T>
bool BlockingQueue<T>::try_pop(T* t) {
    boost::mutex::scoped_lock lock(sync_->mutex_);

    if (queue_.empty()) {
        return false;  // 立即返回
    }

    *t = queue_.front();
    queue_.pop();
    return true;
}

template<typename T>
T BlockingQueue<T>::pop(const string& log_on_wait) {
    boost::mutex::scoped_lock lock(sync_->mutex_);

    while (queue_.empty()) {
        if (!log_on_wait.empty()) {
            LOG_EVERY_N(INFO, 1000)<< log_on_wait;
        }
        sync_->condition_.wait(lock); //阻塞等待條件變數
    }

    T t = queue_.front();
    queue_.pop();
    return t;
}

模板特化:

template class BlockingQueue<Batch<float>*>;
template class BlockingQueue<Batch<double>*>;
template class BlockingQueue<Datum*>;
template class BlockingQueue<shared_ptr<DataReader::QueuePair> >;

其中:

template <typename Dtype>
class Batch {
public:
    Blob<Dtype> data_, label_;
};

DataLayer中使用執行緒讀取Batch(image,label)push到佇列中,然後pop出來前向傳播:

template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
        const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {

    Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
    // Reshape to loaded data.
    top[0]->ReshapeLike(batch->data_);

    // Copy the data
    caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
               top[0]->mutable_cpu_data());

    if (this->output_labels_) {
        // Reshape to loaded labels.
        top[1]->ReshapeLike(batch->label_);
        // Copy the labels.
        caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
                   top[1]->mutable_cpu_data());
    }

    prefetch_free_.push(batch);
}

BlockingQueue成員:

template <typename Dtype>
class BasePrefetchingDataLayer :
        public BaseDataLayer<Dtype>, public InternalThread {
	//.......
protected:
    Batch<Dtype> prefetch_[PREFETCH_COUNT];
    BlockingQueue<Batch<Dtype>*> prefetch_free_;
    BlockingQueue<Batch<Dtype>*> prefetch_full_;
};