Tensorflow 提供的 DeepLab API 训练得到的模型作预测.

<h2>1. Tensorflow - DeepLab Checkpoint</h2>

Tensorflow 提供的训练好的 DeepLab 模型,如 deeplabv3_pascal_trainval_2018_01_04.tar.gz 压缩包中包含三个文件:

  • frozen_inference_graph.pb
  • model.ckpt.data-00000-of-00001
  • model.ckpt.index

实际上, Tensorflow 采用 saver.save 保存训练断点checkpoint 时,产生了三个文件:
如:

import tensorflow as tf

sess=tf.Session()
...
saver = tf.train.Saver()
model_save_path = saver.save(sess, "output/model.ckpt", global_step=30000)
print("Model File: ", model_save_path)
  • model.ckpt-30000.data-00000-of-00001
  • model.ckpt-30000.index
  • model.ckpt-30000.meta

其中,训练的模型权重参数保存在 model.ckpt-30000.data-00000-of-00001.

<h2>2. Tensorflow - DeepLab pb 模型加载</h2>

在 Tensorflow 提供的 DeepLab API 中,模型加载的方式为:

import tensorflow as tf


class DeepLabModel(object):
    """
    加载 DeepLab 模型
    """
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 513
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, tarball_path):
        """
        Creates and loads pretrained deeplab model.
        """
        self.graph = tf.Graph()

        graph_def = None
        # Extract frozen graph from tar archive.
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = tf.GraphDef.FromString(file_handle.read())
                break

        tar_file.close()

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')

        self.sess = tf.Session(graph=self.graph)
其中,`tarball_path` 是模型压缩包,如 `deeplabv3_pascal_trainval_2018_01_04.tar.gz`.

真正加载的模型处理是:

pb_path = 'deeplabv3_pascal_trainval/frozen_inference_graph.pb'
graph_def = tf.GraphDef.FromString(open('pb_path', 'rb').read())

if graph_def is None:

raise RuntimeError('Cannot find inference graph in tar archive.')

with self.graph.as_default():

tf.import_graph_def(graph_def, name='')

self.sess = tf.Session(graph=self.graph)
</pre>

注: 如果采用 tf.GraphDef.FromString(open(pb_path).read()) 会出现错误: UnicodeDecodeError: 'utf-8' codec can't decode byte 0x90 in position 1: invalid start byte 需要由 open 默认的 'r' 改为 'rb'. 即: tf.GraphDef.FromString(open('pb_path', 'rb').read())

<h2>3. Tensorflow - DeepLab Checkpoint ckpt 转换为 pb</h2>

Tensorflow 提供了 Checkpoint ckpt 转换为 pb 的工具 - Freezing.

训练时权重往往不是保存在文件格式,而是分散保存为 Checkpoint 文件,在初始化时,通过模型文件中的 Variable 变量Op 来从读取 Checkpoint 中的数据进行初始化.
当产品部署时,这种模型和权重分散保存的方式,有点不方便.
因此,Tensorflow 提供了 freeze_graph.py 脚本来将模型(定义的图 graph definition) 和断点集Checkpoints 整合为一个文件.

freeze_graph.py 脚本所进行的操作有:

  • 加载 GraphDef;
  • 读取最近断点Checkpoint文件内所有的变量Variables;
  • 将每个权重 Variable Op 替换为权重 Const,权重常量可以和模型一起保存在一个文件;
  • 再将所有未用于 forward 推断计算的无用节点剥离去除掉;
  • 最后重新保存GradpDef 到指定输出文件.

具体可见:

Tensorflow DeepLab API 中提供了将训练的 DeepLab ckpt 转换为 frozen inference graph 的脚本 - export_model.py

正如上面所说,一般模型训练结束能够得到下面的断点 Checkpoint 文件:

  • model.ckpt-${CHECKPOINT_NUMBER}.data-00000-of-00001,
  • model.ckpt-${CHECKPOINT_NUMBER}.index
  • model.ckpt-${CHECKPOINT_NUMBER}.meta

对此,可以根据训练参数设置,和特定的断点 Checkpoint 文件,运行下面的命令行,得到导出的 frozen graph:

# From tensorflow/models/research/
# Assume all checkpoint files share the same path prefix ${CHECKPOINT_PATH}.
python deeplab/export_model.py \
    --checkpoint_path=${CHECKPOINT_PATH}/model.ckpt-${CHECKPOINT_NUMBER}\
    --export_path=${OUTPUT_DIR}/frozen_inference_graph.pb
    --model_variant="xception_65"
    --atrous_rates=6
    --atrous_rates=12
    --atrous_rates=18
    --output_stride=16
    --decoder_output_stride=4

其中,假设训练时 train.py 的参数设置为:

python train.py --logtostderr \
                        --training_number_of_steps=30000
                        --train_split="train"
                        --model_variant="xception_65"
                        --atrous_rates=6
                        --atrous_rates=12
                        --atrous_rates=18
                        --output_stride=16
                        --decoder_output_stride=4
                        --train_crop_size=513
                        --train_crop_size=513
                        --train_batch_size=1
                        --dataset="pascal_voc_seg"
                        --tf_initial_checkpoint=./output/model.ckpt
                        --train_logdir=./output/
                        --dataset_dir=./datasets/tfrecord/
Last modification:October 9th, 2018 at 09:31 am