From:Github - youngwanLEE/centermask2
基于 Pytorch 的IOULoss 的实现.(比较清晰明了)
import torch
from torch import nn
class IOULoss(nn.Module):
def __init__(self, loc_loss_type='iou'):
super(IOULoss, self).__init__()
self.loc_loss_type = loc_loss_type
def forward(self, pred, target, weight=None):
pred_left = pred[:, 0]
pred_top = pred[:, 1]
pred_right = pred[:, 2]
pred_bottom = pred[:, 3]
target_left = target[:, 0]
target_top = target[:, 1]
target_right = target[:, 2]
target_bottom = target[:, 3]
target_aera = (target_left + target_right) * \
(target_top + target_bottom)
pred_aera = (pred_left + pred_right) * \
(pred_top + pred_bottom)
w_intersect = torch.min(pred_left, target_left) + \
torch.min(pred_right, target_right)
h_intersect = torch.min(pred_bottom, target_bottom) + \
torch.min(pred_top, target_top)
g_w_intersect = torch.max(pred_left, target_left) + \
torch.max(pred_right, target_right)
g_h_intersect = torch.max(pred_bottom, target_bottom) + \
torch.max(pred_top, target_top)
ac_uion = g_w_intersect * g_h_intersect
area_intersect = w_intersect * h_intersect
area_union = target_aera + pred_aera - area_intersect
ious = (area_intersect + 1.0) / (area_union + 1.0)
gious = ious - (ac_uion - area_union) / ac_uion
if self.loc_loss_type == 'iou':
losses = -torch.log(ious)
elif self.loc_loss_type == 'linear_iou':
losses = 1 - ious
elif self.loc_loss_type == 'giou':
losses = 1 - gious
else:
raise NotImplementedError
if weight is not None:
return (losses * weight).sum()
else:
return losses.sum()