Shortcuts

Source code for mmdet.models.roi_heads.pisa_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import bbox2roi
from ..builder import HEADS
from ..losses.pisa_loss import carl_loss, isr_p
from .standard_roi_head import StandardRoIHead


[docs]@HEADS.register_module() class PISARoIHead(StandardRoIHead): r"""The RoI head for `Prime Sample Attention in Object Detection <https://arxiv.org/abs/1904.04821>`_."""
[docs] def forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None): """Forward function for training. Args: x (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. proposals (list[Tensors]): List of region proposals. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): Class indices corresponding to each box gt_bboxes_ignore (list[Tensor], optional): Specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : True segmentation masks for each box used if the architecture supports a segmentation task. Returns: dict[str, Tensor]: a dictionary of loss components """ # assign gts and sample proposals if self.with_bbox or self.with_mask: num_imgs = len(img_metas) if gt_bboxes_ignore is None: gt_bboxes_ignore = [None for _ in range(num_imgs)] sampling_results = [] neg_label_weights = [] for i in range(num_imgs): assign_result = self.bbox_assigner.assign( proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i]) sampling_result = self.bbox_sampler.sample( assign_result, proposal_list[i], gt_bboxes[i], gt_labels[i], feats=[lvl_feat[i][None] for lvl_feat in x]) # neg label weight is obtained by sampling when using ISR-N neg_label_weight = None if isinstance(sampling_result, tuple): sampling_result, neg_label_weight = sampling_result sampling_results.append(sampling_result) neg_label_weights.append(neg_label_weight) losses = dict() # bbox head forward and loss if self.with_bbox: bbox_results = self._bbox_forward_train( x, sampling_results, gt_bboxes, gt_labels, img_metas, neg_label_weights=neg_label_weights) losses.update(bbox_results['loss_bbox']) # mask head forward and loss if self.with_mask: mask_results = self._mask_forward_train(x, sampling_results, bbox_results['bbox_feats'], gt_masks, img_metas) losses.update(mask_results['loss_mask']) return losses
def _bbox_forward(self, x, rois): """Box forward function used in both training and testing.""" # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred = self.bbox_head(bbox_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) return bbox_results def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels, img_metas, neg_label_weights=None): """Run forward function and calculate loss for box head in training.""" rois = bbox2roi([res.bboxes for res in sampling_results]) bbox_results = self._bbox_forward(x, rois) bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes, gt_labels, self.train_cfg) # neg_label_weights obtained by sampler is image-wise, mapping back to # the corresponding location in label weights if neg_label_weights[0] is not None: label_weights = bbox_targets[1] cur_num_rois = 0 for i in range(len(sampling_results)): num_pos = sampling_results[i].pos_inds.size(0) num_neg = sampling_results[i].neg_inds.size(0) label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos + num_neg] = neg_label_weights[i] cur_num_rois += num_pos + num_neg cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] # Apply ISR-P isr_cfg = self.train_cfg.get('isr', None) if isr_cfg is not None: bbox_targets = isr_p( cls_score, bbox_pred, bbox_targets, rois, sampling_results, self.bbox_head.loss_cls, self.bbox_head.bbox_coder, **isr_cfg, num_class=self.bbox_head.num_classes) loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois, *bbox_targets) # Add CARL Loss carl_cfg = self.train_cfg.get('carl', None) if carl_cfg is not None: loss_carl = carl_loss( cls_score, bbox_targets[0], bbox_pred, bbox_targets[2], self.bbox_head.loss_bbox, **carl_cfg, num_class=self.bbox_head.num_classes) loss_bbox.update(loss_carl) bbox_results.update(loss_bbox=loss_bbox) return bbox_results
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.