caffe原始碼分析-BlockingQueue
阿新 • • 發佈:2018-12-13
BlockingQueue
執行緒安全的佇列, 作為caffe
訓練時資料同步的重要資料結構,本文做簡要分析。
template<typename T>
class BlockingQueue {
public:
explicit BlockingQueue();
void push(const T& t);
bool try_pop(T* t);// non-blocking
// This logs a message if the threads needs to be blocked
// useful for detecting e.g. when data feeding is too slow
T pop(const string& log_on_wait = "");
bool try_peek(T* t);
// Return element without removing it
T peek();
size_t size() const;
protected:
class sync; // class froward decalration
std::queue<T> queue_;
shared_ptr<sync> sync_;
DISABLE_COPY_AND_ASSIGN (BlockingQueue);
};
執行緒同步的條件變數:
template<typename T>
class BlockingQueue<T>::
sync {
public:
mutable boost::mutex mutex_;
boost::condition_variable condition_;
};
建構函式:
template<typename T>
BlockingQueue<T>::BlockingQueue()
: sync_(new sync()) {
}
下面僅僅給出阻塞pop與非阻塞的try_pop,說明BlockingQueue的使用:
template<typename T>
bool BlockingQueue<T>::try_pop(T* t) {
boost::mutex::scoped_lock lock(sync_->mutex_);
if (queue_.empty()) {
return false; // 立即返回
}
*t = queue_.front();
queue_.pop();
return true;
}
template<typename T>
T BlockingQueue<T>::pop(const string& log_on_wait) {
boost::mutex::scoped_lock lock(sync_->mutex_);
while (queue_.empty()) {
if (!log_on_wait.empty()) {
LOG_EVERY_N(INFO, 1000)<< log_on_wait;
}
sync_->condition_.wait(lock); //阻塞等待條件變數
}
T t = queue_.front();
queue_.pop();
return t;
}
模板特化:
template class BlockingQueue<Batch<float>*>;
template class BlockingQueue<Batch<double>*>;
template class BlockingQueue<Datum*>;
template class BlockingQueue<shared_ptr<DataReader::QueuePair> >;
其中:
template <typename Dtype>
class Batch {
public:
Blob<Dtype> data_, label_;
};
DataLayer中使用執行緒讀取Batch(image,label)push到佇列中,然後pop出來前向傳播:
template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
// Reshape to loaded data.
top[0]->ReshapeLike(batch->data_);
// Copy the data
caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
top[0]->mutable_cpu_data());
if (this->output_labels_) {
// Reshape to loaded labels.
top[1]->ReshapeLike(batch->label_);
// Copy the labels.
caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
top[1]->mutable_cpu_data());
}
prefetch_free_.push(batch);
}
BlockingQueue成員:
template <typename Dtype>
class BasePrefetchingDataLayer :
public BaseDataLayer<Dtype>, public InternalThread {
//.......
protected:
Batch<Dtype> prefetch_[PREFETCH_COUNT];
BlockingQueue<Batch<Dtype>*> prefetch_free_;
BlockingQueue<Batch<Dtype>*> prefetch_full_;
};