主要集中汇总学习下 BackgroundMattingV2 中的一些关于数据集、数据增强等辅助函数.

1. ImagesDataset

import os
import glob
from torch.utils.data import Dataset
from PIL import Image

class ImagesDataset(Dataset):
    def __init__(self, root, mode='RGB', transforms=None):
        self.transforms = transforms
        self.mode = mode
        self.filenames = sorted(
            [*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
             *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        with Image.open(self.filenames[idx]) as img:
            img = img.convert(self.mode)
        
        if self.transforms:
            img = self.transforms(img)
        
        return img

2. VideoDataset

import cv2
import numpy as np
from torch.utils.data import Dataset
from PIL import Image

class VideoDataset(Dataset):
    def __init__(self, path: str, transforms: any = None):
        self.cap = cv2.VideoCapture(path)
        self.transforms = transforms
        
        self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
        self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    def __len__(self):
        return self.frame_count
    
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self[i] for i in range(*idx.indices(len(self)))]
        
        if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
            self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, img = self.cap.read()
        if not ret:
            raise IndexError(f'Idx: {idx} out of length: {len(self)}')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        if self.transforms:
            img = self.transforms(img)
        return img
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.cap.release()

3. ZipDataset

from torch.utils.data import Dataset
from typing import List

class ZipDataset(Dataset):
    def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
        self.datasets = datasets
        self.transforms = transforms
        
        if assert_equal_length:
            for i in range(1, len(datasets)):
                assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
    
    def __len__(self):
        return max(len(d) for d in self.datasets)
    
    def __getitem__(self, idx):
        x = tuple(d[idx % len(d)] for d in self.datasets)
        if self.transforms:
            x = self.transforms(*x)
        return x

4. 数据增强

augmentation.py

该数据增强函数,主要有助于对成对图像的数据增强变换. 如:

img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)

完整实现如下:

import random
import torch
import numpy as np
import math
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image, ImageFilter

"""
Pair transforms are MODs of regular transforms so that it takes in multiple images
and apply exact transforms on all images. This is especially useful when we want the
transforms on a pair of images.
Example:
    img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
"""

class PairCompose(T.Compose):
    def __call__(self, *x):
        for transform in self.transforms:
            x = transform(*x)
        return x
    
#成对应用变换
class PairApply:
    def __init__(self, transforms):
        self.transforms = transforms
        
    def __call__(self, *x):
        return [self.transforms(xi) for xi in x]

#成对仅对特定索引的图像进行变换
class PairApplyOnlyAtIndices:
    def __init__(self, indices, transforms):
        self.indices = indices
        self.transforms = transforms
    
    def __call__(self, *x):
        return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]

#成对随机仿射变换
class PairRandomAffine(T.RandomAffine):
    def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
        super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
        self.resamples = resamples
    
    def __call__(self, *x):
        if not len(x):
            return []
        param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
        resamples = self.resamples or [self.resample] * len(x)
        return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]

#成对随机尺寸调整和裁剪
class PairRandomResizedCrop(T.RandomResizedCrop):
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolations=None):
        super().__init__(size, scale, ratio, Image.BILINEAR)
        self.interpolations = interpolations
    
    def __call__(self, *x):
        if not len(x):
            return []
        i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
        interpolations = self.interpolations or [self.interpolation] * len(x)
        return [F.resized_crop(xi, i, j, h, w, self.size, interpolations[i]) for i, xi in enumerate(x)]
    
#成对随机水平翻转
class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
    def __call__(self, *x):
        if torch.rand(1) < self.p:
            x = [F.hflip(xi) for xi in x]
        return x

#随机方框模糊
class RandomBoxBlur:
    def __init__(self, prob, max_radius):
        self.prob = prob
        self.max_radius = max_radius
    
    def __call__(self, img):
        if torch.rand(1) < self.prob:
            fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
            img = img.filter(fil)
        return img

#成对随机方框模糊
class PairRandomBoxBlur(RandomBoxBlur):
    def __call__(self, *x):
        if torch.rand(1) < self.prob:
            fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
            x = [xi.filter(fil) for xi in x]
        return x

#随机锐化
class RandomSharpen:
    def __init__(self, prob):
        self.prob = prob
        self.filter = ImageFilter.SHARPEN
    
    def __call__(self, img):
        if torch.rand(1) < self.prob:
            img = img.filter(self.filter)
        return img
    
#成对随机锐化    
class PairRandomSharpen(RandomSharpen):
    def __call__(self, *x):
        if torch.rand(1) < self.prob:
            x = [xi.filter(self.filter) for xi in x]
        return x
    
#成对随机仿射变换和尺寸调整
class PairRandomAffineAndResize:
    def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
        self.size = size
        self.degrees = degrees
        self.translate = translate
        self.scale = scale
        self.shear = shear
        self.ratio = ratio
        self.resample = resample
        self.fillcolor = fillcolor
    
    def __call__(self, *x):
        if not len(x):
            return []
        
        w, h = x[0].size
        scale_factor = max(self.size[1] / w, self.size[0] / h)
        
        w_padded = max(w, self.size[1])
        h_padded = max(h, self.size[0])
        
        pad_h = int(math.ceil((h_padded - h) / 2))
        pad_w = int(math.ceil((w_padded - w) / 2))
        
        scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
        translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
        affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
        
        def transform(img):
            if pad_h > 0 or pad_w > 0:
                img = F.pad(img, (pad_w, pad_h))
            
            img = F.affine(img, *affine_params, self.resample, self.fillcolor)
            img = F.center_crop(img, self.size)
            return img
            
        return [transform(xi) for xi in x]

#随机仿射变换和尺寸调整
class RandomAffineAndResize(PairRandomAffineAndResize):
    def __call__(self, img):
        return super().__call__(img)[0]
Last modification:December 29th, 2020 at 11:07 am