Shortcuts

Source code for mmdet.models.losses.seesaw_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .accuracy import accuracy
from .cross_entropy_loss import cross_entropy
from .utils import weight_reduce_loss


def seesaw_ce_loss(cls_score,
                   labels,
                   label_weights,
                   cum_samples,
                   num_classes,
                   p,
                   q,
                   eps,
                   reduction='mean',
                   avg_factor=None):
    """Calculate the Seesaw CrossEntropy loss.

    Args:
        cls_score (torch.Tensor): The prediction with shape (N, C),
             C is the number of classes.
        labels (torch.Tensor): The learning label of the prediction.
        label_weights (torch.Tensor): Sample-wise loss weight.
        cum_samples (torch.Tensor): Cumulative samples for each category.
        num_classes (int): The number of classes.
        p (float): The ``p`` in the mitigation factor.
        q (float): The ``q`` in the compenstation factor.
        eps (float): The minimal value of divisor to smooth
             the computation of compensation factor
        reduction (str, optional): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.

    Returns:
        torch.Tensor: The calculated loss
    """
    assert cls_score.size(-1) == num_classes
    assert len(cum_samples) == num_classes

    onehot_labels = F.one_hot(labels, num_classes)
    seesaw_weights = cls_score.new_ones(onehot_labels.size())

    # mitigation factor
    if p > 0:
        sample_ratio_matrix = cum_samples[None, :].clamp(
            min=1) / cum_samples[:, None].clamp(min=1)
        index = (sample_ratio_matrix < 1.0).float()
        sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
        mitigation_factor = sample_weights[labels.long(), :]
        seesaw_weights = seesaw_weights * mitigation_factor

    # compensation factor
    if q > 0:
        scores = F.softmax(cls_score.detach(), dim=1)
        self_scores = scores[
            torch.arange(0, len(scores)).to(scores.device).long(),
            labels.long()]
        score_matrix = scores / self_scores[:, None].clamp(min=eps)
        index = (score_matrix > 1.0).float()
        compensation_factor = score_matrix.pow(q) * index + (1 - index)
        seesaw_weights = seesaw_weights * compensation_factor

    cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))

    loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')

    if label_weights is not None:
        label_weights = label_weights.float()
    loss = weight_reduce_loss(
        loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor)
    return loss


[docs]@LOSSES.register_module() class SeesawLoss(nn.Module): """ Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) arXiv: https://arxiv.org/abs/2008.10032 Args: use_sigmoid (bool, optional): Whether the prediction uses sigmoid of softmax. Only False is supported. p (float, optional): The ``p`` in the mitigation factor. Defaults to 0.8. q (float, optional): The ``q`` in the compenstation factor. Defaults to 2.0. num_classes (int, optional): The number of classes. Default to 1203 for LVIS v1 dataset. eps (float, optional): The minimal value of divisor to smooth the computation of compensation factor reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of the loss. Defaults to 1.0 return_dict (bool, optional): Whether return the losses as a dict. Default to True. """ def __init__(self, use_sigmoid=False, p=0.8, q=2.0, num_classes=1203, eps=1e-2, reduction='mean', loss_weight=1.0, return_dict=True): super(SeesawLoss, self).__init__() assert not use_sigmoid self.use_sigmoid = False self.p = p self.q = q self.num_classes = num_classes self.eps = eps self.reduction = reduction self.loss_weight = loss_weight self.return_dict = return_dict # 0 for pos, 1 for neg self.cls_criterion = seesaw_ce_loss # cumulative samples for each category self.register_buffer( 'cum_samples', torch.zeros(self.num_classes + 1, dtype=torch.float)) # custom output channels of the classifier self.custom_cls_channels = True # custom activation of cls_score self.custom_activation = True # custom accuracy of the classsifier self.custom_accuracy = True def _split_cls_score(self, cls_score): # split cls_score to cls_score_classes and cls_score_objectness assert cls_score.size(-1) == self.num_classes + 2 cls_score_classes = cls_score[..., :-2] cls_score_objectness = cls_score[..., -2:] return cls_score_classes, cls_score_objectness
[docs] def get_cls_channels(self, num_classes): """Get custom classification channels. Args: num_classes (int): The number of classes. Returns: int: The custom classification channels. """ assert num_classes == self.num_classes return num_classes + 2
[docs] def get_activation(self, cls_score): """Get custom activation of cls_score. Args: cls_score (torch.Tensor): The prediction with shape (N, C + 2). Returns: torch.Tensor: The custom activation of cls_score with shape (N, C + 1). """ cls_score_classes, cls_score_objectness = self._split_cls_score( cls_score) score_classes = F.softmax(cls_score_classes, dim=-1) score_objectness = F.softmax(cls_score_objectness, dim=-1) score_pos = score_objectness[..., [0]] score_neg = score_objectness[..., [1]] score_classes = score_classes * score_pos scores = torch.cat([score_classes, score_neg], dim=-1) return scores
[docs] def get_accuracy(self, cls_score, labels): """Get custom accuracy w.r.t. cls_score and labels. Args: cls_score (torch.Tensor): The prediction with shape (N, C + 2). labels (torch.Tensor): The learning label of the prediction. Returns: Dict [str, torch.Tensor]: The accuracy for objectness and classes, respectively. """ pos_inds = labels < self.num_classes obj_labels = (labels == self.num_classes).long() cls_score_classes, cls_score_objectness = self._split_cls_score( cls_score) acc_objectness = accuracy(cls_score_objectness, obj_labels) acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds]) acc = dict() acc['acc_objectness'] = acc_objectness acc['acc_classes'] = acc_classes return acc
[docs] def forward(self, cls_score, labels, label_weights=None, avg_factor=None, reduction_override=None): """Forward function. Args: cls_score (torch.Tensor): The prediction with shape (N, C + 2). labels (torch.Tensor): The learning label of the prediction. label_weights (torch.Tensor, optional): Sample-wise loss weight. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Returns: torch.Tensor | Dict [str, torch.Tensor]: if return_dict == False: The calculated loss | if return_dict == True: The dict of calculated losses for objectness and classes, respectively. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) assert cls_score.size(-1) == self.num_classes + 2 pos_inds = labels < self.num_classes # 0 for pos, 1 for neg obj_labels = (labels == self.num_classes).long() # accumulate the samples for each category unique_labels = labels.unique() for u_l in unique_labels: inds_ = labels == u_l.item() self.cum_samples[u_l] += inds_.sum() if label_weights is not None: label_weights = label_weights.float() else: label_weights = labels.new_ones(labels.size(), dtype=torch.float) cls_score_classes, cls_score_objectness = self._split_cls_score( cls_score) # calculate loss_cls_classes (only need pos samples) if pos_inds.sum() > 0: loss_cls_classes = self.loss_weight * self.cls_criterion( cls_score_classes[pos_inds], labels[pos_inds], label_weights[pos_inds], self.cum_samples[:self.num_classes], self.num_classes, self.p, self.q, self.eps, reduction, avg_factor) else: loss_cls_classes = cls_score_classes[pos_inds].sum() # calculate loss_cls_objectness loss_cls_objectness = self.loss_weight * cross_entropy( cls_score_objectness, obj_labels, label_weights, reduction, avg_factor) if self.return_dict: loss_cls = dict() loss_cls['loss_cls_objectness'] = loss_cls_objectness loss_cls['loss_cls_classes'] = loss_cls_classes else: loss_cls = loss_cls_classes + loss_cls_objectness return loss_cls
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.