Shortcuts

Source code for mmdet.models.detectors.sparse_rcnn

# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS
from .two_stage import TwoStageDetector


[docs]@DETECTORS.register_module() class SparseRCNN(TwoStageDetector): r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with Learnable Proposals <https://arxiv.org/abs/2011.12450>`_""" def __init__(self, *args, **kwargs): super(SparseRCNN, self).__init__(*args, **kwargs) assert self.with_rpn, 'Sparse R-CNN and QueryInst ' \ 'do not support external proposals'
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, proposals=None, **kwargs): """Forward function of SparseR-CNN and QueryInst in train stage. Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor): specify which bounding boxes can be ignored when computing the loss. gt_masks (List[Tensor], optional) : Segmentation masks for each box. This is required to train QueryInst. proposals (List[Tensor], optional): override rpn proposals with custom proposals. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ assert proposals is None, 'Sparse R-CNN and QueryInst ' \ 'do not support external proposals' x = self.extract_feat(img) proposal_boxes, proposal_features, imgs_whwh = \ self.rpn_head.forward_train(x, img_metas) roi_losses = self.roi_head.forward_train( x, proposal_boxes, proposal_features, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=gt_bboxes_ignore, gt_masks=gt_masks, imgs_whwh=imgs_whwh) return roi_losses
[docs] def simple_test(self, img, img_metas, rescale=False): """Test function without test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool): Whether to rescale the results. Defaults to False. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ x = self.extract_feat(img) proposal_boxes, proposal_features, imgs_whwh = \ self.rpn_head.simple_test_rpn(x, img_metas) results = self.roi_head.simple_test( x, proposal_boxes, proposal_features, img_metas, imgs_whwh=imgs_whwh, rescale=rescale) return results
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/analysis_tools/get_flops.py` """ # backbone x = self.extract_feat(img) # rpn num_imgs = len(img) dummy_img_metas = [ dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs) ] proposal_boxes, proposal_features, imgs_whwh = \ self.rpn_head.simple_test_rpn(x, dummy_img_metas) # roi_head roi_outs = self.roi_head.forward_dummy(x, proposal_boxes, proposal_features, dummy_img_metas) return roi_outs
Read the Docs v: v2.19.0
Versions
latest
stable
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.