TensorFlow (6) tf.app.flags

1. tf.app.flags:

tf.app.flags 類似 argpath 可以用來解析傳入的參數。

  1. tf.app.DEFINE_string: 定義字串參數
  2. tf.app.DEFINE_boolean: 定義布林參數
  3. tf.app.DEFINE_float: 定義浮點數參數
  4. tf.app.DEFINE_integer: 定義整數參數

建立一個 tf_app.py:

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string(
    "data_path", "data/", "path of data")

tf.app.flags.DEFINE_boolean(
    "optimize", False, "enable optimization")

def main(unused_argv):
    data_path = FLAGS.data_path
    optimize = FLAGS.optimize

    print("data_path: {}".format(data_path))
    print("optimize: {}".format(optimize))

if __name__ == "__main__":
    tf.app.run()

tf.app.flags.DEFINE_string() 的格式是 tf.app.flags.DEFINE_string(flag_name, default, help), 取得參數的方式是 tf.app.flags.FLAGS.flag_name

$ python3 test.py --data_path newdir --optimize

data_path: newdir
optimize: True
$ python3 tf_app.py

data_path: data/
optimize: True
$ python3 tf_app.py -h

留言

熱門文章