Shortcuts

Source code for mmdet.models.losses.gfocal_loss

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weighted_loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def quality_focal_loss(pred, target, beta=2.0):
    r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
    Qualified and Distributed Bounding Boxes for Dense Object Detection
    <https://arxiv.org/abs/2006.04388>`_.

    Args:
        pred (torch.Tensor): Predicted joint representation of classification
            and quality (IoU) estimation with shape (N, C), C is the number of
            classes.
        target (tuple([torch.Tensor])): Target category label with shape (N,)
            and target quality label with shape (N,).
        beta (float): The beta parameter for calculating the modulating factor.
            Defaults to 2.0.

    Returns:
        torch.Tensor: Loss tensor with shape (N,).
    """
    assert len(target) == 2, """target for QFL must be a tuple of two elements,
        including category label and quality label, respectively"""
    # label denotes the category id, score denotes the quality score
    label, score = target

    # negatives are supervised by 0 quality score
    pred_sigmoid = pred.sigmoid()
    scale_factor = pred_sigmoid
    zerolabel = scale_factor.new_zeros(pred.shape)
    loss = F.binary_cross_entropy_with_logits(
        pred, zerolabel, reduction='none') * scale_factor.pow(beta)

    # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
    bg_class_ind = pred.size(1)
    pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
    pos_label = label[pos].long()
    # positives are supervised by bbox quality (IoU) score
    scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
    loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
        pred[pos, pos_label], score[pos],
        reduction='none') * scale_factor.abs().pow(beta)

    loss = loss.sum(dim=1, keepdim=False)
    return loss


@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def distribution_focal_loss(pred, label):
    r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
    Qualified and Distributed Bounding Boxes for Dense Object Detection
    <https://arxiv.org/abs/2006.04388>`_.

    Args:
        pred (torch.Tensor): Predicted general distribution of bounding boxes
            (before softmax) with shape (N, n+1), n is the max value of the
            integral set `{0, ..., n}` in paper.
        label (torch.Tensor): Target distance label for bounding boxes with
            shape (N,).

    Returns:
        torch.Tensor: Loss tensor with shape (N,).
    """
    dis_left = label.long()
    dis_right = dis_left + 1
    weight_left = dis_right.float() - label
    weight_right = label - dis_left.float()
    loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
        + F.cross_entropy(pred, dis_right, reduction='none') * weight_right
    return loss


[docs]@LOSSES.register_module() class QualityFocalLoss(nn.Module): r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection <https://arxiv.org/abs/2006.04388>`_. Args: use_sigmoid (bool): Whether sigmoid operation is conducted in QFL. Defaults to True. beta (float): The beta parameter for calculating the modulating factor. Defaults to 2.0. reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Loss weight of current loss. """ def __init__(self, use_sigmoid=True, beta=2.0, reduction='mean', loss_weight=1.0): super(QualityFocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' self.use_sigmoid = use_sigmoid self.beta = beta self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (torch.Tensor): Predicted joint representation of classification and quality (IoU) estimation with shape (N, C), C is the number of classes. target (tuple([torch.Tensor])): Target category label with shape (N,) and target quality label with shape (N,). weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: loss_cls = self.loss_weight * quality_focal_loss( pred, target, weight, beta=self.beta, reduction=reduction, avg_factor=avg_factor) else: raise NotImplementedError return loss_cls
[docs]@LOSSES.register_module() class DistributionFocalLoss(nn.Module): r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection <https://arxiv.org/abs/2006.04388>`_. Args: reduction (str): Options are `'none'`, `'mean'` and `'sum'`. loss_weight (float): Loss weight of current loss. """ def __init__(self, reduction='mean', loss_weight=1.0): super(DistributionFocalLoss, self).__init__() self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: pred (torch.Tensor): Predicted general distribution of bounding boxes (before softmax) with shape (N, n+1), n is the max value of the integral set `{0, ..., n}` in paper. target (torch.Tensor): Target distance label for bounding boxes with shape (N,). weight (torch.Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_cls = self.loss_weight * distribution_focal_loss( pred, target, weight, reduction=reduction, avg_factor=avg_factor) return loss_cls
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.