主要集中汇总学习下 BackgroundMattingV2 中的一些关于数据集、数据增强等辅助函数.
1. ImagesDataset
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
class ImagesDataset(Dataset):
def __init__(self, root, mode='RGB', transforms=None):
self.transforms = transforms
self.mode = mode
self.filenames = sorted(
[*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
*glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
with Image.open(self.filenames[idx]) as img:
img = img.convert(self.mode)
if self.transforms:
img = self.transforms(img)
return img
2. VideoDataset
import cv2
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
class VideoDataset(Dataset):
def __init__(self, path: str, transforms: any = None):
self.cap = cv2.VideoCapture(path)
self.transforms = transforms
self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
def __len__(self):
return self.frame_count
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, img = self.cap.read()
if not ret:
raise IndexError(f'Idx: {idx} out of length: {len(self)}')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
if self.transforms:
img = self.transforms(img)
return img
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.cap.release()
3. ZipDataset
from torch.utils.data import Dataset
from typing import List
class ZipDataset(Dataset):
def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
self.datasets = datasets
self.transforms = transforms
if assert_equal_length:
for i in range(1, len(datasets)):
assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
def __len__(self):
return max(len(d) for d in self.datasets)
def __getitem__(self, idx):
x = tuple(d[idx % len(d)] for d in self.datasets)
if self.transforms:
x = self.transforms(*x)
return x
4. 数据增强
该数据增强函数,主要有助于对成对图像的数据增强变换. 如:
img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
完整实现如下:
import random
import torch
import numpy as np
import math
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image, ImageFilter
"""
Pair transforms are MODs of regular transforms so that it takes in multiple images
and apply exact transforms on all images. This is especially useful when we want the
transforms on a pair of images.
Example:
img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
"""
class PairCompose(T.Compose):
def __call__(self, *x):
for transform in self.transforms:
x = transform(*x)
return x
#成对应用变换
class PairApply:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *x):
return [self.transforms(xi) for xi in x]
#成对仅对特定索引的图像进行变换
class PairApplyOnlyAtIndices:
def __init__(self, indices, transforms):
self.indices = indices
self.transforms = transforms
def __call__(self, *x):
return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
#成对随机仿射变换
class PairRandomAffine(T.RandomAffine):
def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
self.resamples = resamples
def __call__(self, *x):
if not len(x):
return []
param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
resamples = self.resamples or [self.resample] * len(x)
return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
#成对随机尺寸调整和裁剪
class PairRandomResizedCrop(T.RandomResizedCrop):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolations=None):
super().__init__(size, scale, ratio, Image.BILINEAR)
self.interpolations = interpolations
def __call__(self, *x):
if not len(x):
return []
i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
interpolations = self.interpolations or [self.interpolation] * len(x)
return [F.resized_crop(xi, i, j, h, w, self.size, interpolations[i]) for i, xi in enumerate(x)]
#成对随机水平翻转
class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
def __call__(self, *x):
if torch.rand(1) < self.p:
x = [F.hflip(xi) for xi in x]
return x
#随机方框模糊
class RandomBoxBlur:
def __init__(self, prob, max_radius):
self.prob = prob
self.max_radius = max_radius
def __call__(self, img):
if torch.rand(1) < self.prob:
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
img = img.filter(fil)
return img
#成对随机方框模糊
class PairRandomBoxBlur(RandomBoxBlur):
def __call__(self, *x):
if torch.rand(1) < self.prob:
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
x = [xi.filter(fil) for xi in x]
return x
#随机锐化
class RandomSharpen:
def __init__(self, prob):
self.prob = prob
self.filter = ImageFilter.SHARPEN
def __call__(self, img):
if torch.rand(1) < self.prob:
img = img.filter(self.filter)
return img
#成对随机锐化
class PairRandomSharpen(RandomSharpen):
def __call__(self, *x):
if torch.rand(1) < self.prob:
x = [xi.filter(self.filter) for xi in x]
return x
#成对随机仿射变换和尺寸调整
class PairRandomAffineAndResize:
def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
self.size = size
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.ratio = ratio
self.resample = resample
self.fillcolor = fillcolor
def __call__(self, *x):
if not len(x):
return []
w, h = x[0].size
scale_factor = max(self.size[1] / w, self.size[0] / h)
w_padded = max(w, self.size[1])
h_padded = max(h, self.size[0])
pad_h = int(math.ceil((h_padded - h) / 2))
pad_w = int(math.ceil((w_padded - w) / 2))
scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
def transform(img):
if pad_h > 0 or pad_w > 0:
img = F.pad(img, (pad_w, pad_h))
img = F.affine(img, *affine_params, self.resample, self.fillcolor)
img = F.center_crop(img, self.size)
return img
return [transform(xi) for xi in x]
#随机仿射变换和尺寸调整
class RandomAffineAndResize(PairRandomAffineAndResize):
def __call__(self, img):
return super().__call__(img)[0]