使用 TensorFlow 和 TF-Slim 时,对于图片数据集往往需要将数据集转换为 TFRecords 文件.

这里根据 TF-Slim 里的 flowers 的 TFRecords 创建,学习 TFRecords 的创建与读取.

以阿里天池竞赛中的服装属性识别的 coat_length_labels 数据集为例.

coat_length_labels = ['Invisible', 'High Waist Length', 'Regular Length', 'Long Length', 'Micro Length', 'Knee Length', 'Midi Length', 'Ankle&Floor Length']

1. 创建 TFRecords 文件

#!--*-- coding:utf-8 --*-- import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import sys import math from sklearn.model_selection import train_test_split import tensorflow as tf config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction=0.2 # config.gpu_options.allow_growth = True # 将数据集创建为 _NUM_SHARDS 个 tfrecords 文件. _NUM_SHARDS = 5 def write_label_file(labels_to_class_names, dataset_dir, filename='labels.txt'): labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'w') as f: for label in labels_to_class_names: class_name = labels_to_class_names[label] f.write('%d:%s\n' % (label, class_name)) def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def float_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(float_list=tf.train.FloatList(value=values)) def image_to_tfexample(image_data, image_format, height, width, class_id): return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': bytes_feature(image_data), 'image/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), 'image/height': int64_feature(height), 'image/width': int64_feature(width), } ) ) class ImageReader(object): def __init__(self): # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) def read_image_dims(self, sess, image_data): image = self.decode_jpeg(sess, image_data) return image.shape[0], image.shape[1] def decode_jpeg(self, sess, image_data): image = sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3 assert image.shape[2] == 3 return image def get_dataset_filename(tfrecords_dir, split_name, shard_id): output_filename = 'coat_length_%s_%05d-of-%05d.tfrecord' % ( split_name, shard_id, _NUM_SHARDS) return os.path.join(tfrecords_dir, output_filename) def convert_dataset_to_TFRecords(split_name, filenames, class_ids, tfrecords_dir): assert split_name in ['train', 'valid'] num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) with tf.Graph().as_default(): image_reader = ImageReader() with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS): output_filename = get_dataset_filename(tfrecords_dir, split_name, shard_id) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: start_ndx = shard_id * num_per_shard end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for idx in range(start_ndx, end_ndx): sys.stdout.write('\r>> Converting image %d/%d shard %d' % (idx+1, len(filenames), shard_id)) sys.stdout.flush() # 读取图像文件 image_data = tf.gfile.FastGFile(filenames[idx], 'rb').read() height, width = image_reader.read_image_dims(sess, image_data) class_id = class_ids[idx] example = image_to_tfexample(image_data=image_data, image_format=b'jpg', height=height, width=width, class_id=class_id) tfrecord_writer.write(example.SerializeToString()) sys.stdout.write('\n') sys.stdout.flush() def main(): print('[INFO] Converting TFRecords...') dataset_dir = '/path/to/coat_length_datas/' datas = open(os.path.join(dataset_dir, 'coat_length.txt')).readlines() train_datas, valid_datas = train_test_split(datas, test_size=0.1, random_state=42) class_names = ['Invisible', 'High Waist Length', 'Regular Length', 'Long Length', 'Micro Length', 'Knee Length', 'Midi Length', 'Ankle&Floor Length'] train_filenames = [] train_classids = [] for data in train_datas: image_name = data.split(' ')[0] if os.path.exists(os.path.join(dataset_dir, image_name)): train_filenames.append(os.path.join(dataset_dir, image_name)) train_classids.append(data.split(' ')[1].strip().index('y')) valid_filenames = [] valid_classids = [] for data in valid_datas: image_name = data.split(' ')[0] if os.path.exists(os.path.join(dataset_dir, image_name)): valid_filenames.append(os.path.join(dataset_dir, image_name)) valid_classids.append(data.split(' ')[1].strip().index('y')) tfrecords_dir = './datas/' convert_dataset_to_TFRecords('train', train_filenames, train_classids, tfrecords_dir) convert_dataset_to_TFRecords('valid', valid_filenames, valid_classids, tfrecords_dir) labels_to_class_names = dict(zip(range(len(class_names)), class_names)) write_label_file(labels_to_class_names, tfrecords_dir) print('[INFO] Finished converting!') if __name__ == '__main__': main()

2. 读取 TFRecords 文件

#!--*-- coding:utf-8 --*-- import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import sys import matplotlib.pyplot as plt import tensorflow as tf config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction=0.2 # config.gpu_options.allow_growth = True def read_TFRecords(tfrecords_list): print('[INFO] Reading TFRecords...') reader = tf.TFRecordReader() queue = tf.train.string_input_producer(tfrecords_list) _,serialized_example = reader.read(queue) features = tf.parse_single_example( serialized_example, features={'image/encoded': tf.FixedLenFeature([], tf.string), 'image/height': tf.FixedLenFeature([], tf.int64), 'image/width':tf.FixedLenFeature([], tf.int64), 'image/class/label': tf.FixedLenFeature([], tf.int64) } ) image_raw = tf.image.decode_image(features['image/encoded']) label = tf.cast(features['image/class/label'], tf.int32) height = tf.cast(features['image/height'], tf.int64) width = tf.cast(features['image/width'], tf.int64) sess = tf.Session(config=config) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for idx in range(10): img, l, h, w = sess.run([image_raw, label, height, width]) print(img.shape , type(img)) plt.imshow(img) plt.show() if __name__ == '__main__': tfrecords_list = ['../datas/coat_length_train_00000-of-00005.tfrecord'] read_TFRecords(tfrecords_list)

3. 相关函数说明

3.1 tf.placeholder

palceholder 是占位符的意思.

tf.placeholder(dtype, shape=None, name=None) # dtype - 数据类型,如 tf.float32,tf.float64 等. # shape - 数据形状,默认是None,即一维值,也可以多维,比如,[None,3],表示列是 3,行待定. # name - 名称. # 函数返回值为 Tensor 类型.

用法:

import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import tensorflow as tf config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction=0.2 # config.gpu_options.allow_growth = True # 定义 placeholder variable1 = tf.placeholder(tf.float32) # 并没有具体值 variable2 = tf.placeholder(tf.float32) # 定义乘法运算 output = tf.multiply(variable1, variable2) # 创建 session 执行乘法运算 with tf.Session(config=config) as sess: # 需传入 placeholder 的具体值 print('[Output:] ', sess.run(output, feed_dict = {variable1:[12.], variable2: [19.]}))

TensorFlow 采用计算流图的设计理念,代码编程时,其首先创建静态图Graph,但并不会立即生效. 然后,启动一个 session,才是真正的运行代码.

tf.placeholder() 函数的作用是,在构建 graph 模型时提供占位,但并未把待输入的数据送入模型中,只是分配必要的内存等. 在建立 session 后,通过 feed_dict() 函数将数据送入占位符中.

3.2 tf.image.decode_jpeg

TensorFlow 提供了 jpegpng 格式图像的编码和解码函数,进行图像的读取,如 tf.gfile.FastGFile(jpgfile,'r').read() (tf.gfile.FastGFile(jpgfile,'rb').read() - Python3),但读取的结果是最原始的图像,其为一个字符串,并不是解码后的图像的像素值.

TensofFlow 提供的解码函数有两个:tf.image.decode_jepgtf.image.decode_png,分别解码 jpegpng 格式的图像,得到图像的像素值,可以进行图像显示等.

import matplotlib.pyplot as plt import tensorflow as tf image_raw_data_jpg = tf.gfile.FastGFile(jpgfile, 'r').read() image_raw_data_png = tf.gfile.FastGFile(pngfile, 'r').read() with tf.Session() as sess: # 解码后的结果为张量Tensor. img_data_jpg = tf.image.decode_jpeg(image_raw_data_jpg) img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.uint8) img_data_png = tf.image.decode_png(image_raw_data_png) img_data_png = tf.image.convert_image_dtype(img_data_png, dtype=tf.uint8) # 打印解码后的三维矩阵 print(img_data_jpg.eval()) print(img_data_png.eval()) # 显示图片 plt.subplot(1, 2, 1) plt.imshow(img_data_jpg.eval()) plt.subplot(1, 2, 2) plt.imshow(img_data_png.eval()) plt.show() # 将三维矩阵形式的图像按照 jpeg 格式编码并保存为图片. encoded_image = tf.image.encode_jpeg(img_data_jpg) with tf.gfile.GFile('/path/to/save_jpg', 'wb') as f: f.write(encoded_image.eval())

3.3 tf.decode_raw

tf.decode_raw 也记作 tf.io.decode_raw().

tf.io.decode_raw(bytes, out_type, little_endian=True, name=None )

tf.decode_raw函数用于将 to_bytes 函数所编码的字符串类型变量重新解码,常用与数据集 TFRecords 文件中. 因为在创建 TFRecords 文件时,一般是以 to_bytes 形式保存原图片数据,即字符串格式保存.

注:需要保证数据格式与解析格式的一致.

如果原图像数据是由 tf.float64 类型再进行 to_bytes 写入,则tf.decode_raw解码时则也需要使用 tf.float64 数据类型.

如果不一致,会出现 Input to reshape is a tensor with xxx values, but the requested shape has xxx 的类似错误.

3.4 tf.cast

TensorFlow 数据类型转换函数,不会改变原始数据的元素值及其形状shape.

image_raw = tf.decode_raw(features['image/encoded'], tf.uint8) image = tf.reshape(image_raw, [heights, widths, 3]) image = tf.cast(images, tf.float32)

4. Github 上的实现[转]

PanJinquan - create_tf_record.py

# -*-coding: utf-8 -*- """ @File : create_tfrecord.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2018-07-27 17:19:54 @desc : 将图片数据保存为单个tfrecord文件 """ import tensorflow as tf import numpy as np import os import cv2 import matplotlib.pyplot as plt import random from PIL import Image def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字符串型的属性 def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) # 生成实数型的属性 def float_list_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def get_example_nums(tf_records_filenames): ''' 统计tf_records图像的个数(example)个数 tf_records_filenames: tf_records文件路径 ''' nums= 0 for record in tf.python_io.tf_record_iterator(tf_records_filenames): nums += 1 return nums def show_image(title,image): ''' 显示图片 title: 图像标题 image: 图像的数据 ''' # plt.figure("show_image") # print(image.dtype) plt.imshow(image) plt.axis('on') # 关掉坐标轴为 off plt.title(title) # 图像题目 plt.show() def load_labels_file(filename,labels_num=1,shuffle=False): ''' 载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2, 如:test_image/1.jpg 0 2 参数: labels_num :labels个数 shuffle :是否打乱顺序 返回值: images type->list labels type->list ''' images=[] labels=[] with open(filename) as f: lines_list=f.readlines() if shuffle: random.shuffle(lines_list) for lines in lines_list: line=lines.rstrip().split(' ') label=[] for i in range(labels_num): label.append(int(line[i+1])) images.append(line[0]) labels.append(label) return images,labels def read_image(filename, resize_height, resize_width,normalization=False): ''' 读取图片数据,默认返回的是uint8,[0,255] 参数: filename: resize_height: resize_width: normalization:是否归一化到[0.,1.0] 返回值: 返回的图片数据 ''' bgr_image = cv2.imread(filename) if len(bgr_image.shape)==2:#若是灰度图则转为三通道 print("Warning:gray image",filename) bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR) rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)#将BGR转为RGB # show_image(filename,rgb_image) # rgb_image=Image.open(filename) if resize_height>0 and resize_width>0: rgb_image=cv2.resize(rgb_image,(resize_width,resize_height)) rgb_image=np.asanyarray(rgb_image) if normalization: # 不能写成:rgb_image=rgb_image/255 rgb_image=rgb_image/255.0 # show_image("src resize image",image) return rgb_image def get_batch_images(images,labels,batch_size,labels_nums, one_hot=False,shuffle=False,num_threads=1): ''' 参数: images:图像 labels:标签 batch_size: labels_nums:标签个数 one_hot:是否将labels转为one_hot的形式 shuffle:是否打乱顺序,一般train时shuffle=True,验证时shuffle=False 返回值: 返回batch的images和labels ''' min_after_dequeue = 200 capacity = min_after_dequeue + 3 * batch_size # 保证capacity必须大于min_after_dequeue参数值 if shuffle: images_batch, labels_batch = tf.train.shuffle_batch( [images,labels], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue, num_threads=num_threads) else: images_batch, labels_batch = tf.train.batch( [images,labels], batch_size=batch_size, capacity=capacity, num_threads=num_threads) if one_hot: labels_batch = tf.one_hot(labels_batch, labels_nums, 1, 0) return images_batch,labels_batch def read_records(filename,resize_height, resize_width,type=None): ''' 解析record文件: 源文件的图像数据是RGB,uint8,[0,255],一般作为训练数据时,需要归一化到[0,1] 参数: filename: resize_height: resize_width: type:选择图像数据的返回类型 None:默认将uint8-[0,255]转为float32-[0,255] normalization:归一化float32-[0,1] centralization:归一化float32-[0,1],再减均值中心化 ''' # 创建文件队列,不限读取的数量 filename_queue = tf.train.string_input_producer([filename]) # create a reader from file queue reader = tf.TFRecordReader() # reader从文件队列中读入一个序列化的样本 _, serialized_example = reader.read(filename_queue) # get feature from serialized example # 解析符号化的样本 features = tf.parse_single_example( serialized_example, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([], tf.int64) } ) #获得图像原始的数据 tf_image = tf.decode_raw(features['image_raw'], tf.uint8) tf_height = features['height'] tf_width = features['width'] tf_depth = features['depth'] tf_label = tf.cast(features['label'], tf.int32) # PS:恢复原始图像数据,reshape的大小必须与保存之前的图像shape一致,否则出错 # tf_image=tf.reshape(tf_image, [-1]) # 转换为行向量 # 设置图像的维度 tf_image=tf.reshape(tf_image, [resize_height, resize_width, 3]) # 恢复数据后,才可以对图像进行resize_images:输入uint->输出float32 # tf_image=tf.image.resize_images(tf_image,[224, 224]) # 存储的图像类型为uint8,tensorflow训练时数据必须是tf.float32 if type is None: tf_image = tf.cast(tf_image, tf.float32) elif type=='normalization':# [1]若需要归一化请使用: # 仅当输入数据是uint8,才会归一化[0,255] # tf_image = tf.image.convert_image_dtype(tf_image, tf.float32) tf_image = tf.cast(tf_image, tf.float32) * (1. / 255.0) # 归一化 elif type=='centralization': # 若需要归一化,且中心化,假设均值为0.5,请使用: tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化 # 这里仅仅返回图像和标签 # return tf_image, tf_height,tf_width,tf_depth,tf_label return tf_image,tf_label def create_records(image_dir,file, output_record_dir, resize_height, resize_width,shuffle,log=5): ''' 实现将图像原始数据,label,长,宽等信息保存为record文件 注意:读取的图像数据默认是uint8,再转为tf的字符串型BytesList保存,解析请需要根据需要转换类型 参数: image_dir:原始图像的目录 file:输入保存图片信息的txt文件(image_dir+file构成图片的路径) output_record_dir:保存record文件的路径 resize_height: resize_width: PS:当resize_height或者resize_width=0是,不执行resize shuffle:是否打乱顺序 log:log信息打印间隔 ''' # 加载文件,仅获取一个label images_list, labels_list=load_labels_file(file,1,shuffle) writer = tf.python_io.TFRecordWriter(output_record_dir) for i, [image_name, labels] in enumerate(zip(images_list, labels_list)): image_path=os.path.join(image_dir,images_list[i]) if not os.path.exists(image_path): print('Err:no image',image_path) continue image = read_image(image_path, resize_height, resize_width) image_raw = image.tostring() if i%log==0 or i==len(images_list)-1: print('------------processing:%d-th------------' % (i)) print('current image_path=%s' % (image_path), 'shape:{}'.format(image.shape), 'labels:{}'.format(labels)) # 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项 label=labels[0] example = tf.train.Example(features=tf.train.Features(feature={ 'image_raw': _bytes_feature(image_raw), 'height': _int64_feature(image.shape[0]), 'width': _int64_feature(image.shape[1]), 'depth': _int64_feature(image.shape[2]), 'label': _int64_feature(label) })) writer.write(example.SerializeToString()) writer.close() def disp_records(record_file,resize_height, resize_width,show_nums=4): ''' 解析record文件,并显示show_nums张图片,主要用于验证生成record文件是否成功 参数: tfrecord_file: record文件路径 ''' # 读取record函数 tf_image, tf_label = read_records(record_file, resize_height, resize_width, type='normalization') # 显示前4个图片 init_op = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(show_nums): # 在会话中取出image和label image,label = sess.run([tf_image,tf_label]) # image = tf_image.eval() # 直接从record解析的image是一个向量,需要reshape显示 # image = image.reshape([height,width,depth]) print('shape:{},tpye:{},labels:{}'.format( image.shape,image.dtype,label)) # pilimg = Image.fromarray(np.asarray(image_eval_reshape)) # pilimg.show() show_image("image:%d"%(label),image) coord.request_stop() coord.join(threads) def batch_test(record_file,resize_height, resize_width): ''' 参数: record_file: record文件路径 resize_height: resize_width: :PS:image_batch, label_batch一般作为网络的输入 ''' # 读取record函数 tf_image,tf_label = read_records(record_file, resize_height, resize_width, type='normalization') image_batch, label_batch= get_batch_images(tf_image, tf_label, batch_size=4, labels_nums=5, one_hot=False, shuffle=False) init = tf.global_variables_initializer() with tf.Session() as sess: # 开始一个会话 sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(4): # 在会话中取出images和labels images, labels = sess.run([image_batch, label_batch]) # 这里仅显示每个batch里第一张图片 show_image("image", images[0, :, :, :]) print('shape:{},tpye:{},labels:{}'.format( images.shape,images.dtype,labels)) # 停止所有线程 coord.request_stop() coord.join(threads) if __name__ == '__main__': resize_height = 224 # 指定存储图片高度 resize_width = 224 # 指定存储图片宽度 shuffle=True log=5 # 产生train.record文件 image_dir='dataset/train' train_labels = 'dataset/train.txt' # 图片路径 train_record_output = 'dataset/record/train.tfrecords' create_records(image_dir, train_labels, train_record_output, resize_height, resize_width,shuffle,log) train_nums=get_example_nums(train_record_output) print("save train example nums={}".format(train_nums)) # 产生val.record文件 image_dir='dataset/val' val_labels = 'dataset/val.txt' # 图片路径 val_record_output = 'dataset/record/val.tfrecords' create_records(image_dir,val_labels, val_record_output, resize_height, resize_width, shuffle,log) val_nums=get_example_nums(val_record_output) print("save val example nums={}".format(val_nums)) # 测试显示函数 # disp_records(train_record_output,resize_height, resize_width) batch_test(train_record_output,resize_height, resize_width)
Last modification:November 21st, 2018 at 09:44 pm