Source code for mmdet.models.detectors.rpn

import warnings

import mmcv
import torch
from mmcv.image import tensor2imgs

from mmdet.core import bbox_mapping
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector


[docs]@DETECTORS.register_module() class RPN(BaseDetector): """Implementation of Region Proposal Network.""" def __init__(self, backbone, neck, rpn_head, train_cfg, test_cfg, pretrained=None, init_cfg=None): super(RPN, self).__init__(init_cfg) if pretrained: warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') backbone.pretrained = pretrained self.backbone = build_backbone(backbone) self.neck = build_neck(neck) if neck is not None else None rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head.update(train_cfg=rpn_train_cfg) rpn_head.update(test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] def extract_feat(self, img): """Extract features. Args: img (torch.Tensor): Image tensor with shape (n, c, h ,w). Returns: list[torch.Tensor]: Multi-level features that may have different resolutions. """ x = self.backbone(img) if self.with_neck: x = self.neck(x) return x
[docs] def forward_dummy(self, img): """Dummy forward function.""" x = self.extract_feat(img) rpn_outs = self.rpn_head(x) return rpn_outs
[docs] def forward_train(self, img, img_metas, gt_bboxes=None, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ if (isinstance(self.train_cfg.rpn, dict) and self.train_cfg.rpn.get('debug', False)): self.rpn_head.debug_imgs = tensor2imgs(img) x = self.extract_feat(img) losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None, gt_bboxes_ignore) return losses
[docs] def simple_test(self, img, img_metas, rescale=False): """Test function without test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[np.ndarray]: proposals """ x = self.extract_feat(img) # get origin input shape to onnx dynamic input shape if torch.onnx.is_in_onnx_export(): img_shape = torch._shape_as_tensor(img)[2:] img_metas[0]['img_shape_for_onnx'] = img_shape proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) if rescale: for proposals, meta in zip(proposal_list, img_metas): proposals[:, :4] /= proposals.new_tensor(meta['scale_factor']) if torch.onnx.is_in_onnx_export(): return proposal_list return [proposal.cpu().numpy() for proposal in proposal_list]
[docs] def aug_test(self, imgs, img_metas, rescale=False): """Test function with test time augmentation. Args: imgs (list[torch.Tensor]): List of multiple images img_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[np.ndarray]: proposals """ proposal_list = self.rpn_head.aug_test_rpn( self.extract_feats(imgs), img_metas) if not rescale: for proposals, img_meta in zip(proposal_list, img_metas[0]): img_shape = img_meta['img_shape'] scale_factor = img_meta['scale_factor'] flip = img_meta['flip'] flip_direction = img_meta['flip_direction'] proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape, scale_factor, flip, flip_direction) return [proposal.cpu().numpy() for proposal in proposal_list]
[docs] def show_result(self, data, result, top_k=20, **kwargs): """Show RPN proposals on the image. Args: data (str or np.ndarray): Image filename or loaded image. result (Tensor or tuple): The results to draw over `img` bbox_result or (bbox_result, segm_result). top_k (int): Plot the first k bboxes only if set positive. Default: 20 Returns: np.ndarray: The image with bboxes drawn on it. """ mmcv.imshow_bboxes(data, result, top_k=top_k)