Shortcuts

Source code for mmdet.models.losses.ghm_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss


def _expand_onehot_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(
        (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds]] = 1
    bin_label_weights = label_weights.view(-1, 1).expand(
        label_weights.size(0), label_channels)
    return bin_labels, bin_label_weights


# TODO: code refactoring to make it consistent with other losses
[docs]@LOSSES.register_module() class GHMC(nn.Module): """GHM Classification Loss. Details of the theorem can be viewed in the paper `Gradient Harmonized Single-stage Detector <https://arxiv.org/abs/1811.05181>`_. Args: bins (int): Number of the unit regions for distribution calculation. momentum (float): The parameter for moving average. use_sigmoid (bool): Can only be true for BCE based loss now. loss_weight (float): The weight of the total GHM-C loss. reduction (str): Options are "none", "mean" and "sum". Defaults to "mean" """ def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0, reduction='mean'): super(GHMC, self).__init__() self.bins = bins self.momentum = momentum edges = torch.arange(bins + 1).float() / bins self.register_buffer('edges', edges) self.edges[-1] += 1e-6 if momentum > 0: acc_sum = torch.zeros(bins) self.register_buffer('acc_sum', acc_sum) self.use_sigmoid = use_sigmoid if not self.use_sigmoid: raise NotImplementedError self.loss_weight = loss_weight self.reduction = reduction
[docs] def forward(self, pred, target, label_weight, reduction_override=None, **kwargs): """Calculate the GHM-C loss. Args: pred (float tensor of size [batch_num, class_num]): The direct prediction of classification fc layer. target (float tensor of size [batch_num, class_num]): Binary class target for each sample. label_weight (float tensor of size [batch_num, class_num]): the value is 1 if the sample is valid and 0 if ignored. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: The gradient harmonized loss. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) # the target should be binary class label if pred.dim() != target.dim(): target, label_weight = _expand_onehot_labels( target, label_weight, pred.size(-1)) target, label_weight = target.float(), label_weight.float() edges = self.edges mmt = self.momentum weights = torch.zeros_like(pred) # gradient length g = torch.abs(pred.sigmoid().detach() - target) valid = label_weight > 0 tot = max(valid.float().sum().item(), 1.0) n = 0 # n valid bins for i in range(self.bins): inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: if mmt > 0: self.acc_sum[i] = mmt * self.acc_sum[i] \ + (1 - mmt) * num_in_bin weights[inds] = tot / self.acc_sum[i] else: weights[inds] = tot / num_in_bin n += 1 if n > 0: weights = weights / n loss = F.binary_cross_entropy_with_logits( pred, target, reduction='none') loss = weight_reduce_loss( loss, weights, reduction=reduction, avg_factor=tot) return loss * self.loss_weight
# TODO: code refactoring to make it consistent with other losses
[docs]@LOSSES.register_module() class GHMR(nn.Module): """GHM Regression Loss. Details of the theorem can be viewed in the paper `Gradient Harmonized Single-stage Detector <https://arxiv.org/abs/1811.05181>`_. Args: mu (float): The parameter for the Authentic Smooth L1 loss. bins (int): Number of the unit regions for distribution calculation. momentum (float): The parameter for moving average. loss_weight (float): The weight of the total GHM-R loss. reduction (str): Options are "none", "mean" and "sum". Defaults to "mean" """ def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0, reduction='mean'): super(GHMR, self).__init__() self.mu = mu self.bins = bins edges = torch.arange(bins + 1).float() / bins self.register_buffer('edges', edges) self.edges[-1] = 1e3 self.momentum = momentum if momentum > 0: acc_sum = torch.zeros(bins) self.register_buffer('acc_sum', acc_sum) self.loss_weight = loss_weight self.reduction = reduction # TODO: support reduction parameter
[docs] def forward(self, pred, target, label_weight, avg_factor=None, reduction_override=None): """Calculate the GHM-R loss. Args: pred (float tensor of size [batch_num, 4 (* class_num)]): The prediction of box regression layer. Channel number can be 4 or 4 * class_num depending on whether it is class-agnostic. target (float tensor of size [batch_num, 4 (* class_num)]): The target regression values with the same size of pred. label_weight (float tensor of size [batch_num, 4 (* class_num)]): The weight of each sample, 0 if ignored. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: The gradient harmonized loss. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) mu = self.mu edges = self.edges mmt = self.momentum # ASL1 loss diff = pred - target loss = torch.sqrt(diff * diff + mu * mu) - mu # gradient length g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() weights = torch.zeros_like(g) valid = label_weight > 0 tot = max(label_weight.float().sum().item(), 1.0) n = 0 # n: valid bins for i in range(self.bins): inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: n += 1 if mmt > 0: self.acc_sum[i] = mmt * self.acc_sum[i] \ + (1 - mmt) * num_in_bin weights[inds] = tot / self.acc_sum[i] else: weights[inds] = tot / num_in_bin if n > 0: weights /= n loss = weight_reduce_loss( loss, weights, reduction=reduction, avg_factor=tot) return loss * self.loss_weight
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.