Shortcuts

Source code for mmdet.models.dense_heads.pisa_retinanet_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import force_fp32

from mmdet.core import images_to_levels
from ..builder import HEADS
from ..losses import carl_loss, isr_p
from .retina_head import RetinaHead


[docs]@HEADS.register_module() class PISARetinaHead(RetinaHead): """PISA Retinanet Head. The head owns the same structure with Retinanet Head, but differs in two aspects: 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to change the positive loss weights. 2. Classification-aware regression loss is adopted as a third loss. """
[docs] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore=None): """Compute losses of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W) gt_bboxes (list[Tensor]): Ground truth bboxes of each image with shape (num_obj, 4). gt_labels (list[Tensor]): Ground truth labels of each image with shape (num_obj, 4). img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image. Default: None. Returns: dict: Loss dict, comprise classification loss, regression loss and carl loss. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, img_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = self.get_targets( anchor_list, valid_flag_list, gt_bboxes, img_metas, gt_bboxes_ignore_list=gt_bboxes_ignore, gt_labels_list=gt_labels, label_channels=label_channels, return_sampling_results=True) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets num_total_samples = ( num_total_pos + num_total_neg if self.sampling else num_total_pos) # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] # concat all level anchors and flags to a single tensor concat_anchor_list = [] for i in range(len(anchor_list)): concat_anchor_list.append(torch.cat(anchor_list[i])) all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) num_imgs = len(img_metas) flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels) for cls_score in cls_scores ] flatten_cls_scores = torch.cat( flatten_cls_scores, dim=1).reshape(-1, flatten_cls_scores[0].size(-1)) flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] flatten_bbox_preds = torch.cat( flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1)) flatten_labels = torch.cat(labels_list, dim=1).reshape(-1) flatten_label_weights = torch.cat( label_weights_list, dim=1).reshape(-1) flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4) flatten_bbox_targets = torch.cat( bbox_targets_list, dim=1).reshape(-1, 4) flatten_bbox_weights = torch.cat( bbox_weights_list, dim=1).reshape(-1, 4) # Apply ISR-P isr_cfg = self.train_cfg.get('isr', None) if isr_cfg is not None: all_targets = (flatten_labels, flatten_label_weights, flatten_bbox_targets, flatten_bbox_weights) with torch.no_grad(): all_targets = isr_p( flatten_cls_scores, flatten_bbox_preds, all_targets, flatten_anchors, sampling_results_list, bbox_coder=self.bbox_coder, loss_cls=self.loss_cls, num_class=self.num_classes, **self.train_cfg.isr) (flatten_labels, flatten_label_weights, flatten_bbox_targets, flatten_bbox_weights) = all_targets # For convenience we compute loss once instead separating by fpn level, # so that we don't need to separate the weights by level again. # The result should be the same losses_cls = self.loss_cls( flatten_cls_scores, flatten_labels, flatten_label_weights, avg_factor=num_total_samples) losses_bbox = self.loss_bbox( flatten_bbox_preds, flatten_bbox_targets, flatten_bbox_weights, avg_factor=num_total_samples) loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox) # CARL Loss carl_cfg = self.train_cfg.get('carl', None) if carl_cfg is not None: loss_carl = carl_loss( flatten_cls_scores, flatten_labels, flatten_bbox_preds, flatten_bbox_targets, self.loss_bbox, **self.train_cfg.carl, avg_factor=num_total_pos, sigmoid=True, num_class=self.num_classes) loss_dict.update(loss_carl) return loss_dict
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.