Shortcuts

Source code for mmdet.models.dense_heads.free_anchor_retina_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F

from mmdet.core import bbox_overlaps
from ..builder import HEADS
from .retina_head import RetinaHead

EPS = 1e-12


[docs]@HEADS.register_module() class FreeAnchorRetinaHead(RetinaHead): """FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. stacked_convs (int): Number of conv layers in cls and reg tower. Default: 4. conv_cfg (dict): dictionary to construct and config conv layer. Default: None. norm_cfg (dict): dictionary to construct and config norm layer. Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True). pre_anchor_topk (int): Number of boxes that be token in each bag. bbox_thr (float): The threshold of the saturated linear function. It is usually the same with the IoU threshold used in NMS. gamma (float): Gamma parameter in focal loss. alpha (float): Alpha parameter in focal loss. """ # noqa: W605 def __init__(self, num_classes, in_channels, stacked_convs=4, conv_cfg=None, norm_cfg=None, pre_anchor_topk=50, bbox_thr=0.6, gamma=2.0, alpha=0.5, **kwargs): super(FreeAnchorRetinaHead, self).__init__(num_classes, in_channels, stacked_convs, conv_cfg, norm_cfg, **kwargs) self.pre_anchor_topk = pre_anchor_topk self.bbox_thr = bbox_thr self.gamma = gamma self.alpha = alpha
[docs] def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute losses of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W) gt_bboxes (list[Tensor]): each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels anchor_list, _ = self.get_anchors(featmap_sizes, img_metas) anchors = [torch.cat(anchor) for anchor in anchor_list] # concatenate each level cls_scores = [ cls.permute(0, 2, 3, 1).reshape(cls.size(0), -1, self.cls_out_channels) for cls in cls_scores ] bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4) for bbox_pred in bbox_preds ] cls_scores = torch.cat(cls_scores, dim=1) bbox_preds = torch.cat(bbox_preds, dim=1) cls_prob = torch.sigmoid(cls_scores) box_prob = [] num_pos = 0 positive_losses = [] for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_, bbox_preds_) in enumerate( zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)): with torch.no_grad(): if len(gt_bboxes_) == 0: image_box_prob = torch.zeros( anchors_.size(0), self.cls_out_channels).type_as(bbox_preds_) else: # box_localization: a_{j}^{loc}, shape: [j, 4] pred_boxes = self.bbox_coder.decode(anchors_, bbox_preds_) # object_box_iou: IoU_{ij}^{loc}, shape: [i, j] object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes) # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j] t1 = self.bbox_thr t2 = object_box_iou.max( dim=1, keepdim=True).values.clamp(min=t1 + 1e-12) object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp( min=0, max=1) # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j] num_obj = gt_labels_.size(0) indices = torch.stack([ torch.arange(num_obj).type_as(gt_labels_), gt_labels_ ], dim=0) object_cls_box_prob = torch.sparse_coo_tensor( indices, object_box_prob) # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j] """ from "start" to "end" implement: image_box_iou = torch.sparse.max(object_cls_box_prob, dim=0).t() """ # start box_cls_prob = torch.sparse.sum( object_cls_box_prob, dim=0).to_dense() indices = torch.nonzero(box_cls_prob, as_tuple=False).t_() if indices.numel() == 0: image_box_prob = torch.zeros( anchors_.size(0), self.cls_out_channels).type_as(object_box_prob) else: nonzero_box_prob = torch.where( (gt_labels_.unsqueeze(dim=-1) == indices[0]), object_box_prob[:, indices[1]], torch.tensor([ 0 ]).type_as(object_box_prob)).max(dim=0).values # upmap to shape [j, c] image_box_prob = torch.sparse_coo_tensor( indices.flip([0]), nonzero_box_prob, size=(anchors_.size(0), self.cls_out_channels)).to_dense() # end box_prob.append(image_box_prob) # construct bags for objects match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_) _, matched = torch.topk( match_quality_matrix, self.pre_anchor_topk, dim=1, sorted=False) del match_quality_matrix # matched_cls_prob: P_{ij}^{cls} matched_cls_prob = torch.gather( cls_prob_[matched], 2, gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk, 1)).squeeze(2) # matched_box_prob: P_{ij}^{loc} matched_anchors = anchors_[matched] matched_object_targets = self.bbox_coder.encode( matched_anchors, gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors)) loss_bbox = self.loss_bbox( bbox_preds_[matched], matched_object_targets, reduction_override='none').sum(-1) matched_box_prob = torch.exp(-loss_bbox) # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )} num_pos += len(gt_bboxes_) positive_losses.append( self.positive_bag_loss(matched_cls_prob, matched_box_prob)) positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos) # box_prob: P{a_{j} \in A_{+}} box_prob = torch.stack(box_prob, dim=0) # negative_loss: # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B|| negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max( 1, num_pos * self.pre_anchor_topk) # avoid the absence of gradients in regression subnet # when no ground-truth in a batch if num_pos == 0: positive_loss = bbox_preds.sum() * 0 losses = { 'positive_bag_loss': positive_loss, 'negative_bag_loss': negative_loss } return losses
[docs] def positive_bag_loss(self, matched_cls_prob, matched_box_prob): """Compute positive bag loss. :math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`. :math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples. :math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples. Args: matched_cls_prob (Tensor): Classification probability of matched samples in shape (num_gt, pre_anchor_topk). matched_box_prob (Tensor): BBox probability of matched samples, in shape (num_gt, pre_anchor_topk). Returns: Tensor: Positive bag loss in shape (num_gt,). """ # noqa: E501, W605 # bag_prob = Mean-max(matched_prob) matched_prob = matched_cls_prob * matched_box_prob weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None) weight /= weight.sum(dim=1).unsqueeze(dim=-1) bag_prob = (weight * matched_prob).sum(dim=1) # positive_bag_loss = -self.alpha * log(bag_prob) return self.alpha * F.binary_cross_entropy( bag_prob, torch.ones_like(bag_prob), reduction='none')
[docs] def negative_bag_loss(self, cls_prob, box_prob): """Compute negative bag loss. :math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`. :math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples. :math:`P_{j}^{bg}`: Classification probability of negative samples. Args: cls_prob (Tensor): Classification probability, in shape (num_img, num_anchors, num_classes). box_prob (Tensor): Box probability, in shape (num_img, num_anchors, num_classes). Returns: Tensor: Negative bag loss in shape (num_img, num_anchors, num_classes). """ # noqa: E501, W605 prob = cls_prob * (1 - box_prob) # There are some cases when neg_prob = 0. # This will cause the neg_prob.log() to be inf without clamp. prob = prob.clamp(min=EPS, max=1 - EPS) negative_bag_loss = prob**self.gamma * F.binary_cross_entropy( prob, torch.zeros_like(prob), reduction='none') return (1 - self.alpha) * negative_bag_loss
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.