Tensorflow使用flags定義命令列引數詳解
阿新 • • 發佈:2019-01-02
TensorFlow定義了tf.app.flags,用於支援接受命令列傳遞引數,相當於接受argv,詳細用法請看程式碼中的註釋
例一:
import tensorflow as tf
##第一個是引數名稱,第二個引數是預設值,第三個是引數描述
tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.app.flags.DEFINE_integer('int_name', 10,"descript2")
tf.app.flags.DEFINE_boolean('bool_name', False, "descript3" )
FLAGS = tf.app.flags.FLAGS
##必須帶引數,否則:'TypeError: main() takes no arguments (1 given)'; ##main的引數名隨意定義,無要求
def main(_):
print(FLAGS.str_name)
print(FLAGS.int_name)
print(FLAGS.bool_name)
if __name__ == '__main__':
tf.app.run() #tf.app.run()的作用:先處理flag解析,然後執行main函式,
例二:
import tensorflow as tf
flags = tf.flags #flags是一個檔案:flags.py,用於處理命令列引數的解析工作
logging = tf.logging
#呼叫flags內部的DEFINE_string函式來制定解析規則
flags.DEFINE_string("para_name_1","default_val", "description")
flags.DEFINE_bool("para_name_2","default_val", "description")
#FLAGS是一個物件,儲存瞭解析後的命令列引數
FLAGS = flags.FLAGS
def main(_):
FLAGS.para_name #呼叫命令列輸入的引數
if __name__ == "__main__": #使用這種方式保證了,如果此檔案被其它檔案import的時候,不會執行main中的程式碼
tf.app.run() #解析命令列引數,呼叫main函式 main(sys.argv)
'''
呼叫方法,在命令列視窗中輸入:
~/ python script.py --para_name_1=name --para_name_2=name2
# 不傳的話,會使用預設值
'''
例三:
#coding:utf-8
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/feige", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("train_batch_size", 128, "batch size of train data")
tf.app.flags.DEFINE_integer("test_batch_size", 64, "batch size of test data")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")
def main(unused_argv):
train_data_path = FLAGS.train_data_path
print("train_data_path", train_data_path)
train_batch_size = FLAGS.train_batch_size
print("train_batch_size", train_batch_size)
test_batch_size = FLAGS.test_batch_size
print("test_batch_size", test_batch_size)
size_sum = tf.add(train_batch_size, test_batch_size)
with tf.Session() as sess:
sum_result = sess.run(size_sum)
print("sum_result", sum_result)
# 使用這種方式保證了,如果此檔案被其他檔案 import的時候,不會執行main 函式
if __name__ == '__main__':
tf.app.run() # 解析命令列引數,呼叫main 函式 main(sys.argv)
如果需要修改預設引數的值,則在命令列傳入自定義引數值即可,若全部使用預設引數值,則可直接在命令列執行該 python 檔案。
tf.app.run() 真正執行原理,還需查閱其原始碼:
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
flags_passthrough=f._parse_flags(args=args)這裡的_parse_flags就是我們tf.app.flags原始碼中用來解析命令列引數的函式。所以這一行就是解析引數的功能;
下面兩行程式碼也就是 tf.app.run 的核心意思:執行程式中 main 函式,並解析命令列引數!