調優哪家強——tensorflow命令列引數
利用python的argparse包
argparse介紹及基本使用:
http://www.jianshu.com/p/b8b09084bd1a
下面程式碼用argparse實現了命令列引數的輸入。
import argparse import sys parser = argparse.ArgumentParser() parser.add_argument('--fake_data', nargs='?', const=True, type=bool, default=False, help='If true, uses fake data for unit testing.') parser.add_argument('--max_steps', type=int, default=1000, help='Number of steps to run trainer.') parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate') parser.add_argument('--dropout', type=float, default=0.9, help='Keep probability for training dropout.') parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', help='Directory for storing input data') parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries', help='Summaries log directory') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
通過呼叫python的argparse包,呼叫函式parser.parse_known_args()解析命令列引數。程式碼執行後得到的FLAGS是一個結構體,內部引數分別為:
FLAGS.data_dir Out[5]: '/tmp/tensorflow/mnist/input_data' FLAGS.fake_data Out[6]: False FLAGS.max_steps Out[7]: 1000 FLAGS.learning_rate Out[8]: 0.001 FLAGS.dropout Out[9]: 0.9 FLAGS.data_dir Out[10]: '/tmp/tensorflow/mnist/input_data' FLAGS.log_dir Out[11]: '/tmp/tensorflow/mnist/logs/mnist_with_summaries'
利用tf.app.flags元件
首先需要定義一個tf.app.flags物件,呼叫自帶的DEFINE_string, DEFINE_boolean, DEFINE_integer, DEFINE_float設定不同型別的命令列引數及其預設值。當然,也可以在終端用命令列引數修改這些預設值。
# Define hyperparameters flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_boolean("enable_colored_log", False, "Enable colored log") "The glob pattern of train TFRecords files") flags.DEFINE_string("validate_tfrecords_file", "./data/a8a/a8a_test.libsvm.tfrecords", "The glob pattern of validate TFRecords files") flags.DEFINE_integer("label_size", 2, "Number of label size") flags.DEFINE_float("learning_rate", 0.01, "The learning rate") def main(): # Get hyperparameters if FLAGS.enable_colored_log: import coloredlogs coloredlogs.install() logging.basicConfig(level=logging.INFO) FEATURE_SIZE = FLAGS.feature_size LABEL_SIZE = FLAGS.label_size ... return 0 if __name__ == ‘__main__’: main()
這段程式碼採用的是tensorflow庫中自帶的tf.app.flags模組實現命令列引數的解析。如果用終端執行tf程式,用上述兩種方式都可以,如果用spyder之類的工具,那麼只有第一種方式有用,第二種方式會報錯。
其中有個tf.app.flags元件,還有個tf.app.run()函式。官網幫助檔案是這麼說的:
flags module: Implementation of the flags interface.
run(...): Runs the program with an optional 'main' function and 'argv' list.
tf.app.run的原始碼:
1."""Generic entry point script."""
2.from __future__ import absolute_import
3.from __future__ import division
4.from __future__ import print_function
5.
6.import sys
7.
8.from tensorflow.python.platform import flags
9.
10.
11.def run(main=None):
12. f = flags.FLAGS
13. f._parse_flags()
14. main = main or sys.modules['__main__'].main
15. sys.exit(main(sys.argv))
也就是處理flag解析,然後執行main函式。
用shell指令碼實現訓練程式碼的執行
在終端執行python程式碼,首先需要在程式碼檔案開頭寫入shebang,告訴系統環境變數如何設定,用python2還是用python3來編譯這段程式碼。然後修改程式碼許可權為可執行,用 ./python_code.py 就可以執行。同理,這段程式碼也可以用shell指令碼來實現。建立.sh檔案,執行python_code.py並設定引數max_steps=100
python python_code.py --max_steps 100