论文阅读 - Deep Image Matting

[Deep Image Matting复现过程总结[转]](https://www.aiuai.cn/aifarm330.html)

Github - huochaitiantang/pytorch-deep-image-matting

Deep Image Matting - CVPR2017 中采用的 losses 为:

$\mathcal{L}_{overall} = w_l \cdot \mathcal{L} + (1 - w_l) \cdot \mathcal{L}_c$

主要包括两个部分:

1. alpha-prediction loss

alpha-prediction loss 是对每个像素的 groundtruth alpha 值与 predicted alpha 值间的绝对差值(absolute difference).

但由于绝对值的不可微,故采用其逼近形式:

$\mathcal{L}_{\alpha}^{i} = \sqrt{(\alpha _p^i - \alpha _g^i)^2 + \epsilon^2 }, \alpha _p^i, \alpha _g^i \in [0, 1]$

其中,$\alpha _p^i$ 是在像素 i 的网络预测层输出结果,其值区间为 [0, 1];

$\alpha _g^i​$ 是像素 i 的groundtruth alpha 值,其值区间为 [0, 1];

$\epsilon = 10^{-6}$ 是非常小的值.

则 $\mathcal{L} _{\alpha}^i$ 的求导为:

$\frac{\partial \mathcal{L}_{\alpha}^i }{\partial {\alpha _p^i}} = \frac{\alpha _p^i - \alpha _g^i}{ \sqrt{(\alpha _p^i - \alpha _g^i)^2 + \epsilon ^2}}$

其 Pytorch 实现 - gen_alpha_pred_loss

import torch

def gen_alpha_pred_loss(alpha, pred_alpha, trimap):
    wi = torch.zeros(trimap.shape)
    wi[trimap == 128] = 1. # 只对未知区域处理
    t_wi = wi.cuda()
    unknown_region_size = t_wi.sum()

    # alpha diff
    alpha = alpha / 255. # [0, 1] 区间
    alpha_loss = torch.sqrt((pred_alpha - alpha)**2 + 1e-12)
    alpha_loss = (alpha_loss * t_wi).sum() / unknown_region_size
    
    return alpha_loss

2. compositional loss

compositional loss 是 groundtruth RGB 颜色值与由 groundtruth 前景(foreground)、groundtruth 背景(background) 和 predicted alpha mattes 组成的 predicted RGB 颜色之间的绝对值之差.

类似地,采用其逼近形式:

$\mathcal{L}_{c}^{i} = \sqrt{(c_p^i - c_g^i)^2 + \epsilon^2 }$

其中,$c$ 是 RGB 通道颜色值.

$p$ 是由 predicted alphas 组成的图片.

$g$ 是由 groundtruth alphas 组成的图片.

compositional loss 用于约束网络以更精确地得到 alpha 预测结果.

其 Pytorch 实现:

import torch

def gen_loss(img, alpha, fg, bg, trimap, pred_mattes):
    wi = torch.zeros(trimap.shape)
    wi[trimap == 128] = 1.
    t_wi = wi.cuda()
    t3_wi = torch.cat((wi, wi, wi), 1).cuda()
    unknown_region_size = t_wi.sum()

    assert(t_wi.shape == pred_mattes.shape)
    assert(t3_wi.shape == img.shape)

    # alpha diff
    alpha = alpha / 255.
    alpha_loss = torch.sqrt((pred_mattes - alpha)**2 + 1e-12)
    alpha_loss = (alpha_loss * t_wi).sum() / unknown_region_size

    # composite rgb loss
    pred_mattes_3 = torch.cat((pred_mattes, pred_mattes, pred_mattes), 1)
    comp = pred_mattes_3 * fg + (1. - pred_mattes_3) * bg
    comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12) / 255.
    comp_loss = (comp_loss * t3_wi).sum() / unknown_region_size / 3.

    #print("Loss: AlphaLoss:{} CompLoss:{}".format(alpha_loss, comp_loss))
    return alpha_loss, comp_loss

# total loss
alpha_loss, comp_loss = gen_loss(img, alpha, fg, bg, trimap, pred_mattes)
loss = alpha_loss * wl_weight + comp_loss * (1. - wl_weight)

3. 由 Mask 生成 Trimap

gen_trimap.py

主要是基于图像膨胀与腐蚀的形态学方法.

Python - OpenCV 之图像形态学(膨胀与腐蚀)

import cv2
import os
import numpy as np
import random

def main():
    alpha_dir = './alpha'
    sav_dir = './trimap'
    img_ids = os.listdir(alpha_dir)
    print("Images count: {}".format(len(img_ids)))
    for img_id in img_ids:
        alpha = cv2.imread(os.path.join(alpha_dir, img_id), 0)

        #k_size = random.choice(range(20, 40))
        k_size = 15
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
        dilated = cv2.dilate(alpha, kernel)
        #eroded = cv2.erode(alpha, kernel)

        trimap = np.zeros(alpha.shape)
        trimap.fill(128)
        trimap[alpha >= 255] = 255 
        trimap[dilated <= 0] = 0 

        save_name = os.path.join(sav_dir, img_id)
        print("Write to {}".format(save_name))
        cv2.imwrite(save_name, trimap)

if __name__ == "__main__":
    main()
Last modification:October 23rd, 2018 at 05:19 pm