1. 程式人生 > >人工智慧框架資料集轉換C++實現(一):TFRecord資料集

人工智慧框架資料集轉換C++實現(一):TFRecord資料集

最近在研究將各種資料集轉換為不同AI框架的自有資料,這些框架包括Caffe,MXNet,Tensorflow等.C++這樣一個通用而又強大的語言,卻讓使用C++的同鞋在AI時代有一個痛點,那就是目前的AI框架基本都是底層用C/C++實現,但提供的介面卻大部分都是python的介面,而且Python的介面封裝的都特別好,MXNet還好,提供im2rec.cc這樣的C/C++原始碼,而Caffe,尤其是Tensorflow這樣的框架,想用C++來轉換資料就需要花點功夫了.所以本文首先講解Tensorflow的資料集格式轉換.

1.不同框架的資料分別是怎樣的?

MXNet的自有資料集:rec格式

Caffe的自有據集:Lmdb格式

Tensorflow的自有資料集:TFRecord格式

2.什麼是TFRecord格式?

關於tensorflow讀取資料,官網給出了三種方法:
1、供給資料:在tensorflow程式執行的每一步,讓python程式碼來供給資料
2、從檔案讀取資料:建立輸入管線從檔案中讀取資料
3、預載入資料:如果資料量不太大,可以在程式中定義常量或者變數來儲存所有的資料。

而tfrecord格式是Tensorflow官方推薦的標準格式。tfrecord資料檔案是一種將影象資料和標籤統一儲存的二進位制檔案,能更好的利用記憶體,在tensorflow中快速的複製,移動,讀取,儲存等。

該資料集由一個example.proto檔案定義:

syntax = "proto3";


message Example{
    Features features = 1;
};
message Features{
    map<string,Feature> feature = 1;
};

// Containers to hold repeated fundamental values.
message BytesList {
  repeated bytes value = 1;
}
message FloatList {
  repeated float value = 1 [packed = true];
}
message Int64List {
  repeated int64 value = 1 [packed = true];
}

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

這是一個protobuf3的格式定義,需要使用以下命令通過該檔案生成標頭檔案example.pb.h和cc檔案example.pb.cc:

protoc -I=. --cpp_out=./ example.proto

3.自有資料集該準備成什麼樣?

此處以VOC2007資料集為檢測任務的例子講解,LFW資料集為分類任務講解.

對於分類任務,資料集統一構建一個這樣的列表,該表的構建可以參考Caffe的分類任務列表的構建(檔名和標籤中間不是空格,而是\t):

/output/oldFile/1000015_10/wKgB5Fr6WwWAJb7iAAABKohu5Nw109.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAEbg6AAABC_mxdD8880.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAUGTdAAAA8wVERrQ677.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAPJ-lAAABPYAoeuY242.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWARVIWAAABCK2alGs331.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAV3R5AAAA5573dko147.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwaAUjQRAAABIkYxqoY008.png   0
...
/output/oldFile/1000015_10/wKgB5Vr6YF-AALG-AAAA-qStI_Q208.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAe1VYAAABN5fz53Y240.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAQo7fAAABVFasXJ4223.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAL00yAAABJdrU4U0508.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAFjTyAAABJVgoCrU242.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAKmMMAAABMd1_pJg240.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAR2FqAAABFCQ7LRY651.png   1

對於VOC2007資料集,構建的列表如下(檔名和標籤中間不是空格,而是\t):

/home/test/data/VOC2007/JPEGImages/004379.jpg /home/xbx/data/VOC2007/Annotations/004379.xml
/home/test/data/VOC2007/JPEGImages/001488.jpg /home/xbx/data/VOC2007/Annotations/001488.xml
/home/test/data/VOC2007/JPEGImages/004105.jpg /home/xbx/data/VOC2007/Annotations/004105.xml
/home/test/data/VOC2007/JPEGImages/006146.jpg /home/xbx/data/VOC2007/Annotations/006146.xml
/home/test/data/VOC2007/JPEGImages/004295.jpg /home/xbx/data/VOC2007/Annotations/004295.xml
/home/test/data/VOC2007/JPEGImages/001360.jpg /home/xbx/data/VOC2007/Annotations/001360.xml
/home/test/data/VOC2007/JPEGImages/003468.jpg /home/xbx/data/VOC2007/Annotations/003468.xml
...

4.資料集轉換的流程是怎樣的?

資料列表準備好之後,就可以開始分析資料集轉換的流程,大體上來說就是對於分類任務,首先初始化一個RecordWriter,然後處理列表中的資料,每一行對應一個Example,每行包含圖片路徑和相應的標籤,使用OPENCV讀取圖片為Mat後,將其轉換為string的格式(為什麼不是char*,因為影象中可能存在\0),儲存到Example中的feature中,map名稱取為image_raw,並獲取圖片的寬高通道數,標籤等資訊,也都儲存到Example中的feature中,map名分別為width,height,depth,label等,最後將每行的Example序列化SerializeToString為string,呼叫writer_->WriteRecord寫入.對於檢測任務區別則在於增加了對xml檔案的解析,並儲存bbox資訊等.

需要用到的標頭檔案包括:

#include <fcntl.h>
#include <stdio.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <boost/foreach.hpp>
#include <boost/property_tree/ptree.hpp>
#include <boost/property_tree/xml_parser.hpp>
#include <fstream>
#include <iostream>
#include <map>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <vector>

#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/record_writer.h"

#include <boost/lexical_cast.hpp>
#include "rng.hpp"
using namespace tensorflow::io;
using namespace tensorflow;

主函式的判斷:

  if ((dataset_type == "object_detect") && (label_map_file.length() > 0)) {
    //檢測任務,其中datalist_file是列表名,label_map_file是標籤name和label的轉換檔案,output_dir是tfrecord需要輸出的路徑,output_name是tfrecord輸出的檔名,samples_pre是tfrecord單個檔案儲存多少行,Shuffle是是否打亂
    if (!detecteddata_to_tfrecords(datalist_file, label_map_file, output_dir, output_name,
                                   samples_pre, Shuffle)) {
      printf("convert wrong!!!\n");
      return false;
    }
  } else if ((dataset_type == "classification") && (label_width > 0)) {
    //分類任務,其中datalist_file是列表名,output_dir是tfrecord需要輸出的路徑,output_name是tfrecord輸出的檔名,samples_pre是tfrecord單個檔案儲存多少行,label_width是標籤數目,對應單標籤還是多標籤,Shuffle是是否打亂
    if (!clsdata_to_tfrecords(datalist_file, output_dir, output_name, samples_pre, label_width,
                              Shuffle)) {
      printf("convert wrong!!!\n");
      return false;
    }
  } else {
    printf(
        "dataset type is not object_detect or classification, or label_width [%lu], label_map_file "
        "[%s] is wrong!!!\n",
        label_width, label_map_file.c_str());
    return false;
  }

  // Optional:  Delete all global objects allocated by libprotobuf.清理在各子函式中開啟的protobuf資源
  google::protobuf::ShutdownProtobufLibrary();

對於分類任務,程式碼如下:

bool clsdata_to_tfrecords(string datalist_file, string output_dir, string output_name,
                          int samples_pre, size_t label_width, int Shuffle) {
  std::ifstream infile(datalist_file.c_str());
  std::string line;
  std::vector<std::pair<string, std::vector<int> > > dataset;

  //讀取列表檔案,並將資訊儲存到dataset中
  while (getline(infile, line)) {
    vector<string> tmp_str = param_split(line, "\t");
    std::string filename;
    std::vector<int> label_v;
    if (tmp_str.size() != (label_width + 1)) {
      std::cout << "line " << line << "has too many param!!!" << std::endl;
      return false;
    }
    for (size_t i = 0; i < (label_width + 1); ++i) {
      if (i == 0) {
        filename = tmp_str[0];
      } else {
        try {
          int label = boost::lexical_cast<int>(tmp_str[i]);
          label_v.push_back(label);
        } catch (boost::bad_lexical_cast& e) {
          printf("%s\n", e.what());
          return false;
        }
      }
    }
    if (filename.size() > 0) dataset.push_back(std::make_pair(filename, label_v));
  }
  
  //打亂資料集,該程式碼借用caffe中rng.hpp程式碼
    if (Shuffle) {
    printf("tensorflow task will be shuffled!!!");
    caffe::shuffle(dataset.begin(), dataset.end());
  }

  printf("A total of %lu images.\n", dataset.size());

  // create recordwriter
  std::unique_ptr<WritableFile> file;

  RecordWriterOptions options = RecordWriterOptions::CreateRecordWriterOptions("ZLIB");

  RecordWriter* writer_ = NULL;

  int j = 0, fidx = 0;
  size_t line_id = 0;
  for (line_id = 0; line_id < dataset.size(); ++line_id) {
    if (line_id == 0 || j > samples_pre) {
      //如果是第一次或者單個檔案的tfrecord記錄達到samples_pre上限,則重新初始化一個新的RecordWriter
      if (writer_ != NULL) {
        delete writer_;
        writer_ = NULL;
      }

      char output_file[1024];
      memset(output_file, 0, 1024);

      sprintf(output_file, "%s/%s_%03d.tfrecord", output_dir.c_str(), output_name.c_str(), fidx);
      printf("create new tfrecord file: [%s] \n", output_file);

      Status s = Env::Default()->NewWritableFile((string)output_file, &file);
      if (!s.ok()) {
        printf("create write record file [%s] wrong!!!\n", output_file);
        return false;
      }

      writer_ = new RecordWriter(file.get(), options);
      j = 0;
      fidx += 1;
    }
    //讀取圖片
    cv::Mat image = ReadImageToCVMat(dataset[line_id].first);
    //將Mat轉為string的形式
    std::string image_b = matToBytes(image);
    int height = image.rows;
    int width = image.cols;
    int depth = image.channels();

    //每一條資料對應一個Example
    Example example1;
    Features* features1 = example1.mutable_features();
    ::google::protobuf::Map<string, Feature>* feature1 = features1->mutable_feature();
    Feature feature_tmp;

    feature_tmp.Clear();
    if (!bytes_feature(feature_tmp, image_b)) {
      printf("image: [%s] wrong\n", dataset[line_id].first.c_str());
      continue;
    }

    (*feature1)["image_raw"] = feature_tmp;
    feature_tmp.Clear();
    if (!int64_feature(feature_tmp, height)) {
      printf("image: [%s] , height [%d] wrong\n", dataset[line_id].first.c_str(), height);
      continue;
    }
    (*feature1)["height"] = feature_tmp;

    feature_tmp.Clear();
    if (!int64_feature(feature_tmp, width)) {
      printf("image: [%s] , width [%d] wrong\n", dataset[line_id].first.c_str(), width);
      continue;
    }

    (*feature1)["width"] = feature_tmp;

    feature_tmp.Clear();
    if (!int64_feature(feature_tmp, depth)) {
      printf("image: [%s] , depth [%d] wrong\n", dataset[line_id].first.c_str(), depth);
      continue;
    }
    (*feature1)["depth"] = feature_tmp;

    //此次預設分類資料集的label已經轉化為了0,1,2,3,4,5這樣的形式,否則此處需要加上name to label的轉化程式碼
    feature_tmp.Clear();
    if (!int64_feature(feature_tmp, dataset[line_id].second)) {
      printf("image: [%s]  wrong\n", dataset[line_id].first.c_str());
      continue;
    }
    (*feature1)["label"] = feature_tmp;
    //將example序列化為string,並寫入Writer_
    std::string str;
    example1.SerializeToString(&str);
    writer_->WriteRecord(str);
    ++j;

    if (line_id % 1000 == 0) {
      printf("Processed %lu files.\n", line_id);
    }
  }

  printf("Processed %lu files.\n finished", line_id);

  if (writer_ != NULL) {
    delete writer_;
    writer_ = NULL;
  }

  return true;
}

  

其中,matToBytes函式定義如下:

std::string matToBytes(cv::Mat image) {
  int size = image.total() * image.elemSize();
  byte* bytes = new byte[size];
  memcpy(bytes, image.data, size * sizeof(byte));

  std::string img_s(bytes, size);
  return img_s;
}

string轉feature,或vector<int>轉feature等定義如下:

//函式過載,使得int和vector<int>都可以轉換為feature
bool int64_feature(Feature& feature, int value) {
  Int64List* i_list1 = feature.mutable_int64_list();
  i_list1->add_value(value);
  return true;
}
bool int64_feature(Feature& feature, std::vector<int> value) {
  if (value.size() < 1) {
    printf("value int64 is wrong!!!");
    return false;
  }
  Int64List* i_list1 = feature.mutable_int64_list();
  for (size_t i = 0; i < value.size(); ++i) i_list1->add_value(value[i]);
  return true;
}

bool float_feature(Feature& feature, std::vector<double> value) {
  if (value.size() < 1) {
    printf("value float is wrong!!!");
    return false;
  }
  FloatList* f_list1 = feature.mutable_float_list();
  for (size_t i = 0; i < value.size(); ++i) f_list1->add_value(value[i]);
  return true;
}

//將影象資訊轉換為feature
bool bytes_feature(Feature& feature, std::string value) {
  BytesList* b_list1 = feature.mutable_bytes_list();
  //影象中含有0可能會存在問題
  b_list1->add_value(value);
  return true;
}

對於檢測任務,大體流程一致,列表讀取程式碼有點差異,另外需要增加對xml檔案的格式化處理,可以使用boost的xml解析,大體程式碼如下:

bool ReadXMLToExapmle(const string& image_file, const string& xmlfile, const int img_height,
                      const int img_width, const std::map<string, int>& name_to_label,
                      RecordWriter* writer_) {
  //影象讀取
  cv::Mat image = ReadImageToCVMat(image_file);
  if (!image.data) {
    cout << "Could not open or find file " << image_file;
    return false;
  }
  //將Mat轉換為string
  std::string image_b = matToBytes(image);

  Example example1;
  Features* features1 = example1.mutable_features();
  ::google::protobuf::Map<string, Feature>* feature1 = features1->mutable_feature();
  Feature feature_tmp;

  feature_tmp.Clear();
  if (!bytes_feature(feature_tmp, image_b)) {
    printf("image: [%s] wrong\n", image_file.c_str());
    return false;
    ;
  }

  (*feature1)["image/encoded"] = feature_tmp;

  ptree pt;
  read_xml(xmlfile, pt);

  // Parse annotation.
  int width = 0, height = 0, depth = 0;
  try {
    height = pt.get<int>("annotation.size.height");
    width = pt.get<int>("annotation.size.width");
    depth = pt.get<int>("annotation.size.depth");
  } catch (const ptree_error& e) {
    std::cout << "when parsing " << xmlfile << ":" << e.what() << std::endl;
    height = img_height;
    width = img_width;
    return false;
  }

  feature_tmp.Clear();
  feature_tmp.Clear();
  if (!int64_feature(feature_tmp, height)) {
    printf("xml : [%s] 's height wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/height"] = feature_tmp;

  feature_tmp.Clear();
  if (!int64_feature(feature_tmp, width)) {
    printf("xml : [%s] 's width wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/width"] = feature_tmp;

  feature_tmp.Clear();
  if (!int64_feature(feature_tmp, depth)) {
    printf("xml : [%s] 's depth wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/depth"] = feature_tmp;

  std::vector<int> v_label;
  std::vector<int> v_difficult;
  std::vector<double> v_xmin;
  std::vector<double> v_ymin;
  std::vector<double> v_xmax;
  std::vector<double> v_ymax;

  BOOST_FOREACH (ptree::value_type& v1, pt.get_child("annotation")) {
    ptree pt1 = v1.second;
    if (v1.first == "object") {
      bool difficult = false;
      ptree object = v1.second;
      BOOST_FOREACH (ptree::value_type& v2, object.get_child("")) {
        ptree pt2 = v2.second;
        if (v2.first == "name") {
          string name = pt2.data();
          if (name_to_label.find(name) == name_to_label.end()) {
            std::cout << "file : [" << xmlfile << "] Unknown name: " << name << std::endl;
            return true;
          }
          int label = name_to_label.find(name)->second;
          v_label.push_back(label);
        } else if (v2.first == "difficult") {
          difficult = pt2.data() == "1";
          v_difficult.push_back(difficult);
        } else if (v2.first == "bndbox") {
          int xmin = pt2.get("xmin", 0);
          int ymin = pt2.get("ymin", 0);
          int xmax = pt2.get("xmax", 0);
          int ymax = pt2.get("ymax", 0);

          if ((xmin > width) || (ymin > height) || (xmax > width) || (ymax > height) ||
              (xmin < 0) || (ymin < 0) || (xmax < 0) || (ymax < 0)) {
            std::cout << "bounding box exceeds image boundary." << std::endl;
            return false;
          }
          v_xmin.push_back(xmin);
          v_ymin.push_back(ymin);
          v_xmax.push_back(xmax);
          v_ymax.push_back(ymax);
        }
      }
    }
  }
  feature_tmp.Clear();
  if (!int64_feature(feature_tmp, v_label)) {
    printf("xml : [%s]'s label wrong\n", xmlfile.c_str());
    return false; 
  }       
          
  (*feature1)["image/object/bbox/label"] = feature_tmp;
  feature_tmp.Clear();
  if (!int64_feature(feature_tmp, v_difficult)) {
    printf("xml : [%s]'s difficult wrong\n", xmlfile.c_str());
    return false;
  }         
            
  (*feature1)["image/object/bbox/difficult"] = feature_tmp;
  feature_tmp.Clear();
  if (!float_feature(feature_tmp, v_xmin)) {
    printf("xml : [%s]'s v_xmin wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/object/bbox/xmin"] = feature_tmp;
  feature_tmp.Clear();
  if (!float_feature(feature_tmp, v_ymin)) {
    printf("xml : [%s]'s v_ymin wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/object/bbox/ymin"] = feature_tmp;
  feature_tmp.Clear();
  if (!float_feature(feature_tmp, v_xmax)) {
    printf("xml : [%s]'s v_xmax wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/object/bbox/xmax"] = feature_tmp;
  feature_tmp.Clear();
  if (!float_feature(feature_tmp, v_ymax)) {
    printf("xml : [%s]'s v_ymax wrong\n", xmlfile.c_str());
    return false;
  }

  (*feature1)["image/object/bbox/xmax"] = feature_tmp;
  //序列化example並寫入writerrecord
  std::string str;
  example1.SerializeToString(&str);
  
  writer_->WriteRecord(str);
  return true;
}

最終編譯Makefile如下:

all:
	rm -rf example.pb*
	${PROTOBUF_HOME}/bin/protoc -I=. --cpp_out=./ example.proto		
	${PROTOBUF_HOME}/bin/protoc -I=. --cpp_out=./ label.proto		
	g++ -std=c++11 -o dataset_to_tfrecord dataset_to_tfrecord.cc example.pb.cc common.cpp -I/usr/local/opencv2/include -L/usr/local/opencv2/lib -L. -lopencv_core -lopencv_highgui -lopencv_imgproc -Itensorflow的路徑 -Itensorflow的路徑/bazel-genfiles -I${PROTOBUF_HOME}/include -I/usr/local/include/eigen3 -L${PROTOBUF_HOME}/lib -Ltensorflow的路徑/bazel-bin/tensorflow/ -lprotobuf -ltensorflow_framework -I${JSONCPP_HOME}/include -L${JSONCPP_HOME}/lib -ljson_linux-gcc-5.4.0_libmt