原文:Keras: Training on Large Datasets

如果训练网络时,针对的是大规模数据集,如图像数据集,其不能完全读取加载到内存里,那么就需要利用用到 data generator了. data generator 将数据分为 batches,再送入网络进行训练.

TnesorFlow 有对应的 API,但其 API 比较复杂,且容易出错.

对于习惯于 Keras 的人来说,Keras 减少了学习繁琐 API 的成本(未来可能会发生变化),只需关注于模型设计.

另一个采用 Keras 的 Sequence class 作为 batch data generator 的优势在于,Keras 能够处理所有的 multi-threading 和并行化(parallelization),以确保训练过程中不需要 batch data generation 的数据等待. 其背后的原理是,采用了 multiplt CPUs 核提前拉取了 batches 数据.

1. 采用 Keras 的 Sequence Class

1.1 Keras 的 Sequence 文档

Keras Doc - Sequence

keras.utils.Sequence()

每个 Sequence 必须包含 __getitem____len__ 方法的实现.

如果需要修改自定义数据集的 epochs 间的数据,则需要实现 on_epoch_end.

__getitem__返回一个完整的 batch.

Sequence 是进行 Multiprocessing 的更加安全的方式. 其保证了网络只对每个 epoch 内的每个样本训练一次.

例示:

from skimage.io import imread
from skimage.transform import resize
import numpy as np

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        # x_set - 图片路径的列表
        # y_set - 对应的类别
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
                for file_name in batch_x]), np.array(batch_y)

1.2 Keras 的 Sequence 使用

修改自上面的 CIFAR10Sequence(Sequence).

from skimage.io import imread
from skimage.transform import resize
import numpy as np

class MY_Generator(Sequence): # generator 继承自 Sequence

    def __init__(self, image_filenames, labels, batch_size):
        # image_filenames - 图片路径
        # labels - 图片对应的类别标签
        self.image_filenames, self.labels = image_filenames, labels
        self.batch_size = batch_size

    def __len__(self):
        # 计算 generator要生成的 batches 数,
        return np.ceil(len(self.image_filenames) / float(self.batch_size))

    def __getitem__(self, idx):
        # idx - 给定的 batch 数,以构建 batch 数据 [images_batch, GT]
        batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)
    
# generator 的使用
my_training_batch_generator = My_Generator(training_filenames, 
                                           GT_training, 
                                           batch_size)
my_validation_batch_generator = My_Generator(validation_filenames, 
                                             GT_validation, 
                                             batch_size)

model.fit_generator(generator=my_training_batch_generator,
                    steps_per_epoch=(num_training_samples // batch_size),
                    epochs=num_epochs,
                    verbose=1,
                                         validation_data=my_validation_batch_generator,
                    validation_steps=(num_validation_samples // batch_size),
                    use_multiprocessing=True,
                    workers=16,
                    max_queue_size=32)
# 如果有多个 CPU 核,可以设置 use_multiprocessing=True,即可在 CPU 上并行运行.
# 设置 workers=CPU 核数,用于 batch 数据生成. 
Last modification:November 3rd, 2018 at 06:18 pm