1. 程式人生 > >mxnet程式碼解析之mshadow

mxnet程式碼解析之mshadow

mshadow採用了表示式模板的技巧增強了c++矩陣庫的效能。
mshadow用於資料儲存結構的主要繼承脈絡如下:
Tensor->TRValue->RValueExp->Exp
繼承鏈的頂端是所有表示式的基類Exp:

template<typename SubType, typename DType, int exp_type>
struct Exp {
 public:
  /*! \return  subtype instance of current class */
  inline const SubType& self(void) const
{ return *static_cast<const SubType*>(this); } /*! \return reference of subtype instance of current class */ inline SubType* ptrself(void) { return static_cast<SubType*>(this); } };

這裡Exp定義的精髓就是通過self或ptrself可以獲得SubType的引用或指標,這為後面將SubType表示式作為模板引數傳遞後再獲得SubType提供了途徑。
RValueExp繼承於Exp,是所有右值的基類:

template<typename Container, typename DType>
class RValueExp: public Exp<Container, DType, type::kRValue> {
 public:
  inline const TransposeExp<Container, DType> T(void) const {
    return TransposeExp<Container, DType>(this->self());
  }
  /*! \brief operator overload */
inline Container &operator+=(DType s) { ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); return *(this->ptrself()); } /*! \brief operator overload */ inline Container &operator-=(DType s) {......} /*! \brief operator overload */ inline Container &operator*=(DType s) {......} /*! \brief operator overload */ inline Container &operator/=(DType s) {......} /*! \brief operator overload */ inline Container &__assign(DType s) { ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s)); return *(this->ptrself()); } /*! \brief we can not define container = container */ template<typename E, int etype> inline Container &__assign(const Exp<E, DType, etype> &exp) { ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } /*! \brief operator overload, assign */ inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp); /*! \brief implementation of operator+= */ template<typename E, int etype> inline Container &operator+=(const Exp<E, DType, etype> &exp) { ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self()); return *(this->ptrself()); } /*! \brief implementation of operator-= */ template<typename E, int etype> inline Container &operator-=(const Exp<E, DType, etype> &exp) {......} /*! \brief implementation of operator*= */ template<typename E, int etype> inline Container &operator*=(const Exp<E, DType, etype> &exp) {......} /*! \brief implementation of operator/= */ template<typename E, int etype> inline Container &operator/=(const Exp<E, DType, etype> &exp) {......} };

RValueExp類中的ExpEngine類過載了四種表示式型別(kMapper、kChainer、kRValue、kComplex)的Eval函式 :

template<typename SV, typename RV, typename DType>
struct ExpEngine {
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kMapper> &exp) {
    MapExp<SV>(dst, exp);
  }
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kChainer> &exp) {
    MapExp<SV>(dst, exp);
  }
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kRValue> &exp) {
    MapExp<SV>(dst, exp);
  }
  //用於dot
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kComplex> &exp) {
    ExpComplexEngine<SV, RV, E, DType>::Eval(dst->ptrself(), exp.self());
  }
};

TRValue是所有可能的tensor的超類

template<typename Container, typename Device, int dimension, typename DType>
struct TRValue: public expr::RValueExp<Container, DType> {
};

終於到了Tensor類的定義:

template<typename Device, int dimension,
         typename DType MSHADOW_DEFAULT_DTYPE>
struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
                              Device, dimension, DType> {
 public:
  static const bool kDevCPU = Device::kDevCPU;
  static const int  kSubdim = dimension - 1;

  /*! \brief pointer to the data */
  DType *dptr_;
  /*! \brief shape of the tensor */
  Shape<dimension> shape_;
  /*!
   * \brief storing the stride information in x dimension
   *    this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency
   */
  index_t stride_;

  Stream<Device> *stream_;

  //各種建構函式
  ......
  /*! 從 data pointer 和shape構造Tensor  */
  MSHADOW_XINLINE Tensor(DType *dptr,
                         const Shape<dimension> &shape,
                         index_t stride, Stream<Device> *stream)
      : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
  ......
  ......

  MSHADOW_XINLINE Tensor<Device, dimension, DType>
  Slice(index_t begin, index_t end) const {
    Shape<dimension> s = this->shape_;
    s[0] = end - begin;
    return Tensor<Device, dimension, DType>(dptr_ + this->MemSize<1>() * begin,
                                            s, stride_, stream_);
  }
  /*!\brief implement the assignment of same type */
  inline Tensor<Device, dimension, DType> &
  operator=(const Tensor<Device, dimension, DType> &exp) {
    dptr_ = exp.dptr_;
    shape_ = exp.shape_;
    stride_ = exp.stride_;
    stream_ = exp.stream_;
    return *this;
  }
  /*!\brief functions to fit expression template */
  template<typename E, int etype>
  inline Tensor<Device, dimension, DType> &
  operator=(const expr::Exp<E, DType, etype> &exp) {
    return this->__assign(exp);
  }
  /*!\brief functions to fit expression template */
  inline Tensor<Device, dimension, DType> &operator=(const DType &exp) {
    return this->__assign(exp);
  }
};

Tensor的shape與numpy.shape不一樣,最低維度從shape_[0]開始,過載操作符“=”除了拷貝已有Tensor,還可賦值中間運算結果表示式Exp,以及賦值標量。這裡對operator =的過載將運算操作延遲到了賦值階段,實現了Lazy Evaluation,避免了臨時記憶體分配。特別地,DotExp在operator =中實行lazily evaluate, 將矩陣的乘法重定向到了blas庫。

mshadow用於表示式操作的類(DotExp、BinaryMapExp、UnaryMapExp)同樣繼承於Exp基類,其特點是該表示式操作類自身也作為模板引數傳遞給Exp,以BinaryMapExp為例:

template<typename OP, typename TA, typename TB, typename DType, int etype>
struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
                                DType, etype> {
  /*! \brief left operand */
  const TA &lhs_;
  /*! \brief right operand */
  const TB &rhs_;
  /*! \brief constructor */
  explicit BinaryMapExp(const TA &lhs, const TB &rhs)
      :lhs_(lhs), rhs_(rhs) {}
};

template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
  return BinaryMapExp<OP, TA, TB, DType,
                      (ta|tb|type::kMapper)>(lhs.self(), rhs.self());
}

template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
  return MakeExp<OP>(lhs, rhs);
}

template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::plus, TA, TB, DType, (ta|tb|type::kMapper)>
operator+(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
  return MakeExp<op::plus>(lhs, rhs);
}
......
......

BinaryMapExp是雙目運算的表示式類,MakeExp是用來生成BinaryMapExp類物件的函式,F是自定義操作的函式,F< OP >(lhs, rhs)描述了一個新的雙目運算,除此以外,+-*/等操作符過載函式也呼叫MakeExp建立BinaryMapExp。
這些用於表示式操作的類(DotExp、BinaryMapExp、UnaryMapExp)表示一個運算操作的中間結果,且可以遞迴表示(由Plan的Eval函式完成),實現了lengthy equations的解析。

真正用於遞迴呼叫eval的是Plan類:

template<typename ExpType, typename DType>
class Plan {
 public:
  /*!
   * \brief evaluate the expression at index [y][x]
   *  to be implemented by SubType, for RValue, the return type will be DType &
   */
  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const;
};

// tensor的plan函式
template <typename Device, int dim, typename DType>
class Plan<Tensor<Device, dim, DType>, DType> {
 public:
  explicit Plan(const Tensor<Device, dim, DType> &t)
      : dptr_(t.dptr_), stride_(t.stride_) {}
  // for RValue, the return type should be reference
  MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
    return dptr_[y * stride_ + x];
  }
  // const evaluation
  MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const {
    return dptr_[y * stride_ + x];
  }

 private:
  DType  *dptr_;
  index_t stride_;
};
......
......
// 雙目表示式的plan
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> {
 public:
  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
      : lhs_(lhs), rhs_(rhs) {}
  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
    return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
  }

 private:
  Plan<TA, DType> lhs_;
  Plan<TB, DType> rhs_;
};

呼叫模板函式MakePlan從表示式生成Plan,再以BinaryMapExp表示式為例:

template<typename T, typename DType>
inline Plan<T, DType> MakePlan(const RValueExp<T, DType> &e) {
  return Plan<T, DType>(e.self());
}

template<typename OP, typename TA, typename TB, typename DType, int etype>
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) {
  return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
              DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
}

已知ExpEngine中過載的eval中呼叫了MapExp來實現將Plan的結果傳遞給目標Exp的作用,MapExp又通過直接或間接(為處理sse優化而加入一箇中間類MapExpCPUEngine)呼叫MapPlan實現上述作用。

Tensor的維數在定義後就固定了,因此在圖模型中需要一個更為抽象靈活的資料結構,這就是TBlob:

class TBlob {
 public:
  /*! \brief pointer to the data */
  void *dptr_;
  /*! \brief shape of the tensor */
  TShape shape_;
  /*!
   * \brief storing the stride information in x dimension
   */
  index_t stride_;
  /*! \brief device mask of the corresponding device */
  int dev_mask_;
  /*! \brief type flag of the tensor blob */
  int type_flag_;
  ......
  template<typename Device, int dim, typename DType>
  inline Tensor<Device, dim, DType> get(Stream<Device> *stream = NULL) const {...}

  template<typename Device, int dim, typename DType>
  inline Tensor<Device, dim, DType> get_with_shape(const Shape<dim> &shape,
                                                   Stream<Device> *stream = NULL) const
  {
    CHECK(Device::kDevMask == dev_mask_)
      << "TBlob.get: device type do not match specified type";
    CHECK(DataType<DType>::kFlag == type_flag_)
      << "TBlob.get_with_shape: data type do not match specified type."
      << "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag;
    CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous";
    CHECK_EQ(this->shape_.Size(), shape.Size())
      << "TBlob.get_with_shape: new and old shape do not match total elements";
    return Tensor<Device, dim, DType>(static_cast<DType*>(dptr_),
                                      shape,
                                      shape[dim - 1],
                                      stream);
  }
  ......
  template<typename Device, typename DType>
  inline Tensor<Device, 2, DType> FlatTo2D(Stream<Device> *stream = NULL) const {}
  ......
  template<typename Device, typename DType>
  inline Tensor<Device, 3, DType> FlatTo3D(int axis, Stream<Device> *stream = NULL)  
  const {}
  ......                           
}

TBlob不涉及任何算數運算,也沒有隱式的記憶體分配與釋放,它就像一個指標類,在需要的時候呼叫get、get_with_shape、FlatTo2D、FlatTo3D等獲得固定維數的Tensor來做更多的操作。Tshape與TBlob類似,在需要的時候呼叫get、FlatTo2D等獲得Tensor對應的Shape。