1. 程式人生 > >tensorflow 9. 引數解析和經典入口函式tf.app.run

tensorflow 9. 引數解析和經典入口函式tf.app.run

概述

本文總結兩種引數解析的介面,一種是python的引數解析包自帶的功能,使用時需要import argparse。另一類是tensorflow自帶的功能,解析時import tensorflow就行了。

python中的引數解析

parse_known_args的例子

tensorflow下的一個例子 tensorflow/examples/tutorials/mnist/mnist_deep.py,入口程式碼如下:

import tensorflow as tf

FLAGS = None

......

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str,
                      default='/tmp/tensorflow/mnist/input_data',
                      help='Directory for storing input data')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

parse_known_args功能解釋

關於parse_args()的使用方法可以參考之前的部落格: python引數解析和日誌記錄

parse_known_args()與parse_args()功能類似,只是它只返回已知的選項,未知選項原樣返回。

parse_known_args()在接受到多餘的命令列引數時不報錯,而是返回一個tuple型別的名稱空間和一個儲存著餘下的命令列字元的list。

這樣做的好處可以分層解析自己的選項,把剩下的引數傳給tensorflow來解析。

另一個例子

下面的例子來自這個部落格

import argparse 
parser = argparse.ArgumentParser() 
parser.add_argument( 
    '--flag_int', 
    type=float, 
    default=0.01, 
    help='flag_int.' 
) 
FLAGS, unparsed = parser.parse_known_args() 
print(FLAGS) 
print(unparsed)

執行輸出:

$ python prog.py --flag_int 0.02 --double 0.03 a 1
Namespace(flag_int=0.02)
['--double', '0.03', 'a', '1']

argparse模組的FLAGS

上面第一個例子,FLAGS在檔案開頭定義,可以作為依據全域性的變數使用。解析出的引數都儲存在FLAGS裡面。

這個FLAGS儲存的是argparse模組解析出的引數。後面tensorflow也定義了自己的FLAGS,不要弄混了。

tf.app.run

第一個例子中,tf.app.run使用的方式是:

tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

這句話的功能是將程式名(sys.argv[0])和未知引數交給tf.app.run解析,最終傳給main函式。

run函式的定義如下:

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""

  # Define help flags.
  _define_help_flags()

  # 解析已知引數
  argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)

  #如果沒有傳入main引數,則會預設執行main函式
  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(argv)) #將引數傳給main

呼叫tf.app.run如果不出入任何引數,實際會呼叫到main函式。 這種使用tf.app.run作為入口的方式非常常見。

tensorflow的FLAGS:tf.app.flags.FLAGS

上面run函式有一行解析引數的程式碼呼叫:

# Parse known flags.
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)

意思就是,如果傳入了argv則解析argv,否則解析程序的入口引數。

例子

例子1:沒有指定引數格式

import argparse
import sys
import os
import tensorflow as tf


def main(args):
    print("I'm main")
    print('args{}'.format(args))

if __name__ == '__main__':
    tf.app.run()

執行結果:

> python .\app_run.py --user=test
I'm main
args['.\\app_run.py', '--user=test']

預設情況下main會被呼叫,由於沒有指定引數格式,所以user不能不正確解析

例子2:指定引數格式

上一小節沒有定義引數格式,這樣tf是解析不了引數的。定義引數的格式如下:

# 字串
tf.app.flags.DEFINE_string("param_name", "default_val", "description")
# 布林型別
tf.app.flags.DEFINE_boolean("param_name", "default_val", "description")
tf.app.flags.DEFINE_bool("param_name", "default_val", "description")
# 浮點型
tf.app.flags.DEFINE_float("param_name", "default_val", "description")
# 整型
tf.app.flags.DEFINE_integer("param_name", "default_val", "description")

可以看到,如果沒有傳入引數,則會使用預設引數。

完整的程式碼如下:

# coding:utf-8

# 學習使用 tf.app.flags 使用,全域性變數
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string( "train_data_path", "/home/user", "training data dir" )
tf.app.flags.DEFINE_string( "log_dir", "./logs", " the log dir" )
tf.app.flags.DEFINE_integer( "max_sentence", 80, "max num of tokens per query" )
tf.app.flags.DEFINE_integer( "embedding_size", 50, "embedding size" )

tf.app.flags.DEFINE_float( "learning_rate", 0.001, "learning rate" )


def main( unused_argv ):
    train_data_path = FLAGS.train_data_path
    print( "train_data_path", train_data_path )
    log_dir = FLAGS.log_dir
    print( 'log_dir:{}'.format(log_dir) )
    max_sentence = FLAGS.max_sentence
    print( "max_sentence", max_sentence )
    embdeeing_size = FLAGS.embedding_size
    print( "embedding_size", embdeeing_size )
    learning_rate = FLAGS.learning_rate
    print( 'learning_rate:{}'.format(learning_rate))


# 使用這種方式保證了,如果此檔案被其他檔案 import的時候,不會執行main 函式
if __name__ == '__main__':
    tf.app.run()  # 解析命令列引數,呼叫main 函式 main(sys.argv)

我們只看到引數格式宣告的程式碼,沒看到呼叫解析的程式碼,這是因為這部分程式碼在tf.app.run()中。這樣程式碼看起來會簡潔很多。

輸出

在不指定引數的情況下,引數被賦予預設值:

> python .\tf_argument.py
train_data_path /home/user
log_dir:./logs
max_sentence 80
embedding_size 50
learning_rate:0.001

在指定引數的情況下,引數被賦予傳入值。以下兩種傳參方式等效。

# 方式1
> python .\tf_argument.py --train_data_path=./  --max_sentence=100  --embedding_size=100 --learning_rate=0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05

# 方式2
> python .\tf_argument.py --train_data_path ./  --max_sentence 100  --embedding_size 100 --learning_rate 0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05

總結

本文討論了python的引數解析和tensorflow的引數解析。

在tensorflow專案程式碼中,推薦使用tensorflow的tf.app.flags.FLAGS來解析引數,並作為全域性的配置資訊。

不知為何,tensorflow的官方例程混用了這兩種引數解析,目前我還沒想明白。

注意:tf.app.flags.FLAGS是可以作為全域性變數來用的。不同的檔案不用相互import就可以使用。

參考資料