使用 TensorFlow 和 TF-Slim 时,对于图片数据集往往需要将数据集转换为 TFRecords 文件.
这里根据 TF-Slim 里的 flowers
的 TFRecords 创建,学习 TFRecords 的创建与读取.
以阿里天池竞赛中的服装属性识别的 coat_length_labels
数据集为例.
coat_length_labels = ['Invisible',
'High Waist Length',
'Regular Length',
'Long Length',
'Micro Length',
'Knee Length',
'Midi Length',
'Ankle&Floor Length']
1. 创建 TFRecords 文件
#!--*-- coding:utf-8 --*--
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import math
from sklearn.model_selection import train_test_split
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction=0.2
# config.gpu_options.allow_growth = True
# 将数据集创建为 _NUM_SHARDS 个 tfrecords 文件.
_NUM_SHARDS = 5
def write_label_file(labels_to_class_names, dataset_dir, filename='labels.txt'):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name))
def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def float_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width), } ) )
class ImageReader(object):
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def get_dataset_filename(tfrecords_dir, split_name, shard_id):
output_filename = 'coat_length_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(tfrecords_dir, output_filename)
def convert_dataset_to_TFRecords(split_name, filenames, class_ids, tfrecords_dir):
assert split_name in ['train', 'valid']
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = get_dataset_filename(tfrecords_dir,
split_name,
shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for idx in range(start_ndx, end_ndx):
sys.stdout.write('\r>> Converting image %d/%d shard %d' %
(idx+1, len(filenames), shard_id))
sys.stdout.flush()
# 读取图像文件
image_data = tf.gfile.FastGFile(filenames[idx], 'rb').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = class_ids[idx]
example = image_to_tfexample(image_data=image_data,
image_format=b'jpg',
height=height,
width=width,
class_id=class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def main():
print('[INFO] Converting TFRecords...')
dataset_dir = '/path/to/coat_length_datas/'
datas = open(os.path.join(dataset_dir, 'coat_length.txt')).readlines()
train_datas, valid_datas = train_test_split(datas, test_size=0.1, random_state=42)
class_names = ['Invisible', 'High Waist Length', 'Regular Length', 'Long Length',
'Micro Length', 'Knee Length', 'Midi Length', 'Ankle&Floor Length']
train_filenames = []
train_classids = []
for data in train_datas:
image_name = data.split(' ')[0]
if os.path.exists(os.path.join(dataset_dir, image_name)):
train_filenames.append(os.path.join(dataset_dir, image_name))
train_classids.append(data.split(' ')[1].strip().index('y'))
valid_filenames = []
valid_classids = []
for data in valid_datas:
image_name = data.split(' ')[0]
if os.path.exists(os.path.join(dataset_dir, image_name)):
valid_filenames.append(os.path.join(dataset_dir, image_name))
valid_classids.append(data.split(' ')[1].strip().index('y'))
tfrecords_dir = './datas/'
convert_dataset_to_TFRecords('train', train_filenames, train_classids, tfrecords_dir)
convert_dataset_to_TFRecords('valid', valid_filenames, valid_classids, tfrecords_dir)
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, tfrecords_dir)
print('[INFO] Finished converting!')
if __name__ == '__main__':
main()
2. 读取 TFRecords 文件
#!--*-- coding:utf-8 --*--
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import matplotlib.pyplot as plt
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction=0.2
# config.gpu_options.allow_growth = True
def read_TFRecords(tfrecords_list):
print('[INFO] Reading TFRecords...')
reader = tf.TFRecordReader()
queue = tf.train.string_input_producer(tfrecords_list)
_,serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={'image/encoded': tf.FixedLenFeature([], tf.string),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width':tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([], tf.int64)
}
)
image_raw = tf.image.decode_image(features['image/encoded'])
label = tf.cast(features['image/class/label'], tf.int32)
height = tf.cast(features['image/height'], tf.int64)
width = tf.cast(features['image/width'], tf.int64)
sess = tf.Session(config=config)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for idx in range(10):
img, l, h, w = sess.run([image_raw, label, height, width])
print(img.shape , type(img))
plt.imshow(img)
plt.show()
if __name__ == '__main__':
tfrecords_list = ['../datas/coat_length_train_00000-of-00005.tfrecord']
read_TFRecords(tfrecords_list)
3. 相关函数说明
3.1 tf.placeholder
palceholder
是占位符的意思.
tf.placeholder(dtype, shape=None, name=None)
# dtype - 数据类型,如 tf.float32,tf.float64 等.
# shape - 数据形状,默认是None,即一维值,也可以多维,比如,[None,3],表示列是 3,行待定.
# name - 名称.
# 函数返回值为 Tensor 类型.
用法:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction=0.2
# config.gpu_options.allow_growth = True
# 定义 placeholder
variable1 = tf.placeholder(tf.float32) # 并没有具体值
variable2 = tf.placeholder(tf.float32)
# 定义乘法运算
output = tf.multiply(variable1, variable2)
# 创建 session 执行乘法运算
with tf.Session(config=config) as sess:
# 需传入 placeholder 的具体值
print('[Output:] ', sess.run(output,
feed_dict = {variable1:[12.], variable2: [19.]}))
TensorFlow 采用计算流图的设计理念,代码编程时,其首先创建静态图Graph,但并不会立即生效. 然后,启动一个 session,才是真正的运行代码.
而 tf.placeholder()
函数的作用是,在构建 graph 模型时提供占位,但并未把待输入的数据送入模型中,只是分配必要的内存等. 在建立 session 后,通过 feed_dict()
函数将数据送入占位符中.
3.2 tf.image.decode_jpeg
TensorFlow 提供了 jpeg
和 png
格式图像的编码和解码函数,进行图像的读取,如 tf.gfile.FastGFile(jpgfile,'r').read()
(tf.gfile.FastGFile(jpgfile,'rb').read()
- Python3),但读取的结果是最原始的图像,其为一个字符串,并不是解码后的图像的像素值.
TensofFlow 提供的解码函数有两个:tf.image.decode_jepg
和 tf.image.decode_png
,分别解码 jpeg
和 png
格式的图像,得到图像的像素值,可以进行图像显示等.
import matplotlib.pyplot as plt
import tensorflow as tf
image_raw_data_jpg = tf.gfile.FastGFile(jpgfile, 'r').read()
image_raw_data_png = tf.gfile.FastGFile(pngfile, 'r').read()
with tf.Session() as sess:
# 解码后的结果为张量Tensor.
img_data_jpg = tf.image.decode_jpeg(image_raw_data_jpg)
img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.uint8)
img_data_png = tf.image.decode_png(image_raw_data_png)
img_data_png = tf.image.convert_image_dtype(img_data_png, dtype=tf.uint8)
# 打印解码后的三维矩阵
print(img_data_jpg.eval())
print(img_data_png.eval())
# 显示图片
plt.subplot(1, 2, 1)
plt.imshow(img_data_jpg.eval())
plt.subplot(1, 2, 2)
plt.imshow(img_data_png.eval())
plt.show()
# 将三维矩阵形式的图像按照 jpeg 格式编码并保存为图片.
encoded_image = tf.image.encode_jpeg(img_data_jpg)
with tf.gfile.GFile('/path/to/save_jpg', 'wb') as f:
f.write(encoded_image.eval())
3.3 tf.decode_raw
tf.decode_raw
也记作 tf.io.decode_raw()
.
tf.io.decode_raw(bytes,
out_type,
little_endian=True,
name=None
)
tf.decode_raw
函数用于将 to_bytes
函数所编码的字符串类型变量重新解码,常用与数据集 TFRecords 文件中. 因为在创建 TFRecords 文件时,一般是以 to_bytes
形式保存原图片数据,即字符串格式保存.
注:需要保证数据格式与解析格式的一致.
如果原图像数据是由
tf.float64
类型再进行to_bytes
写入,则tf.decode_raw
解码时则也需要使用tf.float64
数据类型.如果不一致,会出现
Input to reshape is a tensor with xxx values, but the requested shape has xxx
的类似错误.
3.4 tf.cast
TensorFlow 数据类型转换函数,不会改变原始数据的元素值及其形状shape.
image_raw = tf.decode_raw(features['image/encoded'], tf.uint8)
image = tf.reshape(image_raw, [heights, widths, 3])
image = tf.cast(images, tf.float32)
4. Github 上的实现[转]
# -*-coding: utf-8 -*-
"""
@File : create_tfrecord.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-07-27 17:19:54
@desc : 将图片数据保存为单个tfrecord文件
"""
import tensorflow as tf
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import random
from PIL import Image
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成字符串型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成实数型的属性
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def get_example_nums(tf_records_filenames):
'''
统计tf_records图像的个数(example)个数
tf_records_filenames: tf_records文件路径
'''
nums= 0
for record in tf.python_io.tf_record_iterator(tf_records_filenames):
nums += 1
return nums
def show_image(title,image):
'''
显示图片
title: 图像标题
image: 图像的数据
'''
# plt.figure("show_image")
# print(image.dtype)
plt.imshow(image)
plt.axis('on') # 关掉坐标轴为 off
plt.title(title) # 图像题目
plt.show()
def load_labels_file(filename,labels_num=1,shuffle=False):
'''
载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,
如:test_image/1.jpg 0 2
参数:
labels_num :labels个数
shuffle :是否打乱顺序
返回值:
images type->list
labels type->list
'''
images=[]
labels=[]
with open(filename) as f:
lines_list=f.readlines()
if shuffle:
random.shuffle(lines_list)
for lines in lines_list:
line=lines.rstrip().split(' ')
label=[]
for i in range(labels_num):
label.append(int(line[i+1]))
images.append(line[0])
labels.append(label)
return images,labels
def read_image(filename, resize_height, resize_width,normalization=False):
'''
读取图片数据,默认返回的是uint8,[0,255]
参数:
filename:
resize_height:
resize_width:
normalization:是否归一化到[0.,1.0]
返回值:
返回的图片数据
'''
bgr_image = cv2.imread(filename)
if len(bgr_image.shape)==2:#若是灰度图则转为三通道
print("Warning:gray image",filename)
bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)
rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#将BGR转为RGB
# show_image(filename,rgb_image)
# rgb_image=Image.open(filename)
if resize_height>0 and resize_width>0:
rgb_image=cv2.resize(rgb_image,(resize_width,resize_height))
rgb_image=np.asanyarray(rgb_image)
if normalization:
# 不能写成:rgb_image=rgb_image/255
rgb_image=rgb_image/255.0
# show_image("src resize image",image)
return rgb_image
def get_batch_images(images,labels,batch_size,labels_nums,
one_hot=False,shuffle=False,num_threads=1):
'''
参数:
images:图像
labels:标签
batch_size:
labels_nums:标签个数
one_hot:是否将labels转为one_hot的形式
shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False
返回值:
返回batch的images和labels
'''
min_after_dequeue = 200
capacity = min_after_dequeue + 3 * batch_size
# 保证capacity必须大于min_after_dequeue参数值
if shuffle:
images_batch, labels_batch = tf.train.shuffle_batch(
[images,labels],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
num_threads=num_threads)
else:
images_batch, labels_batch = tf.train.batch(
[images,labels],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
if one_hot:
labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0)
return images_batch,labels_batch
def read_records(filename,resize_height, resize_width,type=None):
'''
解析record文件:
源文件的图像数据是RGB,uint8,[0,255],一般作为训练数据时,需要归一化到[0,1]
参数:
filename:
resize_height:
resize_width:
type:选择图像数据的返回类型
None:默认将uint8-[0,255]转为float32-[0,255]
normalization:归一化float32-[0,1]
centralization:归一化float32-[0,1],再减均值中心化
'''
# 创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.TFRecordReader()
# reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
}
)
#获得图像原始的数据
tf_image = tf.decode_raw(features['image_raw'], tf.uint8)
tf_height = features['height']
tf_width = features['width']
tf_depth = features['depth']
tf_label = tf.cast(features['label'], tf.int32)
# PS:恢复原始图像数据,reshape的大小必须与保存之前的图像shape一致,否则出错
# tf_image=tf.reshape(tf_image, [-1]) # 转换为行向量
# 设置图像的维度
tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3])
# 恢复数据后,才可以对图像进行resize_images:输入uint->输出float32
# tf_image=tf.image.resize_images(tf_image,[224, 224])
# 存储的图像类型为uint8,tensorflow训练时数据必须是tf.float32
if type is None:
tf_image = tf.cast(tf_image, tf.float32)
elif type=='normalization':# [1]若需要归一化请使用:
# 仅当输入数据是uint8,才会归一化[0,255]
# tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)
tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 归一化
elif type=='centralization':
# 若需要归一化,且中心化,假设均值为0.5,请使用:
tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化
# 这里仅仅返回图像和标签
# return tf_image, tf_height,tf_width,tf_depth,tf_label
return tf_image,tf_label
def create_records(image_dir,file, output_record_dir,
resize_height, resize_width,shuffle,log=5):
'''
实现将图像原始数据,label,长,宽等信息保存为record文件
注意:读取的图像数据默认是uint8,再转为tf的字符串型BytesList保存,解析请需要根据需要转换类型
参数:
image_dir:原始图像的目录
file:输入保存图片信息的txt文件(image_dir+file构成图片的路径)
output_record_dir:保存record文件的路径
resize_height:
resize_width:
PS:当resize_height或者resize_width=0是,不执行resize
shuffle:是否打乱顺序
log:log信息打印间隔
'''
# 加载文件,仅获取一个label
images_list, labels_list=load_labels_file(file,1,shuffle)
writer = tf.python_io.TFRecordWriter(output_record_dir)
for i, [image_name, labels] in enumerate(zip(images_list, labels_list)):
image_path=os.path.join(image_dir,images_list[i])
if not os.path.exists(image_path):
print('Err:no image',image_path)
continue
image = read_image(image_path, resize_height, resize_width)
image_raw = image.tostring()
if i%log==0 or i==len(images_list)-1:
print('------------processing:%d-th------------' % (i))
print('current image_path=%s' % (image_path),
'shape:{}'.format(image.shape),
'labels:{}'.format(labels))
# 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项
label=labels[0]
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'height': _int64_feature(image.shape[0]),
'width': _int64_feature(image.shape[1]),
'depth': _int64_feature(image.shape[2]),
'label': _int64_feature(label)
}))
writer.write(example.SerializeToString())
writer.close()
def disp_records(record_file,resize_height, resize_width,show_nums=4):
'''
解析record文件,并显示show_nums张图片,主要用于验证生成record文件是否成功
参数:
tfrecord_file: record文件路径
'''
# 读取record函数
tf_image, tf_label = read_records(record_file,
resize_height,
resize_width,
type='normalization')
# 显示前4个图片
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(show_nums):
# 在会话中取出image和label
image,label = sess.run([tf_image,tf_label])
# image = tf_image.eval()
# 直接从record解析的image是一个向量,需要reshape显示
# image = image.reshape([height,width,depth])
print('shape:{},tpye:{},labels:{}'.format(
image.shape,image.dtype,label))
# pilimg = Image.fromarray(np.asarray(image_eval_reshape))
# pilimg.show()
show_image("image:%d"%(label),image)
coord.request_stop()
coord.join(threads)
def batch_test(record_file,resize_height, resize_width):
'''
参数:
record_file: record文件路径
resize_height:
resize_width:
:PS:image_batch, label_batch一般作为网络的输入
'''
# 读取record函数
tf_image,tf_label = read_records(record_file,
resize_height,
resize_width,
type='normalization')
image_batch, label_batch= get_batch_images(tf_image,
tf_label,
batch_size=4,
labels_nums=5,
one_hot=False,
shuffle=False)
init = tf.global_variables_initializer()
with tf.Session() as sess: # 开始一个会话
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(4):
# 在会话中取出images和labels
images, labels = sess.run([image_batch, label_batch])
# 这里仅显示每个batch里第一张图片
show_image("image", images[0, :, :, :])
print('shape:{},tpye:{},labels:{}'.format(
images.shape,images.dtype,labels))
# 停止所有线程
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
resize_height = 224 # 指定存储图片高度
resize_width = 224 # 指定存储图片宽度
shuffle=True
log=5
# 产生train.record文件
image_dir='dataset/train'
train_labels = 'dataset/train.txt' # 图片路径
train_record_output = 'dataset/record/train.tfrecords'
create_records(image_dir,
train_labels,
train_record_output,
resize_height,
resize_width,shuffle,log)
train_nums=get_example_nums(train_record_output)
print("save train example nums={}".format(train_nums))
# 产生val.record文件
image_dir='dataset/val'
val_labels = 'dataset/val.txt' # 图片路径
val_record_output = 'dataset/record/val.tfrecords'
create_records(image_dir,val_labels,
val_record_output,
resize_height,
resize_width,
shuffle,log)
val_nums=get_example_nums(val_record_output)
print("save val example nums={}".format(val_nums))
# 测试显示函数
# disp_records(train_record_output,resize_height, resize_width)
batch_test(train_record_output,resize_height, resize_width)