TensorFlow 提供了多种图片数据读取的方法 - TensroFlow - 数据读取[转] 基于 TFRecord Flowers 数据集 fine-tune InceptionV1 模型.
TensorFlow 数据数据读取方法主要有:
- 直接从磁盘读取 - 在
train_op
运行会话Session 时,采用feed_dict
传递数据. 但是对于大规模数据集可能不太适用,因为需要足够的 GPU 内存来存储训练数据. - 从 CSV 文件读取 - 不适用于图片.
- 从 TFRecord 文件读取 - 将图片转化为 TensorFlow 可读取格式的 TFRecord 文件,在训练是不用再读取原始图像文件,具有更高的读取效率. 这里主要基于 TFRecord 处理大规模数据集.
虽然 TFRecord 文件的创建不如从 HDF5 格式读取数据的方式(如 Keras 采用的)直接,但这种方式更便于采用数据管道工具(data pipeline tools) 进行图片训练,比如 queue runners,coordinaors 和 supervisors,有益于训练数据流的管理.
TensorFlow 提供了 TF-Slim 封装用于 TFRecord 文件的创建与读取 - slim/datasets.
- TensorFlow - TF-Slim 使用总览 - 有关于基于 TFRecord Flowers 数据集的模型(不包括 TFRecord 文件的创建)
- TensorFlow - TF-Slim Data
- TensorFlow - TF-Slim 封装模块
<h2>1. 创建 TFRecord Flowers 数据集</h2>
Flowers 数据集下载 - Flowers Dataset
解压后的目录结构为:
flowers_photos/
|----daisy
| -------- *.jpg (633 张)
|----dandelion
| -------- *.jpg (898 张)
|----roses
| -------- *.jpg (641 张)
|----sunflowers
| -------- *.jpg (699 张)
|----tulips
| -------- *.jpg (799 张)
TensorFlow 提供了将 Flowers 数据集转换为 TFRecord 的脚本:
"""
用于数据集下载和转换.
"""
from future import absolute_import
from future import division
from future import print_function
import os
import sys
import tarfile
from six.moves import urllib
import tensorflow as tf
LABELS_FILENAME = 'labels.txt'
def int64_feature(values):
"""Returns a TF-Feature of int64s.
Args:
values: A scalar or list of values.
Returns:
A TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
A TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def float_feature(values):
"""Returns a TF-Feature of floats.
Args:
values: A scalar of list of values.
Returns:
A TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""Downloads the tarball_url
and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('r>> Downloading %s %.1f%%' % (
filename, float(count block_size) / float(total_size) 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
def write_label_file(labels_to_class_names, dataset_dir,
filename=LABELS_FILENAME):
"""
Writes a file with the list of class names.
Args:
labels_to_class_names: A map of (integer) labels to class names.
dataset_dir: The directory in which the labels file should be written.
filename: The filename where the class names are written.
"""
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%sn' % (label, class_name))
def has_labels(dataset_dir, filename=LABELS_FILENAME):
"""Specifies whether or not the dataset directory contains a label map file.
Args:
dataset_dir: The directory in which the labels file is found.
filename: The filename where the class names are written.
Returns:
True
if the labels file exists and False
otherwise.
"""
return tf.gfile.Exists(os.path.join(dataset_dir, filename))
def read_label_file(dataset_dir, filename=LABELS_FILENAME):
"""Reads the labels file and returns a mapping from ID to class name.
Args:
dataset_dir: The directory in which the labels file is found.
filename: The filename where the class names are written.
Returns:
A map from a label (integer) to class name.
"""
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'rb') as f:
lines = f.read().decode()
lines = lines.split('n')
lines = filter(None, lines)
labels_to_class_names = {}
for line in lines:
index = line.index(':')
labels_to_class_names[int(line[:index])] = line[index+1:]
return labels_to_class_names
"""
Flowers 数据集下载和转化为TFRecords 格式(TF-Example protos).
Flowers 数据集的下载,解压,读取数据,创建两个 TFRecord 数据集:训练数据集和测试数据集.
每个数据集是由 TF-Example protocol buffers 构成,每个 TF-Example protocol buffer 包含一张图片和对应的标签.
该脚本大概需要耗时一分钟.
"""
from future import absolute_import
from future import division
from future import print_function
import math
import os
import random
import sys
import tensorflow as tf
from datasets import dataset_utils
# Flowers 数据集的 URL.
_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
# 验证数据集的图片数.
_NUM_VALIDATION = 350
# Seed for repeatability.
_RANDOM_SEED = 0
# The number of shards per dataset split.
_NUM_SHARDS = 5
class ImageReader(object):
"""
用于 TensorFlow 图片编码的辅助类
"""
def __init__(self):
# 初始化解码decode RGB JPEG 格式数据的函数.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _get_filenames_and_classes(dataset_dir):
"""
返回文件名和类别名列表.
Args:
dataset_dir: 包含多个图片子路径的路径.
class names. 每个图片子路径包含 PNG 或 JPG 编码的图片.
Returns:
图片文件列表,相对于 dataset_dir
;
图片子路经列表,表示类比名字.
"""
flower_root = os.path.join(dataset_dir, 'flower_photos')
directories = []
class_names = []
for filename in os.listdir(flower_root):
path = os.path.join(flower_root, filename)
if os.path.isdir(path):
directories.append(path)
class_names.append(filename)
photo_filenames = []
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
photo_filenames.append(path)
return photo_filenames, sorted(class_names)
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
"""
将给定文件名转换为 TFRecord 格式数据集.
Args:
split_name: 数据集的名字,train 或 validation.
filenames: png 或 jpg 图片的绝对路径列表.
class_names_to_ids: 类别名字(字符串strings) 到类别 ids(整数integers ) 映射的字典.
dataset_dir: 转换后的 TFRecord 数据集所保存的路径.
"""
assert split_name in ['train', 'validation']
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
sys.stdout.write('r>> Converting image %d/%d shard %d' % (
i+1, len(filenames), shard_id))
sys.stdout.flush()
# 读取文件名数据:
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_name = os.path.basename(os.path.dirname(filenames[i]))
class_id = class_names_to_ids[class_name]
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('n')
sys.stdout.flush()
def _clean_up_temporary_files(dataset_dir):
"""
删除创建数据集时产生的临时文件.
Args:
dataset_dir: 临时文件的路径.
"""
filename = _DATA_URL.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
tf.gfile.Remove(filepath)
tmp_dir = os.path.join(dataset_dir, 'flower_photos')
tf.gfile.DeleteRecursively(tmp_dir)
def _dataset_exists(dataset_dir):
for split_name in ['train', 'validation']:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
def run(dataset_dir):
"""
运行数据集下载和转换.
Args:
dataset_dir: 数据集所在的路径.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
if _dataset_exists(dataset_dir):
print('Dataset files already exist. Exiting without re-creating them.')
return
# 如果已经下载解压过 Flowers 数据集,可以跳过此步.
# dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
# 数据集分为:train 和 test:
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_VALIDATION:]
validation_filenames = photo_filenames[:_NUM_VALIDATION]
# 首先, 分别转换 training 和 validation 数据集.
_convert_dataset('train', training_filenames, class_names_to_ids, dataset_dir)
_convert_dataset('validation', validation_filenames, class_names_to_ids, dataset_dir)
# 最后, 写入标签label 文件:
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
# 会自动删除 flower_photos.tgz 和 flower_photo 文件夹.
_clean_up_temporary_files(dataset_dir)
print('nFinished converting the Flowers dataset!')
if name == '__main__':
dataset_dir = '/path/to/flower_photos/'
run(dataset_dir)
print('Done.')
输出类似于如下:
<h2>2. Fine-tune InceptionV1 模型</h2>
TensorFlow - TF-Slim 使用总览 有相关介绍.
#!/usr/bin/python
# -- coding: utf-8 --
"""
Provides data for the flowers dataset.
"""
from future import absolute_import
from future import division
from future import print_function
import os
import tensorflow as tf
slim = tf.contrib.slim
import dataset_utils
_FILE_PATTERN = 'flowers_%s_*.tfrecord'
SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
_NUM_CLASSES = 5
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""
获取数据集元组,以读取 flowers 数据.
Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: 数据集路径.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A Dataset
namedtuple.
Raises:
ValueError: if split_name
is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
- train.py
#!/usr/bin/python
# -- coding: utf-8 --
import os
import flowers
from nets import inception
from preprocessing import inception_preprocessing
import tensorflow as tf
import tensorflow.contrib.slim as slim
image_size = inception.inception_v1.default_image_size
flowers_data_dir = '/path/to/flower/tfrecords'
checkpoints_dir = '/path/to/flower/checkpoints'
train_dir = '/path/to/flower/outputs'
def load_batch(dataset, batch_size=32, height=299, width=299, is_training=False):
"""
加载单个 bacth 的数据.
Args:
dataset: 待加载数据.
batch_size: batch 内图片数量.
height: 预处理后的每张图片的 height.
width: 预处理后的每张图片的 width.
is_training: 当前数据是否处于 training 还是 evaluating.
Returns:
images: [batch_size, height, width, 3] 大小的 Tensor, 预处理后的图片样本.
images_raw: [batch_size, height, width, 3] 大小的 Tensor, 用于可视化的图片样本.
labels: [batch_size] 大小的 Tensor, 其值范围为 [0,dataset.num_classes].
"""
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset, common_queue_capacity=32, common_queue_min=8)
image_raw, label = data_provider.get(['image', 'label'])
# Inception 的图片预处理.
image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training)
# 预处理图片的可视化.
image_raw = tf.expand_dims(image_raw, 0)
image_raw = tf.image.resize_images(image_raw, [height, width])
image_raw = tf.squeeze(image_raw)
# Batch 化.
images, images_raw, labels = tf.train.batch(
[image, image_raw, label],batch_size=batch_size,
num_threads=1, capacity=2 * batch_size)
return images, images_raw, labels
def get_init_fn():
"""
训练热身函数.
权重参数初始化.
"""
checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"] #原输出层
# finetune 时更改原输出层,初始化权重时,不更新输出层的权重参数
exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
variables_to_restore = []
for var in slim.get_model_variables():
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
break
else:
variables_to_restore.append(var)
return slim.assign_from_checkpoint_fn(
os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
variables_to_restore)
with tf.Graph().as_default():
tf.logging.set_verbosity(tf.logging.INFO)
dataset = flowers.get_split('train', flowers_data_dir)
images, _, labels = load_batch(dataset, height=image_size, width=image_size)
# 模型创建,采用默认的arg scope 配置 batch norm 参数.
with slim.arg_scope(inception.inception_v1_arg_scope()):
logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)
# 设定 loss 函数:
one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
slim.losses.softmax_cross_entropy(logits, one_hot_labels)
total_loss = slim.losses.get_total_loss()
# 创建 summaries 以可视化训练过程:
tf.summary.scalar('losses/Total Loss', total_loss)
# 设定 optimizer,创建 train op:
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = slim.learning.create_train_op(total_loss, optimizer)
# 开始训练:
final_loss = slim.learning.train(train_op,
logdir=train_dir,
log_every_n_steps=10,
init_fn=get_init_fn(),
number_of_steps=3000,
save_summaries_secs=600,
save_interval_secs=1200)
print('Finished training. Last batch loss %f' % final_loss)
- test.py
#!/usr/bin/python
# -- coding: utf-8 --
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.contrib import slim
from nets import inception
import flowers
from preprocessing import inception_preprocessing
image_size = inception.inception_v1.default_image_size
batch_size = 30
flowers_data_dir = '/path/to/flower/tfrecords'
train_dir = '/path/to/flower/outputs'
with tf.Graph().as_default():
tf.logging.set_verbosity(tf.logging.INFO)
dataset = flowers.get_split('validation', flowers_data_dir)
images, images_raw, labels = load_batch(dataset, height=image_size, width=image_size)
# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(inception.inception_v1_arg_scope()):
logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)
probabilities = tf.nn.softmax(logits)
checkpoint_path = tf.train.latest_checkpoint(train_dir)
init_fn = slim.assign_from_checkpoint_fn(checkpoint_path,
slim.get_variables_to_restore())
with tf.Session() as sess:
with slim.queues.QueueRunners(sess):
sess.run(tf.initialize_local_variables())
init_fn(sess)
np_probabilities, np_images_raw, np_labels = sess.run([probabilities, images_raw, labels])
for i in range(batch_size):
image = np_images_raw[i, :, :, :]
true_label = np_labels[i]
predicted_label = np.argmax(np_probabilities[i, :])
predicted_name = dataset.labels_to_names[predicted_label]
true_name = dataset.labels_to_names[true_label]
plt.figure()
plt.imshow(image.astype(np.uint8))
plt.title('Ground Truth: [%s], Prediction [%s]' % (true_name, predicted_name))
plt.axis('off')
plt.show()
print('Done.')
<h2>3. Related</h2>
[1] - tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)
[2] - tensorflowxun训练自己的数据集之从tfrecords读取数据
[3] - TensorFlow高效读取数据的方法