• Docs >
  • Module code >
  • mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor
Shortcuts

Source code for mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import force_fp32

from mmdet.models.builder import ROI_EXTRACTORS
from .base_roi_extractor import BaseRoIExtractor


[docs]@ROI_EXTRACTORS.register_module() class SingleRoIExtractor(BaseRoIExtractor): """Extract RoI features from a single level feature map. If there are multiple input feature levels, each RoI is mapped to a level according to its scale. The mapping rule is proposed in `FPN <https://arxiv.org/abs/1612.03144>`_. Args: roi_layer (dict): Specify RoI layer type and arguments. out_channels (int): Output channels of RoI layers. featmap_strides (List[int]): Strides of input feature maps. finest_scale (int): Scale threshold of mapping to level 0. Default: 56. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, roi_layer, out_channels, featmap_strides, finest_scale=56, init_cfg=None): super(SingleRoIExtractor, self).__init__(roi_layer, out_channels, featmap_strides, init_cfg) self.finest_scale = finest_scale
[docs] def map_roi_levels(self, rois, num_levels): """Map rois to corresponding feature levels by scales. - scale < finest_scale * 2: level 0 - finest_scale * 2 <= scale < finest_scale * 4: level 1 - finest_scale * 4 <= scale < finest_scale * 8: level 2 - scale >= finest_scale * 8: level 3 Args: rois (Tensor): Input RoIs, shape (k, 5). num_levels (int): Total level number. Returns: Tensor: Level index (0-based) of each RoI, shape (k, ) """ scale = torch.sqrt( (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() return target_lvls
[docs] @force_fp32(apply_to=('feats', ), out_fp16=True) def forward(self, feats, rois, roi_scale_factor=None): """Forward function.""" out_size = self.roi_layers[0].output_size num_levels = len(feats) expand_dims = (-1, self.out_channels * out_size[0] * out_size[1]) if torch.onnx.is_in_onnx_export(): # Work around to export mask-rcnn to onnx roi_feats = rois[:, :1].clone().detach() roi_feats = roi_feats.expand(*expand_dims) roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size) roi_feats = roi_feats * 0 else: roi_feats = feats[0].new_zeros( rois.size(0), self.out_channels, *out_size) # TODO: remove this when parrots supports if torch.__version__ == 'parrots': roi_feats.requires_grad = True if num_levels == 1: if len(rois) == 0: return roi_feats return self.roi_layers[0](feats[0], rois) target_lvls = self.map_roi_levels(rois, num_levels) if roi_scale_factor is not None: rois = self.roi_rescale(rois, roi_scale_factor) for i in range(num_levels): mask = target_lvls == i if torch.onnx.is_in_onnx_export(): # To keep all roi_align nodes exported to onnx # and skip nonzero op mask = mask.float().unsqueeze(-1) # select target level rois and reset the rest rois to zero. rois_i = rois.clone().detach() rois_i *= mask mask_exp = mask.expand(*expand_dims).reshape(roi_feats.shape) roi_feats_t = self.roi_layers[i](feats[i], rois_i) roi_feats_t *= mask_exp roi_feats += roi_feats_t continue inds = mask.nonzero(as_tuple=False).squeeze(1) if inds.numel() > 0: rois_ = rois[inds] roi_feats_t = self.roi_layers[i](feats[i], rois_) roi_feats[inds] = roi_feats_t else: # Sometimes some pyramid levels will not be used for RoI # feature extraction and this will cause an incomplete # computation graph in one GPU, which is different from those # in other GPUs and will cause a hanging error. # Therefore, we add it to ensure each feature pyramid is # included in the computation graph to avoid runtime bugs. roi_feats += sum( x.view(-1)[0] for x in self.parameters()) * 0. + feats[i].sum() * 0. return roi_feats
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.