如果训练网络时,针对的是大规模数据集,如图像数据集,其不能完全读取加载到内存里,那么就需要利用用到 data generator了. data generator 将数据分为 batches,再送入网络进行训练.
TnesorFlow 有对应的 API,但其 API 比较复杂,且容易出错.
对于习惯于 Keras 的人来说,Keras 减少了学习繁琐 API 的成本(未来可能会发生变化),只需关注于模型设计.
另一个采用 Keras 的 Sequence class
作为 batch data generator 的优势在于,Keras 能够处理所有的 multi-threading 和并行化(parallelization),以确保训练过程中不需要 batch data generation 的数据等待. 其背后的原理是,采用了 multiplt CPUs 核提前拉取了 batches 数据.
1. 采用 Keras 的 Sequence Class
1.1 Keras 的 Sequence 文档
keras.utils.Sequence()
每个 Sequence
必须包含 __getitem__
和 __len__
方法的实现.
如果需要修改自定义数据集的 epochs 间的数据,则需要实现 on_epoch_end
.
__getitem__
返回一个完整的 batch.
Sequence 是进行 Multiprocessing 的更加安全的方式. 其保证了网络只对每个 epoch 内的每个样本训练一次.
例示:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
# x_set - 图片路径的列表
# y_set - 对应的类别
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
1.2 Keras 的 Sequence 使用
修改自上面的 CIFAR10Sequence(Sequence)
.
from skimage.io import imread
from skimage.transform import resize
import numpy as np
class MY_Generator(Sequence): # generator 继承自 Sequence
def __init__(self, image_filenames, labels, batch_size):
# image_filenames - 图片路径
# labels - 图片对应的类别标签
self.image_filenames, self.labels = image_filenames, labels
self.batch_size = batch_size
def __len__(self):
# 计算 generator要生成的 batches 数,
return np.ceil(len(self.image_filenames) / float(self.batch_size))
def __getitem__(self, idx):
# idx - 给定的 batch 数,以构建 batch 数据 [images_batch, GT]
batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
# generator 的使用
my_training_batch_generator = My_Generator(training_filenames,
GT_training,
batch_size)
my_validation_batch_generator = My_Generator(validation_filenames,
GT_validation,
batch_size)
model.fit_generator(generator=my_training_batch_generator,
steps_per_epoch=(num_training_samples // batch_size),
epochs=num_epochs,
verbose=1,
validation_data=my_validation_batch_generator,
validation_steps=(num_validation_samples // batch_size),
use_multiprocessing=True,
workers=16,
max_queue_size=32)
# 如果有多个 CPU 核,可以设置 use_multiprocessing=True,即可在 CPU 上并行运行.
# 设置 workers=CPU 核数,用于 batch 数据生成.
2 comments
请问__getitem__(self, idx)里的idx值是怎么来的?
这个可以的 讲解的很清晰