1. 程式人生 > 其它 >調優哪家強——tensorflow命令列引數

調優哪家強——tensorflow命令列引數

深度學習神經網路往往有過多的Hyperparameter需要調優,優化演算法、學習率、卷積核尺寸等很多引數都需要不斷調整,使用命令列引數是非常方便的。有兩種實現方式,一是利用python的argparse包,二是呼叫tensorflow自帶的app.flags實現。

利用python的argparse包

argparse介紹及基本使用:

http://www.jianshu.com/p/b8b09084bd1a

下面程式碼用argparse實現了命令列引數的輸入。

import argparse
import sys
parser = argparse.ArgumentParser()
parser.add_argument('--fake_data', nargs='?', const=True, type=bool,                       
default=False,                       
help='If true, uses fake data for unit testing.')
parser.add_argument('--max_steps', type=int, default=1000,                       
help='Number of steps to run trainer.')
parser.add_argument('--learning_rate', type=float, default=0.001,                       
help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=0.9,                       
help='Keep probability for training dropout.')
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',                       help='Directory for storing input data') parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',                       
help='Summaries log directory') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

通過呼叫python的argparse包,呼叫函式parser.parse_known_args()解析命令列引數。程式碼執行後得到的FLAGS是一個結構體,內部引數分別為:

FLAGS.data_dir
Out[5]: '/tmp/tensorflow/mnist/input_data'
 FLAGS.fake_data Out[6]: False  FLAGS.max_steps
Out[7]: 1000
 FLAGS.learning_rate
Out[8]: 0.001
 FLAGS.dropout
Out[9]: 0.9
 FLAGS.data_dir
Out[10]: '/tmp/tensorflow/mnist/input_data'
 FLAGS.log_dir
Out[11]: '/tmp/tensorflow/mnist/logs/mnist_with_summaries'

利用tf.app.flags元件

首先需要定義一個tf.app.flags物件,呼叫自帶的DEFINE_string, DEFINE_boolean, DEFINE_integer, DEFINE_float設定不同型別的命令列引數及其預設值。當然,也可以在終端用命令列引數修改這些預設值。

# Define hyperparameters
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean("enable_colored_log", False, "Enable colored log")                     
"The glob pattern of train TFRecords files")
flags.DEFINE_string("validate_tfrecords_file",                     
"./data/a8a/a8a_test.libsvm.tfrecords",     
"The glob pattern of validate TFRecords files")
flags.DEFINE_integer("label_size", 2, "Number of label size")
flags.DEFINE_float("learning_rate", 0.01, "The learning rate")
 def main():    
 # Get hyperparameters     
if FLAGS.enable_colored_log:         
import coloredlogs         
coloredlogs.install()     
logging.basicConfig(level=logging.INFO)     
FEATURE_SIZE = FLAGS.feature_size     
LABEL_SIZE = FLAGS.label_size       
...   
return 0
if __name__ == ‘__main__’:     main()

這段程式碼採用的是tensorflow庫中自帶的tf.app.flags模組實現命令列引數的解析。如果用終端執行tf程式,用上述兩種方式都可以,如果用spyder之類的工具,那麼只有第一種方式有用,第二種方式會報錯。

其中有個tf.app.flags元件,還有個tf.app.run()函式。官網幫助檔案是這麼說的:

flags module: Implementation of the flags interface.
run(...): Runs the program with an optional 'main' function and 'argv' list.

tf.app.run的原始碼:

1."""Generic entry point script."""   
2.from __future__ import absolute_import   
3.from __future__ import division   
4.from __future__ import print_function   
5.   
6.import sys   
7.   
8.from tensorflow.python.platform import flags   
9.   
10.   
11.def run(main=None):   
12.  f = flags.FLAGS   
13.  f._parse_flags()   
14.  main = main or sys.modules['__main__'].main   
15.  sys.exit(main(sys.argv))

也就是處理flag解析,然後執行main函式。

用shell指令碼實現訓練程式碼的執行

在終端執行python程式碼,首先需要在程式碼檔案開頭寫入shebang,告訴系統環境變數如何設定,用python2還是用python3來編譯這段程式碼。然後修改程式碼許可權為可執行,用 ./python_code.py 就可以執行。同理,這段程式碼也可以用shell指令碼來實現。建立.sh檔案,執行python_code.py並設定引數max_steps=100

python python_code.py --max_steps 100