Python 命令行参数定义和 Tensorflow 命令行参数定义

<h2>1. Python 命令行参数定义</h2>

argparse 库, 如:

import argparse

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument('-b', dest='batchsize',default=1)
    parser.add_argument('-g', dest='gpuid', default=0)  
    args = parser.parse_args()
    return args

if name == "__main__":
    args = parse_args()
    print(args.batchsize)
    print(args.gpuid)

<h2>2. Tensorflow 命令行参数定义</h2>

Tensorflow 采用tf.app.flags 来进行命令行参数传递.
如 - flags_test.py

import tensorflow as tf 

flags = tf.app.flags
FLAGS = flags.FLAGS

# Settings for some training parameters.
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],  
                  'Learning rate policy for training.')
flags.DEFINE_float('base_learning_rate', .0001,  
                   'The base learning rate for model training.')
flags.DEFINE_integer('learning_rate_decay_step', 2000, 
                     'Decay the base learning rate at a fixed step.')
flags.DEFINE_integer('train_batch_size', 12,
                     'The number of images in each batch during training.')
flags.DEFINE_multi_integer('train_crop_size', [513, 513],
                           'Image crop size [height, width] during training.')
flags.DEFINE_boolean('upsample_logits', True,
                     'Upsample logits during training.')
flags.DEFINE_string('dataset', 'dataset_name',
                    'Name of the test dataset.')

def main(_):  
    print(FLAGS.learning_policy)
    print(FLAGS.base_learning_rate)
    print(FLAGS.learning_rate_decay_step)
    print(FLAGS.train_batch_size)
    print(FLAGS.train_crop_size)
    print(FLAGS.upsample_logits)
    print(FLAGS.dataset)

if name == '__main__':
    tf.app.run()

运行

python flags_test.py

输出:

poly
0.0001
2000
12
[513, 513]
True
dataset_name
Last modification:October 9th, 2018 at 09:31 am