原文 - Tensorflow框架实现中的“三”种图 - 知乎

图(Graph) 是 TensorFlow 用于表达计算任务的一个核心概念.

从前端(python) 描述神经网络的结构,到后端在多机和分布式系统上部署,到底层 Device(CPU、GPU、TPU)上运行,都是基于图来完成.

然而在实际使用过程中遇到了三对API,

[1] - tf.train.Saver()/saver.restore()

[2] - export_meta_graph/Import_meta_graph

[3] - tf.train.write_graph()/tf.Import_graph_def()

它们都是用于对图的保存和恢复.

同一个计算框架,为什么需要三对不同的API呢?他们保存/恢复的图在使用时又有什么区别呢?

初学的时候,常常闹不清楚他们的区别,以至常常写出了错误的程序,经过一番研究,本文中对Tensorflow中围绕Graph的核心概念进行了总结.

1. Graph

首先介绍一下关于 TensorFlow 中 Graph 和它的序列化表示 Graph_def.

在 TensorFlow 官方文档中,Graph 被定义为 “一些 Operation 和 Tensor 的集合”.

例如表达如下的一个计算的 python代码,

import tensorflow as tf

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
d = a*b+c
e = d*2

就会生成相应的一张图,在Tensorboard中看到的图大概如图:

其中,每一个圆圈表示一个 Operation(输入处为Placeholder),椭圆到椭圆的边为Tensor,箭头的指向表示了这张图 Operation 输入输出 Tensor 的传递关系.

在真实的 TensorFlow 运行中,Python 构建的“图Graph” 并不是启动一个 Session 之后始终不变的. 因为 TensorFlow 在运行时,真实的计算会被分配到多CPUs,或 GPUs,或 ARM 等,以进行高性能/能效的计算. 单纯使用 Python 肯定是无法有效完成的.

实际上,TensorFlow 是首先将 python 代码所描绘的图转换(即“序列化”)成 Protocol Buffer,再通过 C/C++/CUDA 运行 Protocol Buffer 所定义的图. (Protocol Buffer 可参考:https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/).

2. GraphDef

从 python Graph中序列化出来的图就叫做 GraphDef (这是一种不严格的说法,先这样进行理解).

GraphDef 又是由许多叫做 NodeDef 的 Protocol Buffer 组成. 在概念上 NodeDef 与(Python Graph 中的) Operation 相对应.

如下就是 GraphDef 的 ProtoBuf,由许多node 组成的图表示. 这是与上文 Python 图对应的 GraphDef:

node {
  name: "Placeholder"    # 注:这是一个叫做 "Placeholder" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "Placeholder_1" # 注:这是一个叫做 "Placeholder_1" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "mul"          # 注:一个 Mul(乘法)操作
  op: "Mul"
  input: "Placeholder" # 使用上面的node(即Placeholder和Placeholder_1)
  input: "Placeholder_1" # 作为这个Node的输入
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

以上三个 NodeDef 定义了两个 Placeholde r和一个Multiply.

Placeholder 通过 attr(attribute的缩写)来定义数据类型和 Tensor 的形状.

Multiply 通过 input 属性定义了两个 placeholder 作为其输入.

无论是 Placeholder 还是 Multiply 都没有关于输出(output)的信息.

其实 Tensorflow 中都是通过 Input 来定义 Node 之间的连接信息.

那么既然 tf.Operation 的序列化 ProtoBuf 是 NodeDef那么 tf.Variable 呢?在这个 GraphDef 中只有网络的连接信息,却没有任何 Variables呀?

没错,Graphdef 中不保存任何 Variable 的信息,所以如果从 graph_def 来构建图并恢复训练的话,是不能成功的.

如,

with tf.Graph().as_default() as graph:
  tf.import_graph_def("graph_def_path")
  saver= tf.train.Saver()
  with tf.Session() as sess:
    tf.trainable_variables()

其中 tf.trainable_variables() 只会返回一个空的list. tf.train.Saver() 也会报告 no variables to save.

然而,在实际线上 inference 中,通常就是使用 GraphDef. 但,GraphDef 中连 Variable都没有,怎么存储 weight 呢?

原来, GraphDef 虽然不能保存 Variable,但可以保存 Constant. 通过 tf.constant 将 weight 直接存储在 NodeDef,tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具来自动的将图中的 Variable 替换成 constant 存储在 GraphDef 里面,并将该图导出为 Proto.

https://www.tensorflow.org/extend/tool_developers/https://www.tensorflow.org/mobile/prepare_models

tf.train.write_graph()/tf.Import_graph_def() 就是用来进行 GraphDef 读写的API. 那么,我们怎么才能从序列化的图中,得到 Variables呢?这就要学习下一个重要概念,MetaGraph.

3. MetaGraph

Meta graph 的官方解释是:一个 Meta Graph 由一个计算图和其相关的元数据构成, 其包含了用于继续训练,实施评估和(在已训练好的的图上)做前向推断的信息.

A MetaGraph consists of both a computational graph and its associated metadata.

A MetaGraph contains the information required to continue training, perform evaluation, or run inference on a previously trained graph.

From https://www.tensorflow.org/versions/r1.1/programmers_guide/

这一段看的云里雾里,不过这篇文章(https://www.tensorflow.org/versions/r1.1/programmers_guide/meta_graph)进一步解释说,Meta Graph在具体实现上就是一个 MetaGraphDef (同样是由 Protocol Buffer来定义的). 其包含了四种主要的信息,根据Tensorflow官网,这四种 Protobuf 分别是:

[1] - MetaInfoDef,存一些元信息(比如版本和其他用户信息)

[2] - GraphDef, MetaGraph 的核心内容之一

[3] - SaverDef,图的Saver信息(比如最多同时保存的check-point数量,需保存的Tensor名字等,但并不保存Tensor中的实际内容)

[4] - CollectionDef,任何需要特殊注意的 Python 对象,需要特殊的标注以方便import_meta_graph 后取回(如 train_op, prediction 等等)

在以上四种 ProtoBuf 里面,[1] 和 [3] 都比较容易理解,[2] 刚刚总结过. 这里特别要讲一下 Collection(CollectionDef是对应的ProtoBuf).

TensorFlow 中并没有一个官方的定义说 collection 是什么. 简单的理解,它就是为了方别用户对图中的操作和变量进行管理,而创建的一个概念. 它可以说是一种“集合”,通过一个 key (string类型) 来对一组 Python 对象进行命名的集合. 这个 key 既可以是TensorFlow 在内部定义的一些 key,也可以是用户自己定义的名字(string).

TensorFlow 内部定义了许多标准 Key,全部定义在了 tf.GraphKeys 这个类中. 其中有一些常用的,tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES 等等. tf.trainable_variables()tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 是等价的;tf.global_variables()tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 是等价的.

集合类型集合内容使用环境
tf.GraphKeys.VARIABLES神经网络参数
tf.GraphKeys.TRAINABLE_VARIABLES模型训练,生产模型可视化内容
tf.GraphKeys.SUMMARIES日志生成相关张量计算可视化
tf.GraphKeys.QUEUE_RUNNER处理输入的QueueRunner输入处理
tf.MOVING_AVERAGE_BARIABLES所有计算了滑动平均值的变量计算变量滑动平均值

对于用户定义的 key,举一个例子, 例如:

pred = model_network(X)
loss=tf.reduce_mean(…, pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)

这样一段 Tensorflow程序,用户希望特别关注 pred, loss, train_op 这几个操作,那么就可以使用如下代码,将这几个变量加入到 collection 中去. (假设我们将其命名为 “training_collection”)

tf.add_to_collection("training_collection", pred)
tf.add_to_collection("training_collection", loss)
tf.add_to_collection("training_collection", train_op)

并且可以通过 Train_collect = tf.get_collection(“training_collection”) 得到一个python list,其中的内容就是pred, loss, train_op 的 Tensor. 这通常是为了在一个新的 session 中打开这张图时,方便我们获取想要的操作. 比如我们可以直接通过 get_collection() 得到 train_op,然后通过 sess.run(train_op) 来开启一段训练,而无需重新构建 lossoptimizer.

通过 export_meta_graph 保存图,并且通过 add_to_collectiontrain_op 加入到 collection 中:

with tf.Session() as sess:
  pred = model_network(X)
  loss=tf.reduce_mean(…,pred, …)
  train_op=tf.train.AdamOptimizer(lr).minimize(loss)
  tf.add_to_collection("training_collection", train_op)
  Meta_graph_def = 
      tf.train.export_meta_graph(tf.get_default_graph(), 'my_graph.meta')

通过 import_meta_graph 将图恢复(同时初始化为本 Session的 default 图),并且通过 get_collection 重新获得 train_op,以及通过 train_op 来开始一段训练(sess.run() ).

with tf.Session() as new_sess:
  tf.train.import_meta_graph('my_graph.meta')
  train_op = tf.get_collection("training_collection")[0]
  new_sess.run(train_op)

更多的代码例子可以在这篇文档(https://www.tensorflow.org/api_guides/python/meta_graph)中的 Import a MetaGraph 章节中看到.

那么,Meta Graph 中恢复构建的图可以被训练吗?是可以的. TensorFlow 的官方文档 https://www.tensorflow.org/api_guides/python/meta_graph 说明了使用方法. 这里要特殊的说明一下,Meta Graph 中虽然包含 Variable 的信息,却没有 Variable 的实际值. 所以, Meta Graph 中恢复的图,其训练是从随机初始化的值开始的. 训练中 Variable的实际值都保存在 checkpoint 中,如果要从之前训练的状态继续恢复训练,就要从checkpoint 中 restore. 进一步读一下 Export Meta Graph 的代码,可以看到,事实上variables 并没有被 export 到 meta_graph 中.

https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/training/saver.py (1872行)

https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/meta_graph.py (829,845行)

export_meta_graph/Import_meta_graph 就是用来进行 Meta Graph 读写的API.

tf.train.saver.save() 在保存checkpoint的同时也会保存Meta Graph. 但是在恢复图时,tf.train.saver.restore() 只恢复 Variable,如果要从MetaGraph恢复图,需要使用 import_meta_graph. 这是其实为了方便用户,有时我们不需要从MetaGraph恢复的图,而是需要在 python 中构建神经网络图,并恢复对应的 Variable.

4. Checkpoint

Checkpoint 里全面保存了训练某时间截面的信息,包括参数,超参数,梯度等等. tf.train.Saver()/saver.restore() 则能够完完整整保存和恢复神经网络的训练.

Checkpoint 分为两个文件保存Variable的二进制信息. ckpt 文件保存了Variable的二进制信息,index 文件用于保存 ckpt 文件中对应 Variable 的偏移量信息.

5. 总结

TensorFlow 三种 API 所保存和恢复的图是不一样的.

这三种图是从 TensorFlow 框架设计的角度出发而定义的.

但是从用户的角度来看,TensorFlow 文档的写作难免有些云里雾里,弄不清他们的区别.需要读一读Tensorflow的代码,做一些实验来进行辨析.

简而言之,TensorFlow 在前端 Python 中构建图,并且通过将该图序列化到 ProtoBuf GraphDef,以方便在后端运行. 在这个过程中,图的保存、恢复和运行都通过 ProtoBuf 来实现. GraphDefMetaGraph,以及VariableCollectionSaver 等都有对应的 ProtoBuf 定义. ProtoBuf 的定义也决定了用户能对图进行的操作. 例如用户只能找到 Node的前一个Node,却无法得知自己的输出会由哪个Node接收.

Last modification:November 26th, 2018 at 11:24 am