torchvision 提供了 transforms 数据增强库.
Albumentations 是一个更强大的数据增强库.
1. 基于 torchvision 的 pipeline
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
class TorchvisionDataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with PIL
image = Image.open(file_path)
if self.transform:
image = self.transform(image)
return image, label
#
torchvision_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
])
#
torchvision_dataset = TorchvisionDataset(
file_paths=['./images/image_1.jpg',
'./images/image_2.jpg',
'./images/image_3.jpg'],
labels=[1, 2, 3],
transform=torchvision_transform,
)
2. 基于 albumentations 的 pipline
2.1. opencv 版
import cv2
import numpy as np
from torch.utils.data import Dataset
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, label
#
albumentations_transform = Compose([
Resize(256, 256),
RandomCrop(224, 224),
HorizontalFlip(),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensor()
])
#
albumentations_dataset = AlbumentationsDataset(
file_paths=['./images/image_1.jpg',
'./images/image_2.jpg',
'./images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_transform,
)
2.2. PIL 版
采用 PIL库在数据增强前,需要现将 PIL 图像转换为 numpy 数组;最后再将数据增强后的 numpy 数组转换为 PIL 图像.
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor
class AlbumentationsPilDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
image = Image.open(file_path)
if self.transform:
# Convert PIL image to numpy array
image_np = np.array(image)
# Apply transformations
augmented = self.transform(image=image_np)
# Convert numpy array to PIL Image
image = Image.fromarray(augmented['image'])
return image, label
#
albumentations_pil_transform = Compose([
Resize(256, 256),
RandomCrop(224, 224),
HorizontalFlip(),
])
# Note that this dataset will output PIL images and not numpy arrays nor PyTorch tensors
albumentations_pil_dataset = AlbumentationsPilDataset(
file_paths=['./images/image_1.jpg',
'./images/image_2.jpg',
'./images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_pil_transform,
)
3. torchvision 与 albumentations 等价的变换
torchvision transform | albumentations transform | albumentations example |
---|---|---|
Compose | Compose | Compose([Resize(256, 256), RandomCrop(224, 224)]) |
CenterCrop | CenterCrop | CenterCrop(256, 256) |
ColorJitter | HueSaturationValue | HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5) |
Pad | PadIfNeeded | PadIfNeeded(min_height=512, min_width=512) |
RandomAffine | ShiftScaleRotate | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5) |
RandomCrop | RandomCrop | RandomCrop(256, 256) |
RandomGrayscale | ToGray | ToGray(p=0.5) |
RandomHorizontalFlip | HorizontalFlip | HorizontalFlip(p=0.5) |
RandomRotation | Rotate | Rotate(limit=45, p=0.5) |
RandomVerticalFlip | VerticalFlip | VerticalFlip(p=0.5) |
Resize | Resize | Resize(256, 256) |
Normalize | Normalize | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |