数据增强(Data Augmentation),又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而提升模型的泛化能力.

torchvision transforms 数据增强主要涉及如下:

[1] - 裁剪类

  • transforms.CenterCrop - 中心裁剪
  • transforms.RandomCrop - 随机裁剪
  • transforms.RandomResizedCrop - 随机长宽比裁剪
  • transforms.FiveCrop - 上下左右中心裁剪
  • transforms.TenCrop - 上下左右中心裁剪后旋转

[2] - 翻转和旋转

  • transforms.RandomHorizontalFlip - 随机水平翻转
  • transforms.RandomVerticalFlip - 随机垂直翻转
  • transforms.RandomRotation - 随机旋转

[3] - 图像变换

  • transforms.Pad - 填充
  • transforms.ColorJitter - 亮度、对比度和饱和度
  • transforms.Grayscale - 转灰度图
  • transforms.RandomGrayscale - 随机转灰度图
  • transforms.RandomAffine - 随机仿射变换
  • transforms.LinearTransformation - 线性变换
  • transforms.RandomErasing - 随机擦除
  • transforms.Lambda - 自定义变换
  • transforms.Resize
  • transforms.ToTensor - 转为 Tensor,并归一化到[0, 1].
  • transforms.Normalize - 标准化

[4] - transforms的操作

  • transforms.RandomChoice(transforms) - 随机选择给定 transforms 中的一种
  • transforms.RandomApply(transforms, p=0.5) - 加上随机概率
  • transforms.RandomOrder - 随机打乱 transforms 操作的顺序

示例 - FixRes codeblock

From FixRes

import numpy as np
import torch
import torchvision.transforms.functional as F
from torchvision import transforms

class Resize(transforms.Resize):
    """
    Resize with a ``largest=False'' argument
    allowing to resize to a common largest side without cropping
    """
    def __init__(self, size, largest=False, **kwargs):
        super().__init__(size, **kwargs)
        self.largest = largest

    @staticmethod
    def target_size(w, h, size, largest=False):
        if h < w and largest:
            w, h = size, int(size * h / w)
        else:
            w, h = int(size * w / h), size
        size = (h, w)
        return size

    def __call__(self, img):
        size = self.size
        w, h = img.size
        target_size = self.target_size(w, h, size, self.largest)
        return F.resize(img, target_size, self.interpolation)

    def __repr__(self):
        r = super().__repr__()
        return r[:-1] + ', largest={})'.format(self.largest)

#
def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val')):
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    transformations = {}
    if 'train' in need:
        if kind == 'torch':
            transformations['train'] = transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        elif kind == 'full':
            transformations['train'] = transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.3, 0.3, 0.3),
                transforms.ToTensor(),            
                transforms.Normalize(mean, std),
            ])

        else:
            raise ValueError('Transforms kind {} unknown'.format(kind))
    if 'val' in need:
        if crop:
            transformations['val'] = transforms.Compose(
                [Resize(int((256 / 224) * test_size)),  # to maintain same ratio w.r.t. 224 images
                 transforms.CenterCrop(test_size),
                 transforms.ToTensor(),
                 transforms.Normalize(mean, std)])
        else:
            transformations['val'] = transforms.Compose(
                [Resize(test_size, largest=True),  # to maintain same ratio w.r.t. 224 images
                 transforms.ToTensor(),
                 transforms.Normalize(mean, std)])
    return transformations

transforms_list = ['torch', 'full']

示例 - self-supervised learning

Typical data augmentation composition (pytorch) for self-supervised learning:

from torchvision import transforms
...
augmentation = [
    transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
]

参考

[1] - PyTorch 学习笔记(三):transforms的二十二个方法 - 2018.12.26

[2] - pytorch torchvision transform

[3] - Docs > torchvision > torchvision.transforms

Last modification:June 20th, 2020 at 04:15 pm