tf.app.run()使用方法的解釋
阿新 • • 發佈:2018-12-13
在一些github上公開的程式碼中,我們經常會看到這樣的程式
if __name__ == '__main__':
tf.app.run()
像網上的大多數文章一樣,先貼上一下run()的原始碼:
def run(main=None, argv=None): """Runs the program with an optional 'main' function and 'argv' list.""" f = flags.FLAGS # Extract the args from the optional `argv` list. args = argv[1:] if argv else None # Parse the known flags from that list, or from the command # line otherwise. # pylint: disable=protected-access flags_passthrough = f._parse_flags(args=args) # pylint: enable=protected-access main = main or _sys.modules['__main__'].main # Call the main function, passing through any arguments # to the final program. _sys.exit(main(_sys.argv[:1] + flags_passthrough)) _allowed_symbols = [ 'run', # Allowed submodule. 'flags', ] remove_undocumented(__name__, _allowed_symbols)
原始碼中寫的很清楚,首先載入flags的引數項,然後執行main()函式,其中引數使用tf.app.flags.FLAGS定義的。
tf.app.flags.FLAGS
# fila_name: temp.py import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('string', 'train', 'This is a string') tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training') tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag') print('string: ', FLAGS.string) print('learning_rate: ', FLAGS.learning_rate) print('flag: ', FLAGS.flag)
輸出:
string: train
learning_rate: 0.001
flag: True
如果在命令列中執行python3 temp.py --help,得到輸出:
usage: temp.py [-h] [--string STRING] [--learning_rate LEARNING_RATE] [--flag [FLAG]] [--noflag] optional arguments: -h, --help show this help message and exit --string STRING This is a string --learning_rate LEARNING_RATE This is the rate in training --flag [FLAG] This is a flag --noflag
如果要對FLAGS的預設值進行修改,只要輸入命令:
python3 temp.py --string 'test' --learning_rate 0.2 --flag False
聯合使用
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')
def main(unuse_args):
print('string: ', FLAGS.string)
print('learning_rate: ', FLAGS.learning_rate)
print('flag: ', FLAGS.flag)
if __name__ == '__main__':
tf.app.run()
主函式中的tf.app.run()會呼叫main,並傳遞引數,因此必須在main函式中設定一個引數位置。如果要更換main名字,只需要在tf.app.run()中傳入一個指定的函式名即可。
def test(args):
# test
...
if __name__ == '__main__':
tf.app.run(test)