TensorFlow 的 Variables 给出了用于从断点模型恢复模型变量的函数,如 get_variables_to_restore 函数. TensorFlow - TF-Slim 之 variables 函数
<h2>D1 - 转自 Tensorflow 部分恢复模型</h2>
# 创建变量
v1 = slim.variable(name="v1", ...)
v2 = slim.variable(name="nested/v2", ...)
...
# 获取待恢复的变量列表,有以下四种等价方法:
variables_to_restore = slim.get_variables_by_name("v2") #1. 根据名字获取变量
variables_to_restore = slim.get_variables_by_suffix("2") #2. 根据后缀获取变量
variables_to_restore = slim.get_variables(scope="nested") #3. 根据作用域获取变量
variables_to_restore = slim.get_variables_to_restore(include=["nested"]) #4. 根据 include 正则表达式获取变量
variables_to_restore = slim.get_variables_to_restore(exclude=["v1"]) #4. 根据 exclude 正则表达式获取变量
# 如 vgg/conv6, vgg 都可以作为exclude的参数传入
# 创建 saver,以恢复变量. 模拟断点文件.
restorer = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# 从磁盘恢复变量.
restorer.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
<h2>D2 - DeepLab 断点初始化</h2>
From train_utils.py.
def get_model_init_fn(train_logdir, tf_initial_checkpoint,
initialize_last_layer, last_layers,
ignore_missing_vars=False):
"""
该函数用于从断点文件初始化模型变量.
Args:
train_logdir: Log directory for training.
tf_initial_checkpoint: TensorFlow checkpoint for initialization.
initialize_last_layer: Initialize last layer or not.
last_layers: Last layers of the model.
ignore_missing_vars: Ignore missing variables in the checkpoint.
Returns:
Initialization function.
"""
if tf_initial_checkpoint is None:
tf.logging.info('Not initializing the model from a checkpoint.')
return None
if tf.train.latest_checkpoint(train_logdir):
tf.logging.info('Ignoring initialization; other checkpoint exists')
return None
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored.
exclude_list = ['global_step']
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list) ##
return slim.assign_from_checkpoint_fn(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
From model.py
_LOGITS_SCOPE_NAME = 'logits'
_MERGED_LOGITS_SCOPE = 'merged_logits'
_IMAGE_POOLING_SCOPE = 'image_pooling'
_ASPP_SCOPE = 'aspp'
_CONCAT_PROJECTION_SCOPE = 'concat_projection'
_DECODER_SCOPE = 'decoder'
def get_extra_layer_scopes(last_layers_contain_logits_only=False):
"""Gets the scopes for extra layers.
Args:
last_layers_contain_logits_only: Boolean, True if only consider logits as
the last layer (i.e., exclude ASPP module, decoder module and so on)
Returns:
A list of scopes for extra layers.
"""
if last_layers_contain_logits_only:
return [_LOGITS_SCOPE_NAME]
else:
return [
_LOGITS_SCOPE_NAME,
_IMAGE_POOLING_SCOPE,
_ASPP_SCOPE,
_CONCAT_PROJECTION_SCOPE,
_DECODER_SCOPE,
]