TensorFlow 提供的 TF-Slim 提供了数据处理的相关API - TF-Slim Data.
TF-Slim 提供数据加载库,用于不同格式数据的读取.
TF-Slim Data 模块由几个抽象层构成,支持多种文件存储类型,如 TF-Records,Text files,以及数据编码(data encoding),特征命名方案(features naming schemes)等,具有较好的灵活性.
数据加载包括两个主要部分:
- [1] - 数据表示的说明,以便于读取和解释;
- [2] - 提供数据到数据集的使用者的指导.
第二点,必须给出数据实际被提供和在内存中封装的说明.
例如,如果数据是被多种源来共享,那么这些源是否需要并行读取?还是串行读取?数据在内存中是否打乱?
<h2>1. 数据集说明</h2>
TF-Slim 定义数据集为文件的集合(编码或者非编码的),表示样本的有限集,可以被读取,以提供预定义的元素集或 items
.
例如,数据集可能保存为上千个文件,或单个文件. 文件内可能以整洁文本或其它高级编码的形式保存的数据. 也可能是单个 item
,比如图片image,或多个 items
,如图片image,类别标签label和场景标签.
更具体的来说,TF-Slim 的 dataset.py 提供的 Dataset
类,是一个数组,主要囊括以下关于数据集说明的部分:
data_sources
- 组成数据集的文件路径列表reader
- 与data_sources
中文件类型相对应的 TensorFlow Reader.decoder
- TF-Slim 的 data_decoder 类,用于解码读取的数据集文件内容.num_samples
- 数据集中样本数量.items_to_descriptions
- 数据集提供的 items 到每个 item 的描述的映射.
简单来说,数据集首先通过采用给定 reader
类打开 data_sources
内的文件;然后,采用给定 decoder
解码文件;最后,允许用户来请求 items
列表,并以 Tensors
形式返回.
Dataset 类:
"""Contains the definition of a Dataset.
A Dataset is a collection of several components:
(1) a list of data sources
(2) a Reader class that can read those sources and returns possibly encoded samples of data
(3) a decoder that decodes each sample of data provided by the reader
(4) the total number of samples
(5) an optional dictionary mapping the list of items returns to a description of those items.
Data can be loaded from a dataset specification using a dataset_data_provider:
dataset = CreateMyDataset(...)
provider = dataset_data_provider.DatasetDataProvider(dataset, shuffle=False)
image, label = provider.get(['image', 'label'])
See slim.data.dataset_data_provider for additional examples.
"""
from future import absolute_import
from future import division
from future import print_function
class Dataset(object):
"""Represents a Dataset specification."""
def __init__(self, data_sources, reader, decoder, num_samples,
items_to_descriptions, **kwargs):
"""Initializes the dataset.
Args:
data_sources: A list of files that make up the dataset.
reader: The reader class, a subclass of BaseReader such as TextLineReader or TFRecordReader.
decoder: An instance of a data_decoder.
num_samples: The number of samples in the dataset.
items_to_descriptions: A map from the items that the dataset provides to
the descriptions of those items.
**kwargs: Any remaining dataset-specific fields.
"""
kwargs['data_sources'] = data_sources
kwargs['reader'] = reader
kwargs['decoder'] = decoder
kwargs['num_samples'] = num_samples
kwargs['items_to_descriptions'] = items_to_descriptions
self.__dict__.update(kwargs)
<h2>2. 数据解码 Data Decoders</h2>
data_decoder 是一个类,其给定某些数据(可能是序列化的或编码的),返回 Tensors
列表.
特别是,给定的数据解码器可以解码预定义的 items
列表,并返回其子集或者全部的结果.
# Load the data
my_encoded_data = ...
data_decoder = MyDataDecoder()
# Decode the inputs and labels:
decoded_input, decoded_labels = data_decoder.Decode(data, ['input', 'labels'])
# Decode just the inputs:
decoded_input = data_decoder.Decode(data, ['input'])
# Check which items a data decoder knows how to decode:
for item in data_decoder.list_items():
print(item)
"""Contains helper functions and classes necessary for decoding data.
While data providers read data from disk, sstables or other formats, data
decoders decode the data (if necessary).
A data decoder is provided with a serialized or encoded piece of data
as well as a list of items and returns a set of tensors,
each of which correspond to the requested list of items extracted from the data:
def Decode(self, data, items):
...
For example, if data is a compressed map, the implementation might be:
def Decode(self, data, items):
decompressed_map = _Decompress(data)
outputs = []
for item in items:
outputs.append(decompressed_map[item])
return outputs.
"""
from future import absolute_import
from future import division
from future import print_function
import abc
class DataDecoder(object):
"""An abstract class which is used to decode data for a provider."""
metaclass = abc.ABCMeta
@abc.abstractmethod
def decode(self, data, items):
"""Decodes the data to returns the tensors specified by the list of items.
Args:
data: A possibly encoded data format.
items: A list of strings, each of which indicate a particular data type.
Returns:
A list of Tensors
, whose length matches the length of items
, where
each Tensor
corresponds to each item.
Raises:
ValueError: If any of the items cannot be satisfied.
"""
pass
@abc.abstractmethod
def list_items(self):
"""Lists the names of the items that the decoder can decode.
Returns:
A list of string names.
"""
pass
<h2>3. 示例:TFExampleDecoder</h2>
tfexample_decoder.py 是解码序列化的 TFExample
protocol buffers 的数据解码器.TFExample
protoco buffer 是从 keys (strings) 到 tf.FixedLenFeature
或 tf.VarLenFeature
的映射.
因此,为了解码 TFExample
,必须提供一个从一个或多个 TFExample
fields 到 tfexample_decoder
提供的每个 item
的映射.
例如,TFExamples
数据集可能保存了不同格式的图片,每个 TFExample
包含一个 encoding
key 和一个 format
key,用于采用相应编码器(jpg,png 等)来解码图片.
tfexample_decoder
通过指定 TFExample
keys 到 tf.FixedLenFeature
或 tf.VarLenFeature
的映射,以及 ItemHandlers
集合来构建.
一个 ItemHandlers
提供了从 TFExample
keys 到提供的 item 的映射.
由于一个 tfexample_decoder
可能返回多个 items
,一般采用多个 ItemHandlers
来构建一个 tfexample_decoder
.
tfexample_decoder
提供了一些预定义的 ItemHandlers
,基本上可以处理 TFExamples
到图片images、Tensors
和 SparseTensors
的大部分场合.
例如,下面的代码可以用于解码图片数据集:
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'),
'image/class/label': tf.FixedLenFeature(
[1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
}
items_to_handlers = {
'image': tfexample_decoder.Image(
image_key = 'image/encoded',
format_key = 'image/format',
shape=[28, 28],
channels=1),
'label': tfexample_decoder.Tensor('image/class/label'),
}
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
注,TFExample
采用三个keys 来解析:image/encoded
,image/format
和 image/class/label
.
前两个 keys 映射到同一个 item
- 即 images
.
根据定义,该 data_decoder
提供了两个 items
- 即 image
和 label
.
<h2>4. 数据提供 Data Provision</h2>
data_provider 是一个类 - DataProvider(object),提供了每个请求 item 的 Tensors
:
my_data_provider = ...
image, class_label, bounding_box = my_data_provider.get( ['image', 'label', 'bb'])
dataset_data_provider 是一个 data_provider
,从给定的 dataset
说明中提供数据.
dataset = GetDataset(...)
data_provider = dataset_data_provider.DatasetDataProvider(
dataset, common_queue_capacity=32, common_queue_min=8)
dataset_data_provider
可以设置参数控制数据提供:
- 同时使用的 readers 数量
- 数据加载到队列queue 时,是否打乱
- 是否是单线程读取数据,或者无限读取数据
"""
A DataProvider that provides data from a Dataset.
DatasetDataProviders provide data from datasets.
The provide can be configured to use multiple readers simultaneously or read via a single reader.
Additionally, the data being read can be optionally shuffled.
For example, to read data using a single thread without shuffling:
pascal_voc_data_provider = DatasetDataProvider(
slim.datasets.pascal_voc.get_split('train'),
shuffle=False)
images, labels = pascal_voc_data_provider.get(['images', 'labels'])
To read data using multiple readers simultaneous with shuffling:
pascal_voc_data_provider = DatasetDataProvider(
slim.datasets.pascal_voc.Dataset(),
num_readers=10,
shuffle=True)
images, labels = pascal_voc_data_provider.get(['images', 'labels'])
Equivalently, one may request different fields of the same sample separately:
[images] = pascal_voc_data_provider.get(['images'])
[labels] = pascal_voc_data_provider.get(['labels'])
"""
from future import absolute_import
from future import division
from future import print_function
from tensorflow.contrib.slim.python.slim.data import data_provider
from tensorflow.contrib.slim.python.slim.data import parallel_reader
class DatasetDataProvider(data_provider.DataProvider):
def __init__(self,
dataset,
num_readers=1,
reader_kwargs=None,
shuffle=True,
num_epochs=None,
common_queue_capacity=256,
common_queue_min=128,
record_key='record_key',
seed=None,
scope=None):
"""
Creates a DatasetDataProvider.
Note: if num_epochs
is not None
, local counter epochs
will be created
by relevant function. Use local_variables_initializer()
to initialize local variables.
Args:
dataset: An instance of the Dataset class.
num_readers: The number of parallel readers to use.
reader_kwargs: An optional dict of kwargs for the reader.
shuffle: Whether to shuffle the data sources and common queue when reading.
num_epochs: The number of times each data source is read. If left as None,
the data will be cycled through indefinitely.
common_queue_capacity: The capacity of the common queue.
common_queue_min: The minimum number of elements in the common queue after a dequeue.
record_key: The item name to use for the dataset record keys in the provided tensors.
seed: The seed to use if shuffling.
scope: Optional name scope for the ops.
Raises:
ValueError: If record_key
matches one of the items in the dataset.
"""
key, data = parallel_reader.parallel_read(
dataset.data_sources,
reader_class=dataset.reader,
num_epochs=num_epochs,
num_readers=num_readers,
reader_kwargs=reader_kwargs,
shuffle=shuffle,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
seed=seed,
scope=scope)
items = dataset.decoder.list_items()
tensors = dataset.decoder.decode(data, items)
items_to_tensors = dict(zip(items, tensors))
if record_key in items_to_tensors:
raise ValueError('The item name used for record_key
cannot also be '
'used for a dataset item: %s', record_key)
items_to_tensors[record_key] = key
super(DatasetDataProvider, self).__init__(
items_to_tensors=items_to_tensors,
num_samples=dataset.num_samples)