Python 命令行参数定义和 Tensorflow 命令行参数定义

<h2>1. Python 命令行参数定义</h2>

argparse 库, 如:

import argparse def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser() parser.add_argument('-b', dest='batchsize',default=1) parser.add_argument('-g', dest='gpuid', default=0) args = parser.parse_args() return args if name == "__main__": args = parse_args() print(args.batchsize) print(args.gpuid)

<h2>2. Tensorflow 命令行参数定义</h2>

Tensorflow 采用tf.app.flags 来进行命令行参数传递.
如 - flags_test.py

import tensorflow as tf flags = tf.app.flags FLAGS = flags.FLAGS # Settings for some training parameters. flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'], 'Learning rate policy for training.') flags.DEFINE_float('base_learning_rate', .0001, 'The base learning rate for model training.') flags.DEFINE_integer('learning_rate_decay_step', 2000, 'Decay the base learning rate at a fixed step.') flags.DEFINE_integer('train_batch_size', 12, 'The number of images in each batch during training.') flags.DEFINE_multi_integer('train_crop_size', [513, 513], 'Image crop size [height, width] during training.') flags.DEFINE_boolean('upsample_logits', True, 'Upsample logits during training.') flags.DEFINE_string('dataset', 'dataset_name', 'Name of the test dataset.') def main(_): print(FLAGS.learning_policy) print(FLAGS.base_learning_rate) print(FLAGS.learning_rate_decay_step) print(FLAGS.train_batch_size) print(FLAGS.train_crop_size) print(FLAGS.upsample_logits) print(FLAGS.dataset) if name == '__main__': tf.app.run()

运行

python flags_test.py

输出:

poly 0.0001 2000 12 [513, 513] True dataset_name
Last modification:October 9th, 2018 at 09:31 am