BackgroundMattingV2 提供了对图像、视频、网络摄像头的抠图实现,以及支持导出 ONNX 和 Torchscript.
1. HomographicAlignment 图像对齐
图像对齐函数,用于将背景与源图的匹配.
import numpy as np
import cv2
from PIL import Image
class HomographicAlignment:
"""
Apply homographic alignment on background to match with the source image.
"""
def __init__(self):
self.detector = cv2.ORB_create() #
'''
ORB 采用FAST(features from accelerated segment test)算法来检测特征点.
FAST核心思想是, 找出那些卓尔不群的点,即拿一个点跟它周围的点比较,
如果它和其中大部分的点都不一样就可以认为它是一个特征点.
'''
self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
'''
描述符匹配,特征匹配
'''
def __call__(self, src, bgr):
src = np.asarray(src)
bgr = np.asarray(bgr)
keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
matches.sort(key=lambda x: x.distance, reverse=False)
num_good_matches = int(len(matches) * 0.15)
matches = matches[:num_good_matches]
points_src = np.zeros((len(matches), 2), dtype=np.float32)
points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
for i, match in enumerate(matches):
points_src[i, :] = keypoints_src[match.trainIdx].pt
points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
#
H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
h, w = src.shape[:2]
bgr = cv2.warpPerspective(bgr, H, (w, h)) #透视变换
msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
# For areas that is outside of the background,
# We just copy pixels from the source.
bgr[msk != 1] = src[msk != 1]
src = Image.fromarray(src)
bgr = Image.fromarray(bgr)
return src, bgr
关于图像对齐可参考:
[1] - Image Alignment (Feature Based) using OpenCV (C++/Python)
[2] - 使用OpenCV实现基于特征的图像对齐
2. 图像抠图
对图像进行抠图(多线程).
import argparse
import torch
import os
import shutil
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from dataset import ImagesDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference images')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--images-src', type=str, required=True)
parser.add_argument('--images-bgr', type=str, required=True)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--preprocess-alignment', action='store_true')
parser.add_argument('--output-dir', type=str, required=True)
parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
parser.add_argument('-y', action='store_true')
args = parser.parse_args()
assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
'Only mattingbase and mattingrefine support err output'
assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
'Only mattingrefine support ref output'
# --------------- Main ---------------
device = torch.device(args.device)
#模型加载
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)
model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
# 加载图像数据集
dataset = ZipDataset([
ImagesDataset(args.images_src),
ImagesDataset(args.images_bgr),],
assert_equal_length=True,
transforms=A.PairCompose([HomographicAlignment() if
args.preprocess_alignment else
A.PairApply(nn.Identity()),
A.PairApply(T.ToTensor())
]))
dataloader = DataLoader(dataset, batch_size=1, num_workers=8, pin_memory=True)
# Create output directory
if os.path.exists(args.output_dir):
if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
shutil.rmtree(args.output_dir)
else:
exit()
for output_type in args.output_types:
os.makedirs(os.path.join(args.output_dir, output_type))
# Worker function
def writer(img, path):
img = to_pil_image(img[0].cpu())
img.save(path)
# Conversion loop
with torch.no_grad():
for i, (src, bgr) in enumerate(tqdm(dataloader)):
filename = dataset.datasets[0].filenames[i]
src = src.to(device, non_blocking=True)
bgr = bgr.to(device, non_blocking=True)
if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
elif args.model_type == 'mattingbm':
pha, fgr = model(src, bgr)
if 'com' in args.output_types:
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
Thread(target=writer, args=(com, filename.replace(args.images_src, os.path.join(args.output_dir, 'com')).replace('.jpg', '.png'))).start()
if 'pha' in args.output_types:
Thread(target=writer, args=(pha, filename.replace(args.images_src, os.path.join(args.output_dir, 'pha')).replace('.png', '.jpg'))).start()
if 'fgr' in args.output_types:
Thread(target=writer, args=(fgr, filename.replace(args.images_src, os.path.join(args.output_dir, 'fgr')).replace('.png', '.jpg'))).start()
if 'err' in args.output_types:
err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
Thread(target=writer, args=(err, filename.replace(args.images_src, os.path.join(args.output_dir, 'err')).replace('.png', '.jpg'))).start()
if 'ref' in args.output_types:
ref = F.interpolate(ref, src.shape[2:], mode='nearest')
Thread(target=writer, args=(ref, filename.replace(args.images_src, os.path.join(args.output_dir, 'ref')).replace('.png', '.jpg'))).start()
使用示例如:
python inference_images.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--images-src "PATH_TO_IMAGES_SRC_DIR" \
--images-bgr "PATH_TO_IMAGES_BGR_DIR" \
--output-dir "PATH_TO_OUTPUT_DIR" \
--output-type com fgr pha
3. 视频抠图
对视频进行抠图.
import argparse
import cv2
import torch
import os
import shutil
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from PIL import Image
from dataset import VideoDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference video')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--video-src', type=str, required=True)
parser.add_argument('--video-bgr', type=str, required=True)
parser.add_argument('--video-resize', type=int, default=None, nargs=2)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--preprocess-alignment', action='store_true')
parser.add_argument('--output-dir', type=str, required=True)
parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])
args = parser.parse_args()
assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
'Only mattingbase and mattingrefine support err output'
assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
'Only mattingrefine support ref output'
# --------------- Utils ---------------
#视频写入类
class VideoWriter:
def __init__(self, path, frame_rate, width, height):
self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
def add_batch(self, frames):
frames = frames.mul(255).byte()
frames = frames.cpu().permute(0, 2, 3, 1).numpy()
for i in range(frames.shape[0]):
frame = frames[i]
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
self.out.write(frame)
#图片序列写入类
class ImageSequenceWriter:
def __init__(self, path, extension):
self.path = path
self.extension = extension
self.index = 0
os.makedirs(path)
def add_batch(self, frames):
#多线程
Thread(target=self._add_batch, args=(frames, self.index)).start()
self.index += frames.shape[0]
def _add_batch(self, frames, index):
frames = frames.cpu()
for i in range(frames.shape[0]):
frame = frames[i]
frame = to_pil_image(frame)
frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
# --------------- Main ---------------
device = torch.device(args.device)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)
model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
#视频数据集和背景图像加载
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset(
[vid, bgr],
transforms=A.PairCompose([
A.PairApply(T.Resize(args.video_resize[::-1]) if
args.video_resize else nn.Identity()),
HomographicAlignment() if args.preprocess_alignment else
A.PairApply(nn.Identity()),
A.PairApply(T.ToTensor())
]))
# Create output directory
if os.path.exists(args.output_dir):
if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
shutil.rmtree(args.output_dir)
else:
exit()
os.makedirs(args.output_dir)
# Prepare writers
if args.output_format == 'video':
h = args.video_resize[1] if args.video_resize is not None else vid.height
w = args.video_resize[0] if args.video_resize is not None else vid.width
if 'com' in args.output_types:
com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
if 'pha' in args.output_types:
pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
if 'fgr' in args.output_types:
fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
if 'err' in args.output_types:
err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
if 'ref' in args.output_types:
ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
else:
if 'com' in args.output_types:
com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
if 'pha' in args.output_types:
pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
if 'fgr' in args.output_types:
fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
if 'err' in args.output_types:
err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
if 'ref' in args.output_types:
ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
# Conversion loop
with torch.no_grad():
for src, bgr in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
src = src.to(device, non_blocking=True)
bgr = bgr.to(device, non_blocking=True)
if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
elif args.model_type == 'mattingbm':
pha, fgr = model(src, bgr)
if 'com' in args.output_types:
if args.output_format == 'video':
#加绿幕背景输出 Output composite with green background
bgr_green = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
com = fgr * pha + bgr_green * (1 - pha)
com_writer.add_batch(com)
else:
#合成rgba图像输出 Output composite as rgba png images
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
com_writer.add_batch(com)
if 'pha' in args.output_types:
pha_writer.add_batch(pha)
if 'fgr' in args.output_types:
fgr_writer.add_batch(fgr)
if 'err' in args.output_types:
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
if 'ref' in args.output_types:
ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
使用示例如:
python inference_video.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--video-src "PATH_TO_VIDEO_SRC" \
--video-bgr "PATH_TO_VIDEO_BGR" \
--video-resize 1920 1080 \
--output-dir "PATH_TO_OUTPUT_DIR" \
--output-type com fgr pha err ref
4. 网络摄像头抠图
对于网络摄像头输入来使用模型进行抠图.
作者提供的实现中,当网络摄像头启动时,脚本会处在背景采集模式; 按 B 键在背景采集模式和抠图模式间进行切换. 当 B 键按下时的视频帧会作为背景进行抠图.
按 Q 键推出脚本程序.
import argparse, os, shutil, time
import cv2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.transforms.functional import to_pil_image
from threading import Thread, Lock
from tqdm import tqdm
from PIL import Image
from dataset import VideoDataset
from model import MattingBase, MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference from web-cam')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--hide-fps', action='store_true')
parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
args = parser.parse_args()
# ----------- Utility classes -------------
# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
# Use .read() in a tight loop to get the newest frame
class Camera:
def __init__(self, device_id=0, width=1280, height=720):
self.capture = cv2.VideoCapture(device_id)
self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
self.success_reading, self.frame = self.capture.read()
self.read_lock = Lock()
self.thread = Thread(target=self.__update, args=())
self.thread.daemon = True
self.thread.start()
def __update(self):
while self.success_reading:
grabbed, frame = self.capture.read()
with self.read_lock:
self.success_reading = grabbed
self.frame = frame
def read(self):
with self.read_lock:
frame = self.frame.copy()
return frame
def __exit__(self, exec_type, exc_value, traceback):
self.capture.release()
# An FPS tracker that computes exponentialy moving average FPS
class FPSTracker:
def __init__(self, ratio=0.5):
self._last_tick = None
self._avg_fps = None
self.ratio = ratio
def tick(self):
if self._last_tick is None:
self._last_tick = time.time()
return None
t_new = time.time()
fps_sample = 1.0 / (t_new - self._last_tick)
self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
self._last_tick = t_new
return self.get()
def get(self):
return self._avg_fps
# Wrapper for playing a stream with cv2.imshow(). 视频流播放的封装函数
# It can accept an image and return keypress info for basic interactivity.
# It also tracks FPS and optionally overlays info onto the stream.
class Displayer:
def __init__(self, title, width=None, height=None, show_info=True):
self.title, self.width, self.height = title, width, height
self.show_info = show_info
self.fps_tracker = FPSTracker()
cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
if width is not None and height is not None:
cv2.resizeWindow(self.title, width, height)
# Update the currently showing frame and return key press char code
def step(self, image):
fps_estimate = self.fps_tracker.tick()
if self.show_info and fps_estimate is not None:
message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
cv2.imshow(self.title, image)
return cv2.waitKey(1) & 0xFF
# --------------- Main ---------------
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold)
model = model.cuda().eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
width, height = args.resolution
cam = Camera(width=width, height=height)
dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
def cv2_frame_to_cuda(frame):
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
with torch.no_grad():
while True:
bgr = None
while True: # grab bgr
frame = cam.read()
key = dsp.step(frame)
if key == ord('b'):
bgr = cv2_frame_to_cuda(cam.read())
break
elif key == ord('q'):
exit()
while True: # matting
frame = cam.read()
src = cv2_frame_to_cuda(frame)
pha, fgr = model(src, bgr)[:2]
res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
key = dsp.step(res)
if key == ord('b'):
break
elif key == ord('q'):
exit()
使用示例如:
python inference_webcam.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--resolution 1280 720
5. 抠图速度测试
测试抠图的速度,主要包括两种:
[1] - 采用随机噪声输入,进行固定计算测试.
[2] - 采用提供的图像输入,进行动态计算测试.
import argparse
import torch
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
from PIL import Image
from model import MattingBase, MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser()
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, default=None)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--resolution', type=int, default=None, nargs=2)
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--image-src', type=str, default=None)
parser.add_argument('--image-bgr', type=str, default=None)
args = parser.parse_args()
assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.'
assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'
# --------------- Run Loop ---------------
device = torch.device(args.device)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size,
refine_prevent_oversampling=False)
if args.model_checkpoint:
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
if args.precision == 'float32':
precision = torch.float32
else:
precision = torch.float16
if args.backend == 'torchscript':
model = torch.jit.script(model)
model = model.eval().to(device=device, dtype=precision)
# Load data
if not args.image_src:
src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
else:
src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
# Loop
with torch.no_grad():
for _ in tqdm(range(1000)):
model(src, bgr)
使用示例如:
#Run inference on random noise input for fixed computation setting.
#(i.e. mode in ['full', 'sampling'])
python inference_speed_test.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--batch-size 1 \
--resolution 1920 1080 \
--backend pytorch \
--precision float32
#Run inference on provided image input for dynamic computation setting.
#(i.e. mode in ['thresholding'])
python inference_speed_test.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--model-refine-mode thresholding \
--model-refine-threshold 0.7 \
--batch-size 1 \
--backend pytorch \
--precision float32 \
--image-src "PATH_TO_IMAGE_SRC" \
--image-bgr "PATH_TO_IMAGE_BGR"