BackgroundMattingV2 提供了对图像、视频、网络摄像头的抠图实现,以及支持导出 ONNX 和 Torchscript.

1. HomographicAlignment 图像对齐

图像对齐函数,用于将背景与源图的匹配.

inference_utils.py

import numpy as np
import cv2
from PIL import Image

class HomographicAlignment:
    """
    Apply homographic alignment on background to match with the source image.
    """
    
    def __init__(self):
        self.detector = cv2.ORB_create() # 
        '''
        ORB 采用FAST(features from accelerated segment test)算法来检测特征点.
        FAST核心思想是, 找出那些卓尔不群的点,即拿一个点跟它周围的点比较,
                      如果它和其中大部分的点都不一样就可以认为它是一个特征点.
        '''
        self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
        '''
        描述符匹配,特征匹配
        '''

    def __call__(self, src, bgr):
        src = np.asarray(src)
        bgr = np.asarray(bgr)

        keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
        keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)

        matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
        matches.sort(key=lambda x: x.distance, reverse=False)
        num_good_matches = int(len(matches) * 0.15)
        matches = matches[:num_good_matches]

        points_src = np.zeros((len(matches), 2), dtype=np.float32)
        points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
        for i, match in enumerate(matches):
            points_src[i, :] = keypoints_src[match.trainIdx].pt
            points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
        #
        H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)

        h, w = src.shape[:2]
        bgr = cv2.warpPerspective(bgr, H, (w, h)) #透视变换
        msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))

        # For areas that is outside of the background, 
        # We just copy pixels from the source.
        bgr[msk != 1] = src[msk != 1]

        src = Image.fromarray(src)
        bgr = Image.fromarray(bgr)
        
        return src, bgr

关于图像对齐可参考:

[1] - Image Alignment (Feature Based) using OpenCV (C++/Python)

[2] - 使用OpenCV实现基于特征的图像对齐

2. 图像抠图

inference_images.py

对图像进行抠图(多线程).

import argparse
import torch
import os
import shutil

from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm

from dataset import ImagesDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment

# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference images')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) 
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)

parser.add_argument('--images-src', type=str, required=True)
parser.add_argument('--images-bgr', type=str, required=True)

parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--preprocess-alignment', action='store_true')

parser.add_argument('--output-dir', type=str, required=True)
parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
parser.add_argument('-y', action='store_true')

args = parser.parse_args()

assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
    'Only mattingbase and mattingrefine support err output'
assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
    'Only mattingrefine support ref output'


# --------------- Main ---------------
device = torch.device(args.device)

#模型加载
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
        args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)


# 加载图像数据集
dataset = ZipDataset([
    ImagesDataset(args.images_src),
    ImagesDataset(args.images_bgr),], 
    assert_equal_length=True, 
    transforms=A.PairCompose([HomographicAlignment() if 
                              args.preprocess_alignment else 
                              A.PairApply(nn.Identity()),
                              A.PairApply(T.ToTensor())
                             ]))
dataloader = DataLoader(dataset, batch_size=1, num_workers=8, pin_memory=True)

# Create output directory
if os.path.exists(args.output_dir):
    if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
        shutil.rmtree(args.output_dir)
    else:
        exit()

for output_type in args.output_types:
    os.makedirs(os.path.join(args.output_dir, output_type))
    
# Worker function
def writer(img, path):
    img = to_pil_image(img[0].cpu())
    img.save(path)
    
# Conversion loop
with torch.no_grad():
    for i, (src, bgr) in enumerate(tqdm(dataloader)):
        filename = dataset.datasets[0].filenames[i]
        src = src.to(device, non_blocking=True)
        bgr = bgr.to(device, non_blocking=True)
        
        if args.model_type == 'mattingbase':
            pha, fgr, err, _ = model(src, bgr)
        elif args.model_type == 'mattingrefine':
            pha, fgr, _, _, err, ref = model(src, bgr)
        elif args.model_type == 'mattingbm':
            pha, fgr = model(src, bgr)
            
        if 'com' in args.output_types:
            com = torch.cat([fgr * pha.ne(0), pha], dim=1)
            Thread(target=writer, args=(com, filename.replace(args.images_src, os.path.join(args.output_dir, 'com')).replace('.jpg', '.png'))).start()
        if 'pha' in args.output_types:
            Thread(target=writer, args=(pha, filename.replace(args.images_src, os.path.join(args.output_dir, 'pha')).replace('.png', '.jpg'))).start()
        if 'fgr' in args.output_types:
            Thread(target=writer, args=(fgr, filename.replace(args.images_src, os.path.join(args.output_dir, 'fgr')).replace('.png', '.jpg'))).start()
        if 'err' in args.output_types:
            err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
            Thread(target=writer, args=(err, filename.replace(args.images_src, os.path.join(args.output_dir, 'err')).replace('.png', '.jpg'))).start()
        if 'ref' in args.output_types:
            ref = F.interpolate(ref, src.shape[2:], mode='nearest')
            Thread(target=writer, args=(ref, filename.replace(args.images_src, os.path.join(args.output_dir, 'ref')).replace('.png', '.jpg'))).start()

使用示例如:

python inference_images.py \
        --model-type mattingrefine \ 
        --model-backbone resnet50 \
        --model-backbone-scale 0.25 \
        --model-refine-mode sampling \
        --model-refine-sample-pixels 80000 \
        --model-checkpoint "PATH_TO_CHECKPOINT" \
        --images-src "PATH_TO_IMAGES_SRC_DIR" \
        --images-bgr "PATH_TO_IMAGES_BGR_DIR" \
        --output-dir "PATH_TO_OUTPUT_DIR" \
        --output-type com fgr pha

3. 视频抠图

inference_video.py

对视频进行抠图.

import argparse
import cv2
import torch
import os
import shutil

from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from PIL import Image

from dataset import VideoDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment


# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference video')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)

parser.add_argument('--video-src', type=str, required=True)
parser.add_argument('--video-bgr', type=str, required=True)
parser.add_argument('--video-resize', type=int, default=None, nargs=2)

parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--preprocess-alignment', action='store_true')

parser.add_argument('--output-dir', type=str, required=True)
parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])

args = parser.parse_args()

assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
    'Only mattingbase and mattingrefine support err output'
assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
    'Only mattingrefine support ref output'

# --------------- Utils ---------------
#视频写入类
class VideoWriter:
    def __init__(self, path, frame_rate, width, height):
        self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
        
    def add_batch(self, frames):
        frames = frames.mul(255).byte()
        frames = frames.cpu().permute(0, 2, 3, 1).numpy()
        for i in range(frames.shape[0]):
            frame = frames[i]
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            self.out.write(frame)
            
#图片序列写入类
class ImageSequenceWriter:
    def __init__(self, path, extension):
        self.path = path
        self.extension = extension
        self.index = 0
        os.makedirs(path)
        
    def add_batch(self, frames):
        #多线程
        Thread(target=self._add_batch, args=(frames, self.index)).start()
        self.index += frames.shape[0]
            
    def _add_batch(self, frames, index):
        frames = frames.cpu()
        for i in range(frames.shape[0]):
            frame = frames[i]
            frame = to_pil_image(frame)
            frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))

# --------------- Main ---------------
device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
        args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)

#视频数据集和背景图像加载
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset(
    [vid, bgr], 
    transforms=A.PairCompose([
        A.PairApply(T.Resize(args.video_resize[::-1]) if 
                    args.video_resize else nn.Identity()),
        HomographicAlignment() if args.preprocess_alignment else 
        A.PairApply(nn.Identity()),
        A.PairApply(T.ToTensor())
    ]))

# Create output directory
if os.path.exists(args.output_dir):
    if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
        shutil.rmtree(args.output_dir)
    else:
        exit()
os.makedirs(args.output_dir)

# Prepare writers
if args.output_format == 'video':
    h = args.video_resize[1] if args.video_resize is not None else vid.height
    w = args.video_resize[0] if args.video_resize is not None else vid.width
    if 'com' in args.output_types:
        com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
    if 'pha' in args.output_types:
        pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
    if 'fgr' in args.output_types:
        fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
    if 'err' in args.output_types:
        err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
    if 'ref' in args.output_types:
        ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
else:
    if 'com' in args.output_types:
        com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
    if 'pha' in args.output_types:
        pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
    if 'fgr' in args.output_types:
        fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
    if 'err' in args.output_types:
        err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
    if 'ref' in args.output_types:
        ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
    
# Conversion loop
with torch.no_grad():
    for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
        src = src.to(device, non_blocking=True)
        bgr = bgr.to(device, non_blocking=True)
        
        if args.model_type == 'mattingbase':
            pha, fgr, err, _ = model(src, bgr)
        elif args.model_type == 'mattingrefine':
            pha, fgr, _, _, err, ref = model(src, bgr)
        elif args.model_type == 'mattingbm':
            pha, fgr = model(src, bgr)

        if 'com' in args.output_types:
            if args.output_format == 'video':
                #加绿幕背景输出 Output composite with green background
                bgr_green = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
                com = fgr * pha + bgr_green * (1 - pha)
                com_writer.add_batch(com)
            else:
                #合成rgba图像输出 Output composite as rgba png images
                com = torch.cat([fgr * pha.ne(0), pha], dim=1)
                com_writer.add_batch(com)
        if 'pha' in args.output_types:
            pha_writer.add_batch(pha)
        if 'fgr' in args.output_types:
            fgr_writer.add_batch(fgr)
        if 'err' in args.output_types:
            err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
        if 'ref' in args.output_types:
            ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))

使用示例如:

python inference_video.py \
        --model-type mattingrefine \
        --model-backbone resnet50 \
        --model-backbone-scale 0.25 \
        --model-refine-mode sampling \
        --model-refine-sample-pixels 80000 \
        --model-checkpoint "PATH_TO_CHECKPOINT" \
        --video-src "PATH_TO_VIDEO_SRC" \
        --video-bgr "PATH_TO_VIDEO_BGR" \
        --video-resize 1920 1080 \
        --output-dir "PATH_TO_OUTPUT_DIR" \
        --output-type com fgr pha err ref

4. 网络摄像头抠图

inference_webcam.py

对于网络摄像头输入来使用模型进行抠图.

作者提供的实现中,当网络摄像头启动时,脚本会处在背景采集模式; 按 B 键在背景采集模式和抠图模式间进行切换. 当 B 键按下时的视频帧会作为背景进行抠图.

按 Q 键推出脚本程序.

import argparse, os, shutil, time
import cv2
import torch

from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.transforms.functional import to_pil_image
from threading import Thread, Lock
from tqdm import tqdm
from PIL import Image

from dataset import VideoDataset
from model import MattingBase, MattingRefine


# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference from web-cam')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)

parser.add_argument('--hide-fps', action='store_true')
parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
args = parser.parse_args()

# ----------- Utility classes -------------
# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
# Use .read() in a tight loop to get the newest frame
class Camera:
    def __init__(self, device_id=0, width=1280, height=720):
        self.capture = cv2.VideoCapture(device_id)
        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
        self.success_reading, self.frame = self.capture.read()
        self.read_lock = Lock()
        self.thread = Thread(target=self.__update, args=())
        self.thread.daemon = True
        self.thread.start()

    def __update(self):
        while self.success_reading:
            grabbed, frame = self.capture.read()
            with self.read_lock:
                self.success_reading = grabbed
                self.frame = frame

    def read(self):
        with self.read_lock:
            frame = self.frame.copy()
        return frame
    
    def __exit__(self, exec_type, exc_value, traceback):
        self.capture.release()

# An FPS tracker that computes exponentialy moving average FPS
class FPSTracker:
    def __init__(self, ratio=0.5):
        self._last_tick = None
        self._avg_fps = None
        self.ratio = ratio
        
    def tick(self):
        if self._last_tick is None:
            self._last_tick = time.time()
            return None
        t_new = time.time()
        fps_sample = 1.0 / (t_new - self._last_tick)
        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
        self._last_tick = t_new
        return self.get()
    
    def get(self):
        return self._avg_fps

# Wrapper for playing a stream with cv2.imshow(). 视频流播放的封装函数
# It can accept an image and return keypress info for basic interactivity.
# It also tracks FPS and optionally overlays info onto the stream.
class Displayer:
    def __init__(self, title, width=None, height=None, show_info=True):
        self.title, self.width, self.height = title, width, height
        self.show_info = show_info
        self.fps_tracker = FPSTracker()
        cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
        
        if width is not None and height is not None:
            cv2.resizeWindow(self.title, width, height)
            
    # Update the currently showing frame and return key press char code
    def step(self, image):
        fps_estimate = self.fps_tracker.tick()
        if self.show_info and fps_estimate is not None:
            message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
        cv2.imshow(self.title, image)
        
        return cv2.waitKey(1) & 0xFF


# --------------- Main ---------------
# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold)

model = model.cuda().eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)

width, height = args.resolution
cam = Camera(width=width, height=height)
dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))

def cv2_frame_to_cuda(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()

with torch.no_grad():
    while True:
        bgr = None
        while True: # grab bgr
            frame = cam.read()
            key = dsp.step(frame)
            if key == ord('b'):
                bgr = cv2_frame_to_cuda(cam.read())
                break
            elif key == ord('q'):
                exit()
                
        while True: # matting
            frame = cam.read()
            src = cv2_frame_to_cuda(frame)
            pha, fgr = model(src, bgr)[:2]
            res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
            res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
            key = dsp.step(res)
            if key == ord('b'):
                break
            elif key == ord('q'):
                exit()

使用示例如:

python inference_webcam.py \
        --model-type mattingrefine \
        --model-backbone resnet50 \
        --model-checkpoint "PATH_TO_CHECKPOINT" \
        --resolution 1280 720

5. 抠图速度测试

inference_speed_test.py

测试抠图的速度,主要包括两种:

[1] - 采用随机噪声输入,进行固定计算测试.

[2] - 采用提供的图像输入,进行动态计算测试.

import argparse
import torch
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
from PIL import Image

from model import MattingBase, MattingRefine

# --------------- Arguments ---------------
parser = argparse.ArgumentParser()

parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, default=None)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)

parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--resolution', type=int, default=None, nargs=2)
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')

parser.add_argument('--image-src', type=str, default=None)
parser.add_argument('--image-bgr', type=str, default=None)

args = parser.parse_args()

assert type(args.image_src) == type(args.image_bgr),  'Image source and background must be provided together.'
assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'


# --------------- Run Loop ---------------
device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
        args.model_refine_kernel_size,
        refine_prevent_oversampling=False)

if args.model_checkpoint:
    model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
    
if args.precision == 'float32':
    precision = torch.float32
else:
    precision = torch.float16
    
if args.backend == 'torchscript':
    model = torch.jit.script(model)

model = model.eval().to(device=device, dtype=precision)

# Load data
if not args.image_src:
    src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
    bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
else:
    src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
    bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
    
# Loop
with torch.no_grad():
    for _ in tqdm(range(1000)):
        model(src, bgr)

使用示例如:

#Run inference on random noise input for fixed computation setting.
#(i.e. mode in ['full', 'sampling'])
python inference_speed_test.py \
        --model-type mattingrefine \
        --model-backbone resnet50 \
        --model-backbone-scale 0.25 \
        --model-refine-mode sampling \
        --model-refine-sample-pixels 80000 \
        --batch-size 1 \
        --resolution 1920 1080 \
        --backend pytorch \
        --precision float32

#Run inference on provided image input for dynamic computation setting.
#(i.e. mode in ['thresholding'])
python inference_speed_test.py \
        --model-type mattingrefine \
        --model-backbone resnet50 \
        --model-backbone-scale 0.25 \
        --model-checkpoint "PATH_TO_CHECKPOINT" \
        --model-refine-mode thresholding \
        --model-refine-threshold 0.7 \
        --batch-size 1 \
        --backend pytorch \
        --precision float32 \
        --image-src "PATH_TO_IMAGE_SRC" \
        --image-bgr "PATH_TO_IMAGE_BGR"
Last modification:December 29th, 2020 at 10:54 am