竞赛:MICCAI 2017 Robotic Instrument Segmentation Sub-Challenge
基于 U-Net,TernausNet-11, TernausNet-16,LinkNet-34 的高分辨率语义分割.

Github项目 - robot-surgery-segmentation
Paper - Automatic Instrument Segmentation in Robot-Assisted Surgery Using Deep Learning - 2018

该项目主要描述对于竞赛 [MICCAI 2017 Endoscopic Vision Sub-Challenge: Robotic Instrument Segmentation] 的解决方案. 原始方案基于 U-Net 网络结构,采用 LinkNet 和 TernausNet 来提升了语义分割结果.
对于二类(binary) 和多类(multi-class) 的机器人仪器分割(robotic instrument segmentation) 取得优秀的分割结果.
对于医学的手术场景,能够很好的进行追踪和姿态估计.

概述

机器人仪器的语义分割是在机器人辅助医学领域中很重要的研究问题.
一个最重要的挑战是,手术场景中仪器位置的正确检测,以进行追踪和姿态估计.
精确的像素级仪器分割有待解决.
该项目采用神经网络结构取得了较为优秀的分割表现.
其主要处理二值语义分割问题,手术视频流中,图像每个像素被标注为仪器类或背景类(非仪器类).
另外,还解决多类分割问题(multi-class segmentation),对不同的机器人仪器和一个仪器的不同部分(包括背景类) 进行语义分割区分.
该项目的解决方案中,在机器人仪器分割的每个子任务中均取得了突出表现.

数据

训练数据集包括 8x255-帧高分辨率相机图像序列,由 da Vinci Xi 外科系统 在几个不同的手术场景中捕捉得到.
训练数据集序列以 2 Hz 帧率,避免冗余.
每个视频序列有两个立体通道组成,分别有左相机和右相机拍摄,为 1920x1080 分辨图的 RGB 图像.
机器人医学仪器的铰链部分,例如 rigid shaft(刚性轴),articulated wrist(铰接手腕) 和 claspers(扣环),分别在每一帧进行了手工标注.
此外,还包括仪器类型标签,将仪器划分为几类:left/right prograsp forceps, monopolar curved scissors, large needle driver, 和 a miscellaneous category for any other surgical instruments(其它手术器械的杂类).


原图序列 - 二值分割(2 类) - Parts(3 类) - Instruments(7 类)

方法

项目采用了 4 个不同的语义分割网络:U-NetTernausNet 的两个改进网络,LinkNet 的改进网络.
语义分割网络模型的输出是像素级的 mask,分别对每一个像素进行了分类.
MICCAI 2017 Endoscopic Vision Sub-Challenge 提交的版本采用的是 U-Net 网络的轻微改进版本.

作为 U-Net 的提升版,采用了类似的网络作为预训练的编码器(encoders). TernausNet 是 U-Net 类型的网络结构,采用了相对简单的预训练 VGG11 和 CGG16 网络作为编码器.

LinkNet 模型采用 ResNet 类型的网络结构作为编码器. 这里采用了预训练的 ResNet34.
网络编码器包括几个解码器(decoder) 模块,每个解码器模块与对应的解码器模块相连接. 每个编码器模块包含 1x1 卷积层,以减少 filters 的数量为 1/4,其后接 Batch Normalization 和 Transposed Convolution 层,以上采样 feature map:

训练

采用 Jaccard Index(Intersection Over Union,IoU) 作为评价度量. 其可以解释为,有限集间的相似度量.
对于两个集 A 和 B,Jaccard Index 定义为:

由于图像有像素组成,Jaccard Index 可以转换为:

${y_i}$ 和 ${\hat{y}_i}$ 分别为第 i 个像素的 groundtruth 二值标签和预测的概率值.

由于图像语义分割任务还可以看作像素分类问题,因此,新增了通用分类 loss 函数,记为 H.
对于二值分割问题,H 是二值交叉熵损失函数(binary cross entropy).
对于多类(multi-class) 分割问题,H 是类别交叉熵损失函数(categorical cross entropy).


作为模型的输出,会得到一张 mask 图片,其每一个像素值对应了属于某个类别的概率值.
输出图片尺寸与输入图片尺寸相匹配.
对于二值分割,采用 0.3 作为阈值来进行二值化像素概率值. 小于阈值 0.3 的像素设为 0. 大于阈值的像素设为 255,作为最终的预测 mask.
对于 multi-class 分割,采用类似的方法,但是对于每一类设置为不同的整数.

Results

对于二值分割,最佳的结果是 TernausNet-16,其 IoU=0.836,Dice=0.901.

预训练的模型 - models-googledrive

项目运行

依赖项:

  • Python 3.6
  • PyTorch 0.4.0
  • TorchVision 0.2.1
  • numpy 1.14.0
  • opencv-python 3.3.0.10
  • tqdm 4.19.4

如:

    sudo pip install pytorch==0.4.0

数据集组织

数据集文件夹形式:

 ├── data
    │   ├── cropped_train
    │   ├── models
    │   ├── test
    │   │   ├── instrument_dataset_1
    │   │   │   ├── left_frames
    │   │   │   └── right_frames
    |   |   .......................
    │   └── train
    │       ├── instrument_dataset_1
    │       │   ├── ground_truth
    │       │   │   ├── Left_Prograsp_Forceps_labels
    │       │   │   ├── Maryland_Bipolar_Forceps_labels
    │       │   │   ├── Other_labels
    │       │   │   └── Right_Prograsp_Forceps_labels
    │       │   ├── left_frames
    │       │   └── right_frames
    │       .......................

训练数据集只包含 8 个 videos,每个 video 包含 225 帧. 在每个 video 内,所有的帧是相关的,因此,实验中采用 4-fold 交叉验证.

数据预处理

  • prepare_data.py,合并不同 instruments 标注的 masks 为一个二值 mask;从图像和对应的 masks 裁剪空白边界.
  • prepare_train_val.py,数据集划分,4-fold 交叉验证.

模型训练.

python3 train.py --help 返回所有的可输入参数.

    import argparse
    import json
    from pathlib import Path
    from validation import validation_binary, validation_multi

    import torch
    from torch import nn
    from torch.optim import Adam
    from torch.utils.data import DataLoader
    import torch.backends.cudnn as cudnn
    import torch.backends.cudnn

    from models import UNet11, LinkNet34, UNet, UNet16, AlbuNet  ## 网络结构
    from loss import LossBinary, LossMulti  ## Loss 函数
    from dataset import RoboticsDataset   ## 数据集类
    import utils

    from prepare_train_val import get_split

    from albumentations import (
        HorizontalFlip, VerticalFlip, Normalize, Compose )  ## 数据增强库

    def main():
        parser = argparse.ArgumentParser()
        arg = parser.add_argument
        arg('--jaccard-weight', default=0.5, type=float)
        arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs')
        arg('--fold', type=int, help='fold', default=0)
        arg('--root', default='runs/debug', help='checkpoint root')
        arg('--batch-size', type=int, default=1)
        arg('--n-epochs', type=int, default=100)
        arg('--lr', type=float, default=0.0001)
        arg('--workers', type=int, default=12)
        arg('--type', type=str, default='binary', choices=['binary', 'parts', 'instruments'])
        arg('--model', type=str, default='UNet', choices=['UNet', 'UNet11', 'LinkNet34', 'AlbuNet'])

        args = parser.parse_args()

        root = Path(args.root)
        root.mkdir(exist_ok=True, parents=True)

        if args.type == 'parts':
            num_classes = 4
        elif args.type == 'instruments':
            num_classes = 8
        else:
            num_classes = 1

        if args.model == 'UNet':
            model = UNet(num_classes=num_classes)
        elif args.model == 'UNet11':
            model = UNet11(num_classes=num_classes, pretrained=True)
        elif args.model == 'UNet16':
            model = UNet16(num_classes=num_classes, pretrained=True)
        elif args.model == 'LinkNet34':
            model = LinkNet34(num_classes=num_classes, pretrained=True)
        elif args.model == 'AlbuNet':
            model = AlbuNet(num_classes=num_classes, pretrained=True)
        else:
            model = UNet(num_classes=num_classes, input_channels=3)

        if torch.cuda.is_available():
            if args.device_ids:
                device_ids = list(map(int, args.device_ids.split(',')))
            else:
                device_ids = None
            model = nn.DataParallel(model, device_ids=device_ids).cuda()

        if args.type == 'binary':
            loss = LossBinary(jaccard_weight=args.jaccard_weight)
        else:
            loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight)

        cudnn.benchmark = True

        def make_loader(file_names, shuffle=False, transform=None, problem_type='binary', batch_size=1):
            return DataLoader(
                dataset=RoboticsDataset(file_names, transform=transform, problem_type=problem_type),
                shuffle=shuffle,
                num_workers=args.workers,
                batch_size=batch_size,
                pin_memory=torch.cuda.is_available()
            )

        train_file_names, val_file_names = get_split(args.fold)

        print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names)))

        def train_transform(p=1):
            return Compose([VerticalFlip(p=0.5), HorizontalFlip(p=0.5),
                            Normalize(p=1)], p=p)

        def val_transform(p=1):
            return Compose([Normalize(p=1)], p=p)

        train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(p=1), 
                                   problem_type=args.type, batch_size=args.batch_size)
        valid_loader = make_loader(val_file_names, transform=val_transform(p=1), 
                                   problem_type=args.type, batch_size=len(device_ids))

        root.joinpath('params.json').write_text(
            json.dumps(vars(args), indent=True, sort_keys=True))

        if args.type == 'binary':
            valid = validation_binary
        else:
            valid = validation_multi

        utils.train(
            init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
            args=args,
            model=model,
            criterion=loss,
            train_loader=train_loader,
            valid_loader=valid_loader,
            validation=valid,
            fold=args.fold,
            num_classes=num_classes
        )

    if __name__ == '__main__':
        main()
    #!/bin/bash
    for i in 0 1 2 3
    do
        python train.py --device-ids 0,1,2,3 --batch-size 16 --fold $i --workers 12 --lr 0.0001 --n-epochs 10 --type binary --jaccard-weight 1
        python train.py --device-ids 0,1,2,3 --batch-size 16 --fold $i --workers 12 --lr 0.00001 --n-epochs 20 --type binary --jaccard-weight 1
    done

生成 Masks

python3 generate_masks.py --help 返回可设置输入参数.
如:

python3 generate_masks.py --output_path predictions/unet16/binary --model_type UNet16 --problem_type binary --model_path data/models/unet16_binary_20 --fold -1 --batch-size 4

结果评估 Evaluation

  • [1] - 二值语义分割,针对每张图像或每个 video,计算 Jaccard(Dice);然后平均化预测结果.
  • [2] - 多类语义分割,先对每张图像独立计算每类的 Jaccard(Dice)再求平均化;然后针对每个 video 计算.
python3 evaluate.py --target_path predictions/unet16 --problem_type binary --train_path data/cropped_train

Demo 实例

    import cv2
    import numpy as np
    import matplotlib.pyplot as plt

    import torch
    from utils import cuda

    from dataset import load_image
    from generate_masks import get_model

    from albumentations import Compose, Normalize
    from albumentations.torch.functional import img_to_tensor

    def img_transform(p=1):
        return Compose([Normalize(p=1)], p=p)

    def mask_overlay(image, mask, color=(0, 255, 0)):
        """
        Helper function to visualize mask on the top of the car
        """
        mask = np.dstack((mask, mask, mask)) * np.array(color)
        mask = mask.astype(np.uint8)
        weighted_sum = cv2.addWeighted(mask, 0.5, image, 0.5, 0.)
        img = image.copy()
        ind = mask[:, :, 1]
        img[ind] = weighted_sum[ind]
        return img

    model_path = 'models/unet11_binary_20/model_0.pt'
    model = get_model(model_path, model_type='UNet11', problem_type='binary')

    img_file_name = 'data/cropped_train/instrument_dataset_3/images/frame004.jpg'
    gt_file_name = 'data/cropped_train/instrument_dataset_3/binary_masks/frame004.png'

    image = load_image(img_file_name)
    gt = cv2.imread(gt_file_name, 0)

    plt.imshow(image)

    with torch.no_grad():
        input_image = torch.unsqueeze(img_to_tensor(img_transform(p=1)(image=image)['image']).cuda(), dim=0)

    mask = model(input_image)
    mask_array = mask.data[0].cpu().numpy()[0]
    plt.imshow(mask_array)

    plt.imshow(mask_overlay(image, (mask_array).astype(np.uint8)))
    plt.show()

    print('Done.')
Last modification:October 9th, 2018 at 05:38 pm