Source code for mmdet.models.dense_heads.pisa_retinanet_head

import torch
from mmcv.runner import force_fp32

from mmdet.core import images_to_levels
from ..builder import HEADS
from ..losses import carl_loss, isr_p
from .retina_head import RetinaHead


[docs]@HEADS.register_module() class PISARetinaHead(RetinaHead): """PISA Retinanet Head. The head owns the same structure with Retinanet Head, but differs in two aspects: 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to change the positive loss weights. 2. Classification-aware regression loss is adopted as a third loss. """
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute losses of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W) gt_bboxes (list[Tensor]): Ground truth bboxes of each image with shape (num_obj, 4). gt_labels (list[Tensor]): Ground truth labels of each image with shape (num_obj, 4). img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image. Default: None. Returns: dict: Loss dict, comprise classification loss, regression loss and carl loss. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.anchor_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, return_sampling_results=True) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets num_total_samples = ( num_total_pos + num_total_neg if self.sampling else num_total_pos) # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] # concat all level anchors and flags to a single tensor concat_anchor_list = [] for i in range(len(anchor_list)): concat_anchor_list.append(torch.cat(anchor_list[i])) all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) num_imgs = len(img_metas) flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels) for cls_score in cls_scores ] flatten_cls_scores = torch.cat( flatten_cls_scores, dim=1).reshape(-1, flatten_cls_scores[0].size(-1)) flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] flatten_bbox_preds = torch.cat( flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1)) flatten_labels = torch.cat(labels_list, dim=1).reshape(-1) flatten_label_weights = torch.cat( label_weights_list, dim=1).reshape(-1) flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4) flatten_bbox_targets = torch.cat( bbox_targets_list, dim=1).reshape(-1, 4) flatten_bbox_weights = torch.cat( bbox_weights_list, dim=1).reshape(-1, 4) # Apply ISR-P isr_cfg = self.train_cfg.get('isr', None) if isr_cfg is not None: all_targets = (flatten_labels, flatten_label_weights, flatten_bbox_targets, flatten_bbox_weights) with torch.no_grad(): all_targets = isr_p( flatten_cls_scores, flatten_bbox_preds, all_targets, flatten_anchors, sampling_results_list, bbox_coder=self.bbox_coder, loss_cls=self.loss_cls, num_class=self.num_classes, **self.train_cfg.isr) (flatten_labels, flatten_label_weights, flatten_bbox_targets, flatten_bbox_weights) = all_targets # For convenience we compute loss once instead separating by fpn level, # so that we don't need to separate the weights by level again. # The result should be the same losses_cls = self.loss_cls( flatten_cls_scores, flatten_labels, flatten_label_weights, avg_factor=num_total_samples) losses_bbox = self.loss_bbox( flatten_bbox_preds, flatten_bbox_targets, flatten_bbox_weights, avg_factor=num_total_samples) loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) # CARL Loss carl_cfg = self.train_cfg.get('carl', None) if carl_cfg is not None: loss_carl = carl_loss( flatten_cls_scores, flatten_labels, flatten_bbox_preds, flatten_bbox_targets, self.loss_bbox, **self.train_cfg.carl, avg_factor=num_total_pos, sigmoid=True, num_class=self.num_classes) loss_dict.update(loss_carl) return loss_dict