Tensorflow 提供的 DeepLab API 训练得到的模型作预测.
<h2>1. Tensorflow - DeepLab Checkpoint</h2>
Tensorflow 提供的训练好的 DeepLab 模型,如 deeplabv3_pascal_trainval_2018_01_04.tar.gz
压缩包中包含三个文件:
- frozen_inference_graph.pb
- model.ckpt.data-00000-of-00001
- model.ckpt.index
实际上, Tensorflow 采用 saver.save 保存训练断点checkpoint 时,产生了三个文件:
如:
import tensorflow as tf
sess=tf.Session()
...
saver = tf.train.Saver()
model_save_path = saver.save(sess, "output/model.ckpt", global_step=30000)
print("Model File: ", model_save_path)
- model.ckpt-30000.data-00000-of-00001
- model.ckpt-30000.index
- model.ckpt-30000.meta
其中,训练的模型权重参数保存在 model.ckpt-30000.data-00000-of-00001
.
<h2>2. Tensorflow - DeepLab pb 模型加载</h2>
在 Tensorflow 提供的 DeepLab API 中,模型加载的方式为:
import tensorflow as tf class DeepLabModel(object): """ 加载 DeepLab 模型 """ INPUT_TENSOR_NAME = 'ImageTensor:0' OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' INPUT_SIZE = 513 FROZEN_GRAPH_NAME = 'frozen_inference_graph' def __init__(self, tarball_path): """ Creates and loads pretrained deeplab model. """ self.graph = tf.Graph() graph_def = None # Extract frozen graph from tar archive. tar_file = tarfile.open(tarball_path) for tar_info in tar_file.getmembers(): if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): file_handle = tar_file.extractfile(tar_info) graph_def = tf.GraphDef.FromString(file_handle.read()) break tar_file.close() if graph_def is None: raise RuntimeError('Cannot find inference graph in tar archive.') with self.graph.as_default(): tf.import_graph_def(graph_def, name='') self.sess = tf.Session(graph=self.graph)
</pre>其中,`tarball_path` 是模型压缩包,如 `deeplabv3_pascal_trainval_2018_01_04.tar.gz`. 真正加载的模型处理是:
pb_path = 'deeplabv3_pascal_trainval/frozen_inference_graph.pb'
graph_def = tf.GraphDef.FromString(open('pb_path', 'rb').read())if graph_def is None:
raise RuntimeError('Cannot find inference graph in tar archive.')
with self.graph.as_default():
tf.import_graph_def(graph_def, name='')
self.sess = tf.Session(graph=self.graph)
注: 如果采用 tf.GraphDef.FromString(open(pb_path).read()) 会出现错误: UnicodeDecodeError: 'utf-8' codec can't decode byte 0x90 in position 1: invalid start byte 需要由 open 默认的 'r' 改为 'rb'. 即: tf.GraphDef.FromString(open('pb_path', 'rb').read())<h2>3. Tensorflow - DeepLab Checkpoint ckpt 转换为 pb</h2>
Tensorflow 提供了 Checkpoint ckpt 转换为 pb 的工具 - Freezing.
训练时权重往往不是保存在文件格式,而是分散保存为 Checkpoint 文件,在初始化时,通过模型文件中的 Variable 变量Op 来从读取 Checkpoint 中的数据进行初始化.
当产品部署时,这种模型和权重分散保存的方式,有点不方便.
因此,Tensorflow 提供了 freeze_graph.py 脚本来将模型(定义的图 graph definition) 和断点集Checkpoints 整合为一个文件.freeze_graph.py 脚本所进行的操作有:
- 加载 GraphDef;
- 读取最近断点Checkpoint文件内所有的变量Variables;
- 将每个权重 Variable Op 替换为权重 Const,权重常量可以和模型一起保存在一个文件;
- 再将所有未用于 forward 推断计算的无用节点剥离去除掉;
- 最后重新保存GradpDef 到指定输出文件.
具体可见:
Tensorflow DeepLab API 中提供了将训练的 DeepLab ckpt 转换为 frozen inference graph 的脚本 - export_model.py
正如上面所说,一般模型训练结束能够得到下面的断点 Checkpoint 文件:
- model.ckpt-${CHECKPOINT_NUMBER}.data-00000-of-00001,
- model.ckpt-${CHECKPOINT_NUMBER}.index
- model.ckpt-${CHECKPOINT_NUMBER}.meta
对此,可以根据训练参数设置,和特定的断点 Checkpoint 文件,运行下面的命令行,得到导出的 frozen graph:
# From tensorflow/models/research/
# Assume all checkpoint files share the same path prefix ${CHECKPOINT_PATH}
.
python deeplab/export_model.py \
--checkpoint_path=${CHECKPOINT_PATH}/model.ckpt-${CHECKPOINT_NUMBER}\
--export_path=${OUTPUT_DIR}/frozen_inference_graph.pb
--model_variant="xception_65"
--atrous_rates=6
--atrous_rates=12
--atrous_rates=18
--output_stride=16
--decoder_output_stride=4
其中,假设训练时 train.py
的参数设置为:
python train.py --logtostderr \
--training_number_of_steps=30000
--train_split="train"
--model_variant="xception_65"
--atrous_rates=6
--atrous_rates=12
--atrous_rates=18
--output_stride=16
--decoder_output_stride=4
--train_crop_size=513
--train_crop_size=513
--train_batch_size=1
--dataset="pascal_voc_seg"
--tf_initial_checkpoint=./output/model.ckpt
--train_logdir=./output/
--dataset_dir=./datasets/tfrecord/