Shortcuts

Source code for mmdet.models.detectors.two_stage

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch

from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector


[docs]@DETECTORS.register_module() class TwoStageDetector(BaseDetector): """Base class for two-stage detectors. Two-stage detectors typically consisting of a region proposal network and a task-specific regression head. """ def __init__(self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(TwoStageDetector, self).__init__(init_cfg) if pretrained: warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') backbone.pretrained = pretrained self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) if rpn_head is not None: rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head_ = rpn_head.copy() rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head_) if roi_head is not None: # update train and test cfg here for now # TODO: refactor assigner & sampler rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None roi_head.update(train_cfg=rcnn_train_cfg) roi_head.update(test_cfg=test_cfg.rcnn) roi_head.pretrained = pretrained self.roi_head = build_head(roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg @property def with_rpn(self): """bool: whether the detector has RPN""" return hasattr(self, 'rpn_head') and self.rpn_head is not None @property def with_roi_head(self): """bool: whether the detector has a RoI head""" return hasattr(self, 'roi_head') and self.roi_head is not None
[docs] def extract_feat(self, img): """Directly extract features from the backbone+neck.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x
[docs] def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/analysis_tools/get_flops.py` """ outs = () # backbone x = self.extract_feat(img) # rpn if self.with_rpn: rpn_outs = self.rpn_head(x) outs = outs + (rpn_outs, ) proposals = torch.randn(1000, 4).to(img.device) # roi_head roi_outs = self.roi_head.forward_dummy(x, proposals) outs = outs + (roi_outs, ) return outs
[docs] def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None, gt_masks=None, proposals=None, **kwargs): """ Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): specify which bounding boxes can be ignored when computing the loss. gt_masks (None | Tensor) : true segmentation masks for each box used if the architecture supports a segmentation task. proposals : override rpn proposals with custom proposals. Use when `with_rpn` is False. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self.extract_feat(img) losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) rpn_losses, proposal_list = self.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, gt_bboxes_ignore=gt_bboxes_ignore, proposal_cfg=proposal_cfg, **kwargs) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs) losses.update(roi_losses) return losses
[docs] async def async_simple_test(self, img, img_meta, proposals=None, rescale=False): """Async test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' x = self.extract_feat(img) if proposals is None: proposal_list = await self.rpn_head.async_simple_test_rpn( x, img_meta) else: proposal_list = proposals return await self.roi_head.async_simple_test( x, proposal_list, img_meta, rescale=rescale)
[docs] def simple_test(self, img, img_metas, proposals=None, rescale=False): """Test without augmentation.""" assert self.with_bbox, 'Bbox head must be implemented.' x = self.extract_feat(img) if proposals is None: proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) else: proposal_list = proposals return self.roi_head.simple_test( x, proposal_list, img_metas, rescale=rescale)
[docs] def aug_test(self, imgs, img_metas, rescale=False): """Test with augmentations. If rescale is False, then returned bboxes and masks will fit the scale of imgs[0]. """ x = self.extract_feats(imgs) proposal_list = self.rpn_head.aug_test_rpn(x, img_metas) return self.roi_head.aug_test( x, proposal_list, img_metas, rescale=rescale)
def onnx_export(self, img, img_metas): img_shape = torch._shape_as_tensor(img)[2:] img_metas[0]['img_shape_for_onnx'] = img_shape x = self.extract_feat(img) proposals = self.rpn_head.onnx_export(x, img_metas) if hasattr(self.roi_head, 'onnx_export'): return self.roi_head.onnx_export(x, proposals, img_metas) else: raise NotImplementedError( f'{self.__class__.__name__} can not ' f'be exported to ONNX. Please refer to the ' f'list of supported models,' f'https://mmdetection.readthedocs.io/en/latest/tutorials/pytorch2onnx.html#list-of-supported-models-exportable-to-onnx' # noqa E501 )
Read the Docs v: v2.19.1
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.