Shortcuts

Source code for mmdet.models.detectors.rpn

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import torch
from mmcv.image import tensor2imgs

from mmdet.core import bbox_mapping
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector


[docs]@DETECTORS.register_module() class RPN(BaseDetector): """Implementation of Region Proposal Network.""" def __init__(self, backbone, neck, rpn_head, train_cfg, test_cfg, pretrained=None, init_cfg=None): super(RPN, self).__init__(init_cfg) if pretrained: warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') backbone.pretrained = pretrained self.backbone = build_backbone(backbone) self.neck = build_neck(neck) if neck is not None else None rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head.update(train_cfg=rpn_train_cfg) rpn_head.update(test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img): """Extract features. Args: img (torch.Tensor): Image tensor with shape (n, c, h ,w). Returns: list[torch.Tensor]: Multi-level features that may have different resolutions. """ x = self.backbone(img) if self.with_neck: x = self.neck(x) return x
[docs] def forward_dummy(self, img): """Dummy forward function.""" x = self.extract_feat(img) rpn_outs = self.rpn_head(x) return rpn_outs
[docs] def forward_train(self, img, img_metas, gt_bboxes=None, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ if (isinstance(self.train_cfg.rpn, dict) and self.train_cfg.rpn.get('debug', False)): self.rpn_head.debug_imgs = tensor2imgs(img) x = self.extract_feat(img) losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None, gt_bboxes_ignore) return losses
[docs] def simple_test(self, img, img_metas, rescale=False): """Test function without test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[np.ndarray]: proposals """ x = self.extract_feat(img) # get origin input shape to onnx dynamic input shape if torch.onnx.is_in_onnx_export(): img_shape = torch._shape_as_tensor(img)[2:] img_metas[0]['img_shape_for_onnx'] = img_shape proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) if rescale: for proposals, meta in zip(proposal_list, img_metas): proposals[:, :4] /= proposals.new_tensor(meta['scale_factor']) if torch.onnx.is_in_onnx_export(): return proposal_list return [proposal.cpu().numpy() for proposal in proposal_list]
[docs] def aug_test(self, imgs, img_metas, rescale=False): """Test function with test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[np.ndarray]: proposals """ proposal_list = self.rpn_head.aug_test_rpn( self.extract_feats(imgs), img_metas) if not rescale: for proposals, img_meta in zip(proposal_list, img_metas[0]): img_shape = img_meta['img_shape'] scale_factor = img_meta['scale_factor'] flip = img_meta['flip'] flip_direction = img_meta['flip_direction'] proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape, scale_factor, flip, flip_direction) return [proposal.cpu().numpy() for proposal in proposal_list]
[docs] def show_result(self, data, result, top_k=20, **kwargs): """Show RPN proposals on the image. Args: data (str or np.ndarray): Image filename or loaded image. result (Tensor or tuple): The results to draw over `img` bbox_result or (bbox_result, segm_result). top_k (int): Plot the first k bboxes only if set positive. Default: 20 Returns: np.ndarray: The image with bboxes drawn on it. """ if kwargs is not None: kwargs.pop('score_thr', None) kwargs.pop('text_color', None) kwargs['colors'] = kwargs.pop('bbox_color', 'green') mmcv.imshow_bboxes(data, result, top_k=top_k, **kwargs)
Read the Docs v: v2.21.0
Versions
latest
stable
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.