1. 程式人生 > >18、使用 tf.app.flags 介面定義命令列引數

18、使用 tf.app.flags 介面定義命令列引數

一、使用 tf.app.flags 介面定義命令列引數

  • 眾所周知,深度學習有很多的 Hyperparameter 需要調優,TensorFlow 底層使用了python-gflags專案,然後封裝成tf.app.flags介面
  • 使用tf.app.flags介面可以非常方便的呼叫自帶的DEFINE_string, DEFINE_boolean, DEFINE_integer, DEFINE_float設定不同型別的命令列引數及其預設值,在實際專案中一般會提前定義命令列引數,如下所示:
# coding: utf-8
# filename: flags.py
import tensorflow as
tf # 定義一個全域性物件來獲取引數的值,在程式中使用(eg:FLAGS.iteration)來引用引數 FLAGS = tf.app.flags.FLAGS # 定義命令列引數,第一個是:引數名稱,第二個是:引數預設值,第三個是:引數描述 tf.app.flags.DEFINE_integer("iteration", 200000, "Iterations to train [2e5]") tf.app.flags.DEFINE_integer("disp_freq", 1000, "Display the current results every display_freq iterations [1e3]"
) tf.app.flags.DEFINE_integer("save_freq", 2000, "Save the checkpoints every save_freq iterations [2e3]") tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate of for adam [0.001]") tf.app.flags.DEFINE_integer("train_batch_size", 64, "The size of batch images [64]") tf.app.flags.DEFINE_integer("val_batch_size"
, 100, "The size of batch images [100]") tf.app.flags.DEFINE_integer("height", 48, "The height of image to use. [48]") tf.app.flags.DEFINE_integer("width", 160, "The width of image to use. [160]") tf.app.flags.DEFINE_integer("depth", 3, "Dimension of image color. [3]") tf.app.flags.DEFINE_string("data_dir", "/path/to/data_sets/", "Directory of dataset in the form of TFRecords.") tf.app.flags.DEFINE_string("checkpoint_dir", "/path/to/checkpoint_save_dir/", "Directory name to save the checkpoints [checkpoint]") tf.app.flags.DEFINE_string("model_name", "40w_grtr", "Model name. [40w_grtr]") tf.app.flags.DEFINE_string("gpu_id", "0", "Which GPU to be used. [0]") tf.app.flags.DEFINE_boolean("continue_train", False, "True for continue training.[False]") tf.app.flags.DEFINE_boolean("per_image_standardization", True, "True for per_image_standardization.[True]") # 定義主函式 def main(argv=None): print(FLAGS.iteration) print(FLAGS.learning_rate) print(FLAGS.data_dir) print(FLAGS.continue_train) # 執行main函式 if __name__ == '__main__': tf.app.run()

二、執行程式的方法

1、使用程式中的預設引數

  • python flags.py

這裡寫圖片描述

2、在命令列更改程式中的預設引數

  • python flags.py --iteration=500000 --learning_rate=0.01 --data_dir='/home/test/' --continue_train=True

這裡寫圖片描述