[slim.learning.py]()
TF-Slim 模型训练代码. 包含了模型训练个不同函数,如,
[1] - 梯度控制(manipulation gradientes)
[2] - train_op
创建,计算 loss 和应用梯度的 operation.
[3] - 训练 loop 函数(training loop function).
The training loop allows the user to pass in the
train_op
and runs the optimization according to user-specified arguments.
1. 模型训练简单流程
# 加载数据/创建模型
images, labels = LoadData(...)
predictions = MyModel(images)
# 定义损失函数loss
slim.losses.log_loss(predictions, labels)
total_loss = slim.losses.get_total_loss()
# 定义优化器optimizer
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
# 创建train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
# 运行训练.
slim.learning.train(train_op, my_log_dir)
1.1 创建 train_op
模型训练时,TF-Slim 的训练循环函数需要定义 train_op
.
该 train_op
的作用:
[1] - 计算 loss.
[2] - 应用梯度,更新权重.
[3] - 返回 loss 值.
slim.learning.create_train
则用于创建 train_op
,如:
# 创建 train_op 和梯度裁剪(clip the gradient norms)
train_op = slim.learning.create_train_op(
total_loss,
optimizer,
clip_gradient_norm=4)
# 创建 train_op 和缩放梯度值
# 根据提供的从变量名(variable name 或 variable) 到
# 缩放系数(scaling coefficient) 的映射来进行缩放梯度值.
# scale the gradients by providing a map from variable
# name (or variable) to a scaling coefficient:
gradient_multipliers = {'conv0/weights': 1.2,
'fc8/weights': 3.4,
}
train_op = slim.learning.create_train_op(
total_loss,
optimizer,
gradient_multipliers=gradient_multipliers)
注: 梯度裁剪:
梯度裁剪的目的是避免梯度爆炸. 其通过控制梯度的最大范数(norm) 来实现.
tf.clip_by_global_norm
.
tf.clip_by_global_norm(
t_list,
clip_norm,
use_norm=None,
name=None
)
1.2 训练过程其它(非梯度)更新
很多网络中,会利用到 BatchNorm 等模块,其在训练过程中,需要进行一系列的非梯度更新(non-gradient updates).
slim.learning.create_train_op
还支持传递与梯度更新一起的其它 update_ops
列表.
train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops)
slim.learning.create_train_op
默认包含了所有的更新的 ops
,其是 tf.GraphLeys.UPDATE_OPS
collection 的一部分.
此外,TF-Slim 的 slim.batch_norm
函数还在该 collection
添加了 moving mean
和 moving variance
更新. 故,如果采用到了 slim.batch_norm
函数则不需任何额外的计算 moving mean
和 moving variance
更新的处理.
不过,也可以针对 tf.GraphKeys.UPDATE_OPS
collection,覆盖重写默认的 update ops
或者新增 update ops
.
# 强制 TF-Slim 不采用任何 update_ops:
train_op = slim.learning.create_train_op(
total_loss,
optimizer,
update_ops=[])
# 替换 update ops 集:
train_op = slim.learning.create_train_op(
total_loss,
optimizer,
update_ops=my_other_update_ops)
# 新增 update ops 到默认的 updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1)
train_op = slim.learning.create_train_op(
total_loss,
optimizer)
# 等价形式:
train_op = slim.learning.create_train_op(
total_loss,
optimizer,
update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))
1.3 从断点文件初始化模型
模型训练时,往往需要从与训练的断点模型文件 warm-start 训练.
...
# 创建 train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
# 创建初始化赋值 op
checkpoint_path = '/path/to/old_model_checkpoint'
variables_to_restore = slim.get_model_variables()
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
checkpoint_path, variables_to_restore)
# 创建初始化赋值函数
def InitAssignFn(sess):
sess.run(init_assign_op, init_feed_dict)
# 运行训练
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
1.4 从内存变量初始化模型变量
在模型训练中,也可能需要从任意源(如文本文档,matlab 文件等)的值来初始化模型权重. 虽然采用原始 TensorFlow 是技术可行的,但其仍需要权重值是以图(graph) 的形式保存的. 这对于大型模型而言,很可能是很大的文件. TF-Slim 提供了一种无需将初始模型权重值保存为图(graph)的初始化赋值方法,其采用了 placeholders(占位符)
和 feed dictionary
:
# 创建 train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
# 创建变量名到值的映射(from variable names to values):
var0_initial_value = ReadFromDisk(...)
var1_initial_value = ReadFromDisk(...)
var_names_to_values = {'var0': var0_initial_value,
'var1': var1_initial_value,
}
init_assign_op, init_feed_dict = slim.assign_from_values(var_names_to_values)
# 创建初始化赋值函数
def InitAssignFn(sess):
sess.run(init_assign_op, init_feed_dict)
# 运行训练
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
2. tf.learning.train()
采用 TensorFlow supervisor 运行模型循环的函数:
# 返回:训练后的loss函数值.
_USE_DEFAULT = 0
def train(train_op,
logdir,
train_step_fn=train_step,
train_step_kwargs=_USE_DEFAULT,
log_every_n_steps=1,
graph=None,
master='',
is_chief=True,
global_step=None,
number_of_steps=None,
init_op=_USE_DEFAULT,
init_feed_dict=None,
local_init_op=_USE_DEFAULT,
init_fn=None,
ready_op=_USE_DEFAULT,
summary_op=_USE_DEFAULT,
save_summaries_secs=600,
summary_writer=_USE_DEFAULT,
startup_delay_steps=0,
saver=None,
save_interval_secs=600,
sync_optimizer=None,
session_config=None,
session_wrapper=None,
trace_every_n_steps=None,
ignore_live_threads=False)
[1] - train_op - Tensor,如图像输入Tensor,当执行时,进行梯度计算并返回 loss 值.
[2] - logdir - 训练日志保存路径. 如果是 None,则不写入模型断点(checkpoints) 和概要(summaries).
[3] - train_step_fn - 调用函数,用于执行单次梯度计算. 该函数必须有四个参数:current seesion,train_op Tensor,global step Tensor 和 a dictionary.
[4] - train_step_kwargs - 字典形式,其传递到 train_step_fn
. 默认提供的是两个布尔参数(Boolean), should_stop
和 should_log
两个标量 ops.
[5] - log_every_n_steps - loss 和 global step 相对于 global steps 的日志保存频率.
[6] - graph - 传递到 supervisor 的图(graph). 如果值为 None,则采用默认图(graph).
[7] - master - tensorflow master 的地址.
[8] - is_chief - 指定在复制训练(replica training)时,训练是否以主要副本(primary replica) 运行.
[9] - global_step - 表示 global step 的 Tensor. 如果值是 None,则采用 training_util.get_or_create_global_step()
,即:tf.contrib.framework.global_step()
.
[10] - number_of_steps - 训练所进行的梯度计算的最大数量. 采用 global_step 来衡量:当 global_step 大于 number_of_steps 时,训练停止. 如果值是None,则训练无限进行.
[11] - init_op - 初始化操作. 如果是默认值,则 session 通过调用 tf.global_variables_initializer()
进行初始化.
[12] - init_feed_dict - 执行 init_op 时所采用的 feed dictionary.
[13] - local_init_op - 局部初始化操作. 如果是默认值,则 session 通过调用tf.local_variables_initializer()
和 tf.tables_initializer()
进行初始化.
[14] - init_fn - 调用 init_op 后,待执行的可调用参数. 该调用必须有一个参数,session 才进行初始化.
[15] - ready_op - 检查模型是否准备好的操作. 如果是默认值,则 session 通过调用 tf.report_uninitialized_variables()
检查模型读取.
[16] - summary_op - summary 操作.
[17] - save_summaries_secs - 每隔多少秒保存一次 summaries.
[18] - summary_writer - 采用的 SummaryWriter
. 值如果是 None,则表示不写任何 summaries. 如果未设置(unset),则创建一个 SummaryWriter.
[19] - startup_delay_steps - 开始训练前所等待的迭代次数. 如果采用了 sync_optimizer 则其值必须是0.
[20] - saver - 保存断点(checkpoints)的 Saver. 如果值时 None,则创建和使用默认的.
[21] - save_interval_secs - 每隔多少秒保存一次模型到 logdir 路径.
[22] - sync_optimizer - tf.train.SyncReplicasOptimizer
实例,或tf.train.SyncReplicasOptimizer
实例列表. 如果提供了参数,则进行同步(synchronous)的梯度更新. 如果值是 None,则进行异步(asynchronous)的梯度更新.
[23] - session_config - 用于配置 Session
的tf.ConfigProto
实例. 如果值是 None,则采用默认设置.
[24] - session_wrapper - 一个函数接口,其采用 tf.Session
对象作为唯一参数,并返回与原始对象具有相同方法的 封装 session 对象;或者返回 None. 如果其值不是 None,则训练采用封装的对象.
[25] - trace_every_n_steps - 产生并保存 Chrome trace 格式的 Timeline
,并每 trace_every_n_steps 将其添加到 summaries. 如果值是 None,则不产生和保存任何 trace 信息.
[26] - ignore_live_threads - 如果值是 True,则,当停止 supervisor时,忽略在一个 grace period 周期内剩余的线程,而不是抛出 RuntimeError.
ValueError 出现情况:
train_op 为空;
或者,当提供了 sync_optimizer,而 startup_delay_steps 是非零值;
或者,number_of_steps 是负值;
或者 trace_every_n_steps 不是 None,但未提供 logdir.
_USE_DEFAULT = 0
def train(train_op,
logdir,
train_step_fn=train_step,
train_step_kwargs=_USE_DEFAULT,
log_every_n_steps=1,
graph=None,
master='',
is_chief=True,
global_step=None,
number_of_steps=None,
init_op=_USE_DEFAULT,
init_feed_dict=None,
local_init_op=_USE_DEFAULT,
init_fn=None,
ready_op=_USE_DEFAULT,
summary_op=_USE_DEFAULT,
save_summaries_secs=600,
summary_writer=_USE_DEFAULT,
startup_delay_steps=0,
saver=None,
save_interval_secs=600,
sync_optimizer=None,
session_config=None,
session_wrapper=None,
trace_every_n_steps=None,
ignore_live_threads=False):
if train_op is None:
raise ValueError('train_op cannot be None.')
if logdir is None:
if summary_op != _USE_DEFAULT:
raise ValueError('Cannot provide summary_op because logdir=None')
if saver is not None:
raise ValueError('Cannot provide saver because logdir=None')
if trace_every_n_steps is not None:
raise ValueError('Cannot provide trace_every_n_steps because '
'logdir=None')
if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
sync_optimizer = [sync_optimizer]
if sync_optimizer is not None and startup_delay_steps > 0:
raise ValueError(
'startup_delay_steps must be zero when sync_optimizer is supplied.')
if number_of_steps is not None and number_of_steps <= 0:
raise ValueError(
'`number_of_steps` must be either None or a positive number.')
graph = graph or ops.get_default_graph()
with graph.as_default():
if global_step is None:
global_step = training_util.get_or_create_global_step()
saver = saver or tf_saver.Saver()
if sync_optimizer is not None:
for opt in sync_optimizer:
if not isinstance(opt, sync_replicas_optimizer.SyncReplicasOptimizer):
raise ValueError(
'`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')
with ops.name_scope('init_ops'):
if init_op == _USE_DEFAULT:
init_op = variables.global_variables_initializer()
if ready_op == _USE_DEFAULT:
ready_op = variables.report_uninitialized_variables()
if local_init_op == _USE_DEFAULT:
local_init_op = control_flow_ops.group(
variables.local_variables_initializer(),
lookup_ops.tables_initializer())
if sync_optimizer is not None and isinstance(sync_optimizer, list):
with ops.control_dependencies([local_init_op] if local_init_op is
not None else []):
if is_chief:
local_init_op = control_flow_ops.group(
*[opt.chief_init_op for opt in sync_optimizer])
else:
local_init_op = control_flow_ops.group(
*[opt.local_step_init_op for opt in sync_optimizer])
ready_for_local_init_op = control_flow_ops.group(
*[opt.ready_for_local_init_op for opt in sync_optimizer])
else:
ready_for_local_init_op = None
if summary_op == _USE_DEFAULT:
summary_op = summary.merge_all()
if summary_writer == _USE_DEFAULT:
summary_writer = supervisor.Supervisor.USE_DEFAULT
if is_chief and sync_optimizer is not None:
# Need to create these BEFORE the supervisor finalizes the graph:
init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer]
chief_queue_runner = [
opt.get_chief_queue_runner() for opt in sync_optimizer]
if train_step_kwargs == _USE_DEFAULT:
with ops.name_scope('train_step'):
train_step_kwargs = {}
if number_of_steps:
should_stop_op = math_ops.greater_equal(global_step, number_of_steps)
else:
should_stop_op = constant_op.constant(False)
train_step_kwargs['should_stop'] = should_stop_op
if log_every_n_steps > 0:
train_step_kwargs['should_log'] = math_ops.equal(
math_ops.mod(global_step, log_every_n_steps), 0)
if is_chief and trace_every_n_steps is not None:
train_step_kwargs['should_trace'] = math_ops.equal(
math_ops.mod(global_step, trace_every_n_steps), 0)
train_step_kwargs['logdir'] = logdir
sv = supervisor.Supervisor(
graph=graph,
is_chief=is_chief,
logdir=logdir,
init_op=init_op,
init_feed_dict=init_feed_dict,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
ready_op=ready_op,
summary_op=summary_op,
summary_writer=summary_writer,
global_step=global_step,
saver=saver,
save_summaries_secs=save_summaries_secs,
save_model_secs=save_interval_secs,
init_fn=init_fn)
if summary_writer is not None:
train_step_kwargs['summary_writer'] = sv.summary_writer
total_loss = None
should_retry = True
while should_retry:
try:
should_retry = False
with sv.managed_session(
master, start_standard_services=False, config=session_config) as sess:
logging.info('Starting Session.')
if session_wrapper is not None:
logging.info(
'Wrapping session with wrapper function: %s', session_wrapper)
sess = session_wrapper(sess)
if is_chief:
if logdir:
sv.start_standard_services(sess)
elif startup_delay_steps > 0:
# (use sys.maxsize because sys.maxint doesn't exist in Python 3)
_wait_for_step(sess, global_step,
min(startup_delay_steps, number_of_steps or
sys.maxsize))
threads = sv.start_queue_runners(sess)
logging.info('Starting Queues.')
if is_chief and sync_optimizer is not None:
sv.start_queue_runners(sess, chief_queue_runner)
sess.run(init_tokens_op)
try:
while not sv.should_stop():
# 训练计算
total_loss, should_stop = train_step_fn(sess,
train_op,
global_step,
train_step_kwargs)
if should_stop:
logging.info('Stopping Training.')
sv.request_stop()
break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
if logdir and sv.is_chief:
logging.info('Finished training! Saving model to disk.')
sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
sv.stop(
threads,
close_summary_writer=True,
ignore_live_threads=ignore_live_threads)
except errors.AbortedError:
# Always re-run on AbortedError as it indicates a restart of one of the
# distributed tensorflow servers.
logging.info('Retrying training!')
should_retry = True
return total_loss
2.1 tf.learning.train_step()
def train_step(sess, train_op, global_step, train_step_kwargs):
"""
函数用于进行一次梯度计算,指定是否停止训练.
Args:
sess: 当前 session.
train_op: 计算梯度的操作`Operation`,并返回 total loss.
global_step: 表示 global training step 的 Tensor.
train_step_kwargs: 关键词参数字典.
Returns:
total loss 和是否停止训练的布尔值.
Raises:
ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
"""
start_time = time.time()
trace_run_options = None
run_metadata = None
if 'should_trace' in train_step_kwargs:
if 'logdir' not in train_step_kwargs:
raise ValueError('logdir must be present in train_step_kwargs when '
'should_trace is present')
if sess.run(train_step_kwargs['should_trace']):
trace_run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
total_loss, np_global_step = sess.run([train_op, global_step],
options=trace_run_options,
run_metadata=run_metadata)
time_elapsed = time.time() - start_time
if run_metadata is not None:
tl = timeline.Timeline(run_metadata.step_stats)
trace = tl.generate_chrome_trace_format()
trace_filename = os.path.join(train_step_kwargs['logdir'],
'tf_trace-%d.json' % np_global_step)
logging.info('Writing trace to %s', trace_filename)
file_io.write_string_to_file(trace_filename, trace)
if 'summary_writer' in train_step_kwargs:
train_step_kwargs['summary_writer'].add_run_metadata(run_metadata,
'run_metadata-%d' %
np_global_step)
if 'should_log' in train_step_kwargs:
if sess.run(train_step_kwargs['should_log']):
logging.info('global step %d: loss = %.4f (%.3f sec/step)',
np_global_step, total_loss, time_elapsed)
if 'should_stop' in train_step_kwargs:
should_stop = sess.run(train_step_kwargs['should_stop'])
else:
should_stop = False
return total_loss, should_stop
2.2 tf.learning.create_train_op()
函数用于创建梯度计算和返回loss 的 Operation
. 如:
total_loss = slim.losses.get_total_loss()
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
#
train_op = slim.learning.create_train_op(total_loss, optimizer)
def clip_gradient_norms(gradients_to_variables, max_norm):
"""
根据给定值对梯度裁剪.
Args:
gradients_to_variables: 梯度和变量对(元祖)列表
max_norm: 最大范数值.
Returns:
A list of clipped gradient to variable pairs.
"""
clipped_grads_and_vars = []
for grad, var in gradients_to_variables:
if grad is not None:
if isinstance(grad, ops.IndexedSlices):
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad = clip_ops.clip_by_norm(grad, max_norm)
clipped_grads_and_vars.append((grad, var))
return clipped_grads_and_vars
def multiply_gradients(grads_and_vars, gradient_multipliers):
"""
乘以指定梯度.
Args:
grads_and_vars: 梯度和变量对(元祖)列表.
gradient_multipliers: A map from either `Variables` or `Variable` op names
to the coefficient by which the associated gradient should be scaled.
Returns:
The updated list of gradient to variable pairs.
Raises:
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
is empty or None or if `gradient_multipliers` is not a dictionary.
"""
if not isinstance(grads_and_vars, list):
raise ValueError('`grads_and_vars` must be a list.')
if not gradient_multipliers:
raise ValueError('`gradient_multipliers` is empty.')
if not isinstance(gradient_multipliers, dict):
raise ValueError('`gradient_multipliers` must be a dict.')
multiplied_grads_and_vars = []
for grad, var in grads_and_vars:
if var in gradient_multipliers or var.op.name in gradient_multipliers:
key = var if var in gradient_multipliers else var.op.name
if grad is None:
raise ValueError('Requested multiple of `None` gradient.')
multiplier = gradient_multipliers[key]
if not isinstance(multiplier, ops.Tensor):
multiplier = constant_op.constant(multiplier, dtype=grad.dtype)
if isinstance(grad, ops.IndexedSlices):
tmp = grad.values * multiplier
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
grad *= multiplier
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
_USE_GLOBAL_STEP = 0
def create_train_op(total_loss,
optimizer,
global_step=_USE_GLOBAL_STEP,
update_ops=None,
variables_to_train=None,
clip_gradient_norm=0,
summarize_gradients=False,
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
gradient_multipliers=None,
check_numerics=True):
def transform_grads_fn(grads):
if gradient_multipliers:
with ops.name_scope('multiply_grads'):
grads = multiply_gradients(grads, gradient_multipliers)
# Clip gradients.
if clip_gradient_norm > 0:
with ops.name_scope('clip_grads'):
grads = clip_gradient_norms(grads, clip_gradient_norm)
return grads
return training.create_train_op(
total_loss=total_loss,
optimizer=optimizer,
global_step=global_step,
update_ops=update_ops,
variables_to_train=variables_to_train,
transform_grads_fn=transform_grads_fn,
summarize_gradients=summarize_gradients,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
check_numerics=check_numerics)
[1] - total_loss - 表示 total loss 的 Tensor
.
[2] - optimizer - 用于梯度计算的 tf.Optimizer
.
[3] - global_step - 表示 global step 变量的 Tensor
. 如果值是默认的 _USE_GLOBAL_STEP
, 则采用 tf.contrib.framework.global_step()
.
[4] - update_ops - 执行的更新参数列表. 如果值是 None,则 update ops 设置为 is tf.GraphKeys.UPDATE_OPS
collection 中的内容,并显示一条警告warning.
[5] - variables_to_train - 训练的变量参数列表. 如果值是 None,则默认为所有的 tf.trainable_variables()
.
[6] - clip_gradient_norm - 如果值大于 0,则梯度被剪枝.
[7] - summarize_gradients - 是否添加每个梯度的 summaries.
[8] - gate_gradients - 如何进行梯度计算. 参考 tf.Optimizer
.
[9] - aggregation_method - 指定用于组合梯度项的方法. 可选值定义在 AggregationMethod
类中.
[10] - colocate_gradients_with_ops - Whether or not to try colocating the gradients with the ops that generated them.
[11] - gradient_multipliers - A dictionary of either Variables
or Variable
op names to the coefficient by which the associated gradient should be scaled.
[12] - check_numerics - 是否进行 check_numerics.