数据增强(Data Augmentation),又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而提升模型的泛化能力.
torchvision transforms 数据增强主要涉及如下:
[1] - 裁剪类
- transforms.CenterCrop - 中心裁剪
- transforms.RandomCrop - 随机裁剪
- transforms.RandomResizedCrop - 随机长宽比裁剪
- transforms.FiveCrop - 上下左右中心裁剪
- transforms.TenCrop - 上下左右中心裁剪后旋转
[2] - 翻转和旋转
- transforms.RandomHorizontalFlip - 随机水平翻转
- transforms.RandomVerticalFlip - 随机垂直翻转
- transforms.RandomRotation - 随机旋转
[3] - 图像变换
- transforms.Pad - 填充
- transforms.ColorJitter - 亮度、对比度和饱和度
- transforms.Grayscale - 转灰度图
- transforms.RandomGrayscale - 随机转灰度图
- transforms.RandomAffine - 随机仿射变换
- transforms.LinearTransformation - 线性变换
- transforms.RandomErasing - 随机擦除
- transforms.Lambda - 自定义变换
- transforms.Resize
- transforms.ToTensor - 转为 Tensor,并归一化到[0, 1].
- transforms.Normalize - 标准化
[4] - transforms的操作
- transforms.RandomChoice(transforms) - 随机选择给定 transforms 中的一种
- transforms.RandomApply(transforms, p=0.5) - 加上随机概率
- transforms.RandomOrder - 随机打乱 transforms 操作的顺序
示例 - FixRes codeblock
From FixRes
import numpy as np
import torch
import torchvision.transforms.functional as F
from torchvision import transforms
class Resize(transforms.Resize):
"""
Resize with a ``largest=False'' argument
allowing to resize to a common largest side without cropping
"""
def __init__(self, size, largest=False, **kwargs):
super().__init__(size, **kwargs)
self.largest = largest
@staticmethod
def target_size(w, h, size, largest=False):
if h < w and largest:
w, h = size, int(size * h / w)
else:
w, h = int(size * w / h), size
size = (h, w)
return size
def __call__(self, img):
size = self.size
w, h = img.size
target_size = self.target_size(w, h, size, self.largest)
return F.resize(img, target_size, self.interpolation)
def __repr__(self):
r = super().__repr__()
return r[:-1] + ', largest={})'.format(self.largest)
#
def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val')):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
transformations = {}
if 'train' in need:
if kind == 'torch':
transformations['train'] = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
elif kind == 'full':
transformations['train'] = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.3, 0.3, 0.3),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
else:
raise ValueError('Transforms kind {} unknown'.format(kind))
if 'val' in need:
if crop:
transformations['val'] = transforms.Compose(
[Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images
transforms.CenterCrop(test_size),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
else:
transformations['val'] = transforms.Compose(
[Resize(test_size, largest=True), # to maintain same ratio w.r.t. 224 images
transforms.ToTensor(),
transforms.Normalize(mean, std)])
return transformations
transforms_list = ['torch', 'full']
示例 - self-supervised learning
Typical data augmentation composition (pytorch) for self-supervised learning:
from torchvision import transforms
...
augmentation = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]
参考
[1] - PyTorch 学习笔记(三):transforms的二十二个方法 - 2018.12.26