TensorFlow 提供了 inspect_checkpoint.py 用于查看 Checkpoint 文件中的变量名及对应的变量值.

# inpect_checkpoint.py
"""
A simple script for inspect checkpoint files.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("file_name", "", "Checkpoint filename")
tf.app.flags.DEFINE_string("tensor_name", "", 
                           "Name of the tensor to inspect")

def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """
  打印 checkpoint 文件中的 tensors.
  
 如果未指定 `tensor_name`,则打印 checkpoint 文件中的 tensor names 和 shapes.
 如果指定了 `tensor_name`,则打印该 tensor 的内容.
 Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    reader = tf.train.NewCheckpointReader(file_name)
    if not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed with SNAPPY.")


def main(unused_argv):
  if not FLAGS.file_name:
    print("Usage: inspect_checkpoint "
          "--file_name=checkpoint_file_name "
          "[--tensor_name=tensor_to_print]")
    sys.exit(1)
  else:
    print_tensors_in_checkpoint_file(FLAGS.file_name, 
                                     FLAGS.tensor_name)

if __name__ == "__main__":
  tf.app.run()
import tensorflow as tf

checkpoint_file = './outputs/model.ckpt-20000'

reader = tf.train.NewCheckpointReader(checkpoint_file)
#print(reader.debug_string().decode("utf-8"))
var_to_shape_map = reader.get_variable_to_shape_map() 
for key in var_to_shape_map: 
    print("tensor_name: ", key)   # 打印变量名
    print(reader.get_tensor(key)) # 打印变量值 
Last modification:November 26th, 2018 at 05:24 pm