1. 程式人生 > 其它 >tensorflow分散式訓練的學習

tensorflow分散式訓練的學習

參考這篇文章:

https://zhuanlan.zhihu.com/p/41473323

 

Tensorflow在1.4版本中引入了tf.estimator.train_and_evaluate函式,用來替換老版中Experiment類提供的功能。tf.estimator.train_and_evaluate簡化了訓練、評估和匯出Estimator模型的過程,抽象了模型分散式訓練和評估的細節,使得同樣的程式碼在本地與分散式叢集上的行為一致。

 

本文簡要介紹如何自定義Estimator模型並通過使用tf.estimator.train_and_evaluate完成訓練和評估。

 

主要步驟:

    1. 構建自己的Estimator模型
    2. 定義在訓練和測試過程中資料如何輸入給模型
    3. 定義傳遞給tf.estimator.train_and_evaluate函式的訓練、評估和匯出的詳述引數(TrainSpec and EvalSpec)
    4. 使用tf.estimator.train_and_evaluate訓練並評估模型

 

我們使用的資料集是《DBPedia Ontology Classification Dataset》(可點選下載),是從資料集DBpedia 2014中挑選的14個類別的互不重疊的本體(Company, EducationalInstitution, Artist, Athlete, OfficeHolder, MeanOfTransportation, Building, NaturalPlace, Village, Animal, Plant, Album, Film, WrittenWork),每個本體類別隨機選擇了40,000 個訓練樣本和5,000個測試樣本。因此,總共有560,000個訓練樣本和70,000個測試樣本。

 

模型程式碼

def my_model(features, labels, mode, params):
  sentence = features['sentence']
  # Get word embeddings for each token in the sentence
  embeddings = tf.get_variable(name="embeddings", dtype=tf.float32,
                               shape=[params["vocab_size"], FLAGS.embedding_size])
  sentence 
= tf.nn.embedding_lookup(embeddings, sentence) # shape:(batch, sentence_len, embedding_size) # add a channel dim, required by the conv2d and max_pooling2d method sentence = tf.expand_dims(sentence, -1) # shape:(batch, sentence_len/height, embedding_size/width, channels=1) pooled_outputs = [] for filter_size in params["filter_sizes"]: conv = tf.layers.conv2d( sentence, filters=FLAGS.num_filters, kernel_size=[filter_size, FLAGS.embedding_size], strides=(1, 1), padding="VALID", activation=tf.nn.relu) pool = tf.layers.max_pooling2d( conv, pool_size=[FLAGS.sentence_max_len - filter_size + 1, 1], strides=(1, 1), padding="VALID") pooled_outputs.append(pool) h_pool = tf.concat(pooled_outputs, 3) # shape: (batch, 1, len(filter_size) * embedding_size, 1) h_pool_flat = tf.reshape(h_pool, [-1, FLAGS.num_filters * len(params["filter_sizes"])]) # shape: (batch, len(filter_size) * embedding_size) if 'dropout_rate' in params and params['dropout_rate'] > 0.0: h_pool_flat = tf.layers.dropout(h_pool_flat, params['dropout_rate'], training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = tf.layers.dense(h_pool_flat, FLAGS.num_classes, activation=None) optimizer = tf.train.AdagradOptimizer(learning_rate=params['learning_rate']) def _train_op_fn(loss): return optimizer.minimize(loss, global_step=tf.train.get_global_step()) my_head = tf.contrib.estimator.multi_class_head(n_classes=FLAGS.num_classes) return my_head.create_estimator_spec( features=features, mode=mode, labels=labels, logits=logits, train_op_fn=_train_op_fn )

 

dataset

def input_fn(path_csv, path_vocab, shuffle_buffer_size, num_oov_buckets):
  vocab = tf.contrib.lookup.index_table_from_file(path_vocab, num_oov_buckets=num_oov_buckets)
  # Load csv file, one example per line
  dataset = tf.data.TextLineDataset(path_csv)
  # Convert line into list of tokens, splitting by white space; then convert each token to an unique id
  dataset = dataset.map(lambda line: parse_line(line, vocab))
  if shuffle_buffer_size > 0:
    dataset = dataset.shuffle(shuffle_buffer_size).repeat()
  dataset = dataset.batch(FLAGS.batch_size).prefetch(1)
  return dataset

 

訓練

classifier = tf.estimator.Estimator(
    model_fn=my_model,
    params={
      'vocab_size': config["vocab_size"],
      'filter_sizes': map(int, FLAGS.filter_sizes.split(',')),
      'learning_rate': FLAGS.learning_rate,
      'dropout_rate': FLAGS.dropout_rate
    },
    config=tf.estimator.RunConfig(model_dir=FLAGS.model_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps)
  )

  train_spec = tf.estimator.TrainSpec(
    input_fn=lambda: input_fn(path_train, path_words, FLAGS.shuffle_buffer_size, config["num_oov_buckets"]),
    max_steps=FLAGS.train_steps
  )
  input_fn_for_eval = lambda: input_fn(path_eval, path_words, 0, config["num_oov_buckets"])
  eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_for_eval, throttle_secs=300)


  tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)