1. 程式人生 > 程式設計 >TensorFlow實現自定義Op方式

TensorFlow實現自定義Op方式

『寫在前面』

以CTC Beam search decoder為例,簡單整理一下TensorFlow實現自定義Op的操作流程。

基本的流程

1. 定義Op介面

#include "tensorflow/core/framework/op.h"
 
REGISTER_OP("Custom")  
  .Input("custom_input: int32")
  .Output("custom_output: int32");

2. 為Op實現Compute操作(CPU)或實現kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"
 
using namespace tensorflow;
 
class CustomOp : public OpKernel{
  public:
  explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
  // 獲取輸入 tensor.
  const Tensor& input_tensor = context->input(0);
  auto input = input_tensor.flat<int32>();
  // 建立一個輸出 tensor.
  Tensor* output_tensor = NULL;
  OP_REQUIRES_OK(context,context->allocate_output(0,input_tensor.shape(),&output_tensor));
  auto output = output_tensor->template flat<int32>();
  //進行具體的運算,操作input和output
  //……
 }
};

3. 將實現的kernel註冊到TensorFlow系統中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU),CustomOp);

CTCBeamSearchDecoder自定義

該Op對應TensorFlow中的原始碼部分

Op介面的定義:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定義:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封裝與Op註冊:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基於原始碼修改的Op

#include <algorithm>
#include <vector>
#include <cmath>
 
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
 
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
 
namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;
 
using namespace tensorflow;
 
REGISTER_OP("CTCBeamSearchDecoderWithParam")
  .Input("inputs: float")
  .Input("sequence_length: int32")
  .Attr("beam_width: int >= 1")
  .Attr("top_paths: int >= 1")
  .Attr("merge_repeated: bool = true")
  //新添加了兩個引數
  .Attr("label_selection_size: int >= 0 = 0") 
  .Attr("label_selection_margin: float") 
  .Output("decoded_indices: top_paths * int64")
  .Output("decoded_values: top_paths * int64")
  .Output("decoded_shape: top_paths * int64")
  .Output("log_probability: float")
  .SetShapeFn([](InferenceContext* c) {
   ShapeHandle inputs;
   ShapeHandle sequence_length;
 
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0),3,&inputs));
   TF_RETURN_IF_ERROR(c->WithRank(c->input(1),1,&sequence_length));
 
   // Get batch size from inputs and sequence_length.
   DimensionHandle batch_size;
   TF_RETURN_IF_ERROR(
     c->Merge(c->Dim(inputs,1),c->Dim(sequence_length,0),&batch_size));
 
   int32 top_paths;
   TF_RETURN_IF_ERROR(c->GetAttr("top_paths",&top_paths));
 
   // Outputs.
   int out_idx = 0;
   for (int i = 0; i < top_paths; ++i) { // decoded_indices
    c->set_output(out_idx++,c->Matrix(InferenceContext::kUnknownDim,2));
   }
   for (int i = 0; i < top_paths; ++i) { // decoded_values
    c->set_output(out_idx++,c->Vector(InferenceContext::kUnknownDim));
   }
   ShapeHandle shape_v = c->Vector(2);
   for (int i = 0; i < top_paths; ++i) { // decoded_shape
    c->set_output(out_idx++,shape_v);
   }
   c->set_output(out_idx++,c->Matrix(batch_size,top_paths));
   return Status::OK();
  });
 
typedef Eigen::ThreadPoolDevice CPUDevice;
 
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m,int r,int* c) {
 *c = 0;
 CHECK_LT(0,m.dimension(1));
 float p = m(r,0);
 for (int i = 1; i < m.dimension(1); ++i) {
  if (m(r,i) > p) {
   p = m(r,i);
   *c = i;
  }
 }
 return p;
}
 
class CTCDecodeHelper {
 public:
 CTCDecodeHelper() : top_paths_(1) {}
 
 inline int GetTopPaths() const { return top_paths_; }
 void SetTopPaths(int tp) { top_paths_ = tp; }
 
 Status ValidateInputsGenerateOutputs(
   OpKernelContext* ctx,const Tensor** inputs,const Tensor** seq_len,Tensor** log_prob,OpOutputList* decoded_indices,OpOutputList* decoded_values,OpOutputList* decoded_shape) const {
  Status status = ctx->input("inputs",inputs);
  if (!status.ok()) return status;
  status = ctx->input("sequence_length",seq_len);
  if (!status.ok()) return status;
 
  const TensorShape& inputs_shape = (*inputs)->shape();
 
  if (inputs_shape.dims() != 3) {
   return errors::InvalidArgument("inputs is not a 3-Tensor");
  }
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
 
  if (max_time == 0) {
   return errors::InvalidArgument("max_time is 0");
  }
  if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
   return errors::InvalidArgument("sequence_length is not a vector");
  }
 
  if (!(batch_size == (*seq_len)->dim_size(0))) {
   return errors::FailedPrecondition(
     "len(sequence_length) != batch_size. ","len(sequence_length): ",(*seq_len)->dim_size(0)," batch_size: ",batch_size);
  }
 
  auto seq_len_t = (*seq_len)->vec<int32>();
 
  for (int b = 0; b < batch_size; ++b) {
   if (!(seq_len_t(b) <= max_time)) {
    return errors::FailedPrecondition("sequence_length(",b,") <= ",max_time);
   }
  }
 
  Status s = ctx->allocate_output(
    "log_probability",TensorShape({batch_size,top_paths_}),log_prob);
  if (!s.ok()) return s;
 
  s = ctx->output_list("decoded_indices",decoded_indices);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_values",decoded_values);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_shape",decoded_shape);
  if (!s.ok()) return s;
 
  return Status::OK();
 }
 
 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
 Status StoreAllDecodedSequences(
   const std::vector<std::vector<std::vector<int> > >& sequences,OpOutputList* decoded_shape) const {
  // Calculate the total number of entries for each path
  const int64 batch_size = sequences.size();
  std::vector<int64> num_entries(top_paths_,0);
 
  // Calculate num_entries per path
  for (const auto& batch_s : sequences) {
   CHECK_EQ(batch_s.size(),top_paths_);
   for (int p = 0; p < top_paths_; ++p) {
    num_entries[p] += batch_s[p].size();
   }
  }
 
  for (int p = 0; p < top_paths_; ++p) {
   Tensor* p_indices = nullptr;
   Tensor* p_values = nullptr;
   Tensor* p_shape = nullptr;
 
   const int64 p_num = num_entries[p];
 
   Status s =
     decoded_indices->allocate(p,TensorShape({p_num,2}),&p_indices);
   if (!s.ok()) return s;
   s = decoded_values->allocate(p,TensorShape({p_num}),&p_values);
   if (!s.ok()) return s;
   s = decoded_shape->allocate(p,TensorShape({2}),&p_shape);
   if (!s.ok()) return s;
 
   auto indices_t = p_indices->matrix<int64>();
   auto values_t = p_values->vec<int64>();
   auto shape_t = p_shape->vec<int64>();
 
   int64 max_decoded = 0;
   int64 offset = 0;
 
   for (int64 b = 0; b < batch_size; ++b) {
    auto& p_batch = sequences[b][p];
    int64 num_decoded = p_batch.size();
    max_decoded = std::max(max_decoded,num_decoded);
    std::copy_n(p_batch.begin(),num_decoded,&values_t(offset));
    for (int64 t = 0; t < num_decoded; ++t,++offset) {
     indices_t(offset,0) = b;
     indices_t(offset,1) = t;
    }
   }
 
   shape_t(0) = batch_size;
   shape_t(1) = max_decoded;
  }
  return Status::OK();
 }
 
 private:
 int top_paths_;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
 
// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {
 public:
 explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
  OP_REQUIRES_OK(ctx,ctx->GetAttr("merge_repeated",&merge_repeated_));
  OP_REQUIRES_OK(ctx,ctx->GetAttr("beam_width",&beam_width_));
  //從引數列表中讀取新添的兩個引數
  OP_REQUIRES_OK(ctx,ctx->GetAttr("label_selection_size",&label_selection_size));
  OP_REQUIRES_OK(ctx,ctx->GetAttr("label_selection_margin",&label_selection_margin));
  int top_paths;
  OP_REQUIRES_OK(ctx,ctx->GetAttr("top_paths",&top_paths));
  decode_helper_.SetTopPaths(top_paths);
 }
 
 void Compute(OpKernelContext* ctx) override {
  const Tensor* inputs;
  const Tensor* seq_len;
  Tensor* log_prob = nullptr;
  OpOutputList decoded_indices;
  OpOutputList decoded_values;
  OpOutputList decoded_shape;
  OP_REQUIRES_OK(ctx,decode_helper_.ValidateInputsGenerateOutputs(
              ctx,&inputs,&seq_len,&log_prob,&decoded_indices,&decoded_values,&decoded_shape));
 
  auto inputs_t = inputs->tensor<float,3>();
  auto seq_len_t = seq_len->vec<int32>();
  auto log_prob_t = log_prob->matrix<float>();
 
  const TensorShape& inputs_shape = inputs->shape();
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
  const int64 num_classes_raw = inputs_shape.dim_size(2);
  OP_REQUIRES(
    ctx,FastBoundsCheck(num_classes_raw,std::numeric_limits<int>::max()),errors::InvalidArgument("num_classes cannot exceed max int"));
  const int num_classes = static_cast<const int>(num_classes_raw);
 
  log_prob_t.setZero();
 
  std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
 
  for (std::size_t t = 0; t < max_time; ++t) {
   input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,batch_size,num_classes);
  }
 
  ctc::CTCBeamSearchDecoder<> beam_search(num_classes,beam_width_,&beam_scorer_,1 /* batch_size */,merge_repeated_);
  //使用傳入的兩個引數進行Set
  beam_search.SetLabelSelectionParameters(label_selection_size,label_selection_margin);
  Tensor input_chip(DT_FLOAT,TensorShape({num_classes}));
  auto input_chip_t = input_chip.flat<float>();
 
  std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
  std::vector<float> log_probs;
 
  // Assumption: the blank index is num_classes - 1
  for (int b = 0; b < batch_size; ++b) {
   auto& best_paths_b = best_paths[b];
   best_paths_b.resize(decode_helper_.GetTopPaths());
   for (int t = 0; t < seq_len_t(b); ++t) {
    input_chip_t = input_list_t[t].chip(b,0);
    auto input_bi =
      Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(),num_classes);
    beam_search.Step(input_bi);
   }
   OP_REQUIRES_OK(
     ctx,beam_search.TopPaths(decode_helper_.GetTopPaths(),&best_paths_b,&log_probs,merge_repeated_));
 
   beam_search.Reset();
 
   for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
    log_prob_t(b,bp) = log_probs[bp];
   }
  }
 
  OP_REQUIRES_OK(ctx,decode_helper_.StoreAllDecodedSequences(
              best_paths,&decoded_shape));
 }
 
 private:
 CTCDecodeHelper decode_helper_;
 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
 bool merge_repeated_;
 int beam_width_;
 //新添兩個資料成員,用於儲存新加的引數
 int label_selection_size;
 float label_selection_margin;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};
 
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),CTCBeamSearchDecoderWithParamOp);

將自定義的Op編譯成.so檔案

在tensorflow-master目錄下新建一個資料夾custom_op

cd custom_op

新建一個BUILD檔案,並在其中新增如下程式碼:

cc_library(
  name = "ctc_decoder_with_param",srcs = [
      "new_beamsearch.cc"
      ] +
      glob(["boost_locale/**/*.hpp"]),includes = ["boost_locale"],copts = ["-std=c++11"],deps = ["//tensorflow/core:core","//tensorflow/core/util/ctc","//third_party/eigen3",],)

編譯過程:

1. cd 到 tensorflow-master 目錄下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目錄下生成 libctc_decoder_with_param.so

在訓練(預測)程式中使用自定義的Op

在程式中定義如下的方法:

decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs,sequence_length,beam_width=100,top_paths=1,merge_repeated=True):
  decoded_ixs,decoded_vals,decoded_shapes,log_probabilities = (
    decode_param_op_module.ctc_beam_search_decoder_with_param(
      inputs,beam_width=beam_width,top_paths=top_paths,merge_repeated=merge_repeated,label_selection_size=40,label_selection_margin=0.99))
  return (
    [tf.SparseTensor(ix,val,shape) for (ix,shape)
     in zip(decoded_ixs,decoded_shapes)],log_probabilities)

然後就可以像使用tf.nn.ctc_beam_search_decoder一樣使用該Op了。

以上這篇TensorFlow實現自定義Op方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。