tensorflow 9. 引數解析和經典入口函式tf.app.run
概述
本文總結兩種引數解析的介面,一種是python的引數解析包自帶的功能,使用時需要import argparse。另一類是tensorflow自帶的功能,解析時import tensorflow就行了。
python中的引數解析
parse_known_args的例子
tensorflow下的一個例子 tensorflow/examples/tutorials/mnist/mnist_deep.py,入口程式碼如下:
import tensorflow as tf FLAGS = None ...... if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', help='Directory for storing input data') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
parse_known_args功能解釋
關於parse_args()的使用方法可以參考之前的部落格: python引數解析和日誌記錄
parse_known_args()與parse_args()功能類似,只是它只返回已知的選項,未知選項原樣返回。
parse_known_args()在接受到多餘的命令列引數時不報錯,而是返回一個tuple型別的名稱空間和一個儲存著餘下的命令列字元的list。
這樣做的好處可以分層解析自己的選項,把剩下的引數傳給tensorflow來解析。
另一個例子
下面的例子來自這個部落格:
import argparse parser = argparse.ArgumentParser() parser.add_argument( '--flag_int', type=float, default=0.01, help='flag_int.' ) FLAGS, unparsed = parser.parse_known_args() print(FLAGS) print(unparsed)
執行輸出:
$ python prog.py --flag_int 0.02 --double 0.03 a 1
Namespace(flag_int=0.02)
['--double', '0.03', 'a', '1']
argparse模組的FLAGS
上面第一個例子,FLAGS在檔案開頭定義,可以作為依據全域性的變數使用。解析出的引數都儲存在FLAGS裡面。
這個FLAGS儲存的是argparse模組解析出的引數。後面tensorflow也定義了自己的FLAGS,不要弄混了。
tf.app.run
第一個例子中,tf.app.run使用的方式是:
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
這句話的功能是將程式名(sys.argv[0])和未知引數交給tf.app.run解析,最終傳給main函式。
run函式的定義如下:
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
# Define help flags.
_define_help_flags()
# 解析已知引數
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)
#如果沒有傳入main引數,則會預設執行main函式
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(argv)) #將引數傳給main
呼叫tf.app.run如果不出入任何引數,實際會呼叫到main函式。 這種使用tf.app.run作為入口的方式非常常見。
tensorflow的FLAGS:tf.app.flags.FLAGS
上面run函式有一行解析引數的程式碼呼叫:
# Parse known flags.
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)
意思就是,如果傳入了argv則解析argv,否則解析程序的入口引數。
例子
例子1:沒有指定引數格式
import argparse
import sys
import os
import tensorflow as tf
def main(args):
print("I'm main")
print('args{}'.format(args))
if __name__ == '__main__':
tf.app.run()
執行結果:
> python .\app_run.py --user=test
I'm main
args['.\\app_run.py', '--user=test']
預設情況下main會被呼叫,由於沒有指定引數格式,所以user不能不正確解析
例子2:指定引數格式
上一小節沒有定義引數格式,這樣tf是解析不了引數的。定義引數的格式如下:
# 字串
tf.app.flags.DEFINE_string("param_name", "default_val", "description")
# 布林型別
tf.app.flags.DEFINE_boolean("param_name", "default_val", "description")
tf.app.flags.DEFINE_bool("param_name", "default_val", "description")
# 浮點型
tf.app.flags.DEFINE_float("param_name", "default_val", "description")
# 整型
tf.app.flags.DEFINE_integer("param_name", "default_val", "description")
可以看到,如果沒有傳入引數,則會使用預設引數。
完整的程式碼如下:
# coding:utf-8
# 學習使用 tf.app.flags 使用,全域性變數
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string( "train_data_path", "/home/user", "training data dir" )
tf.app.flags.DEFINE_string( "log_dir", "./logs", " the log dir" )
tf.app.flags.DEFINE_integer( "max_sentence", 80, "max num of tokens per query" )
tf.app.flags.DEFINE_integer( "embedding_size", 50, "embedding size" )
tf.app.flags.DEFINE_float( "learning_rate", 0.001, "learning rate" )
def main( unused_argv ):
train_data_path = FLAGS.train_data_path
print( "train_data_path", train_data_path )
log_dir = FLAGS.log_dir
print( 'log_dir:{}'.format(log_dir) )
max_sentence = FLAGS.max_sentence
print( "max_sentence", max_sentence )
embdeeing_size = FLAGS.embedding_size
print( "embedding_size", embdeeing_size )
learning_rate = FLAGS.learning_rate
print( 'learning_rate:{}'.format(learning_rate))
# 使用這種方式保證了,如果此檔案被其他檔案 import的時候,不會執行main 函式
if __name__ == '__main__':
tf.app.run() # 解析命令列引數,呼叫main 函式 main(sys.argv)
我們只看到引數格式宣告的程式碼,沒看到呼叫解析的程式碼,這是因為這部分程式碼在tf.app.run()中。這樣程式碼看起來會簡潔很多。
輸出
在不指定引數的情況下,引數被賦予預設值:
> python .\tf_argument.py
train_data_path /home/user
log_dir:./logs
max_sentence 80
embedding_size 50
learning_rate:0.001
在指定引數的情況下,引數被賦予傳入值。以下兩種傳參方式等效。
# 方式1
> python .\tf_argument.py --train_data_path=./ --max_sentence=100 --embedding_size=100 --learning_rate=0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05
# 方式2
> python .\tf_argument.py --train_data_path ./ --max_sentence 100 --embedding_size 100 --learning_rate 0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05
總結
本文討論了python的引數解析和tensorflow的引數解析。
在tensorflow專案程式碼中,推薦使用tensorflow的tf.app.flags.FLAGS來解析引數,並作為全域性的配置資訊。
不知為何,tensorflow的官方例程混用了這兩種引數解析,目前我還沒想明白。
注意:tf.app.flags.FLAGS是可以作為全域性變數來用的。不同的檔案不用相互import就可以使用。