[Deep Image Matting复现过程总结[转]](https://www.aiuai.cn/aifarm330.html)
Deep Image Matting - CVPR2017 中采用的 losses 为:
$\mathcal{L}_{overall} = w_l \cdot \mathcal{L} + (1 - w_l) \cdot \mathcal{L}_c$
主要包括两个部分:
1. alpha-prediction loss
alpha-prediction loss 是对每个像素的 groundtruth alpha 值与 predicted alpha 值间的绝对差值(absolute difference).
但由于绝对值的不可微,故采用其逼近形式:
$\mathcal{L}_{\alpha}^{i} = \sqrt{(\alpha _p^i - \alpha _g^i)^2 + \epsilon^2 }, \alpha _p^i, \alpha _g^i \in [0, 1]$
其中,$\alpha _p^i$ 是在像素 i 的网络预测层输出结果,其值区间为 [0, 1];
$\alpha _g^i$ 是像素 i 的groundtruth alpha 值,其值区间为 [0, 1];
$\epsilon = 10^{-6}$ 是非常小的值.
则 $\mathcal{L} _{\alpha}^i$ 的求导为:
$\frac{\partial \mathcal{L}_{\alpha}^i }{\partial {\alpha _p^i}} = \frac{\alpha _p^i - \alpha _g^i}{ \sqrt{(\alpha _p^i - \alpha _g^i)^2 + \epsilon ^2}}$
其 Pytorch 实现 - gen_alpha_pred_loss:
import torch
def gen_alpha_pred_loss(alpha, pred_alpha, trimap):
wi = torch.zeros(trimap.shape)
wi[trimap == 128] = 1. # 只对未知区域处理
t_wi = wi.cuda()
unknown_region_size = t_wi.sum()
# alpha diff
alpha = alpha / 255. # [0, 1] 区间
alpha_loss = torch.sqrt((pred_alpha - alpha)**2 + 1e-12)
alpha_loss = (alpha_loss * t_wi).sum() / unknown_region_size
return alpha_loss
2. compositional loss
compositional loss 是 groundtruth RGB 颜色值与由 groundtruth 前景(foreground)、groundtruth 背景(background) 和 predicted alpha mattes 组成的 predicted RGB 颜色之间的绝对值之差.
类似地,采用其逼近形式:
$\mathcal{L}_{c}^{i} = \sqrt{(c_p^i - c_g^i)^2 + \epsilon^2 }$
其中,$c$ 是 RGB 通道颜色值.
$p$ 是由 predicted alphas 组成的图片.
$g$ 是由 groundtruth alphas 组成的图片.
compositional loss 用于约束网络以更精确地得到 alpha 预测结果.
其 Pytorch 实现:
import torch
def gen_loss(img, alpha, fg, bg, trimap, pred_mattes):
wi = torch.zeros(trimap.shape)
wi[trimap == 128] = 1.
t_wi = wi.cuda()
t3_wi = torch.cat((wi, wi, wi), 1).cuda()
unknown_region_size = t_wi.sum()
assert(t_wi.shape == pred_mattes.shape)
assert(t3_wi.shape == img.shape)
# alpha diff
alpha = alpha / 255.
alpha_loss = torch.sqrt((pred_mattes - alpha)**2 + 1e-12)
alpha_loss = (alpha_loss * t_wi).sum() / unknown_region_size
# composite rgb loss
pred_mattes_3 = torch.cat((pred_mattes, pred_mattes, pred_mattes), 1)
comp = pred_mattes_3 * fg + (1. - pred_mattes_3) * bg
comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12) / 255.
comp_loss = (comp_loss * t3_wi).sum() / unknown_region_size / 3.
#print("Loss: AlphaLoss:{} CompLoss:{}".format(alpha_loss, comp_loss))
return alpha_loss, comp_loss
# total loss
alpha_loss, comp_loss = gen_loss(img, alpha, fg, bg, trimap, pred_mattes)
loss = alpha_loss * wl_weight + comp_loss * (1. - wl_weight)
3. 由 Mask 生成 Trimap
主要是基于图像膨胀与腐蚀的形态学方法.
import cv2
import os
import numpy as np
import random
def main():
alpha_dir = './alpha'
sav_dir = './trimap'
img_ids = os.listdir(alpha_dir)
print("Images count: {}".format(len(img_ids)))
for img_id in img_ids:
alpha = cv2.imread(os.path.join(alpha_dir, img_id), 0)
#k_size = random.choice(range(20, 40))
k_size = 15
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
dilated = cv2.dilate(alpha, kernel)
#eroded = cv2.erode(alpha, kernel)
trimap = np.zeros(alpha.shape)
trimap.fill(128)
trimap[alpha >= 255] = 255
trimap[dilated <= 0] = 0
save_name = os.path.join(sav_dir, img_id)
print("Write to {}".format(save_name))
cv2.imwrite(save_name, trimap)
if __name__ == "__main__":
main()