Source code for mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor

import torch

from mmdet.core import force_fp32
from mmdet.models.builder import ROI_EXTRACTORS
from .base_roi_extractor import BaseRoIExtractor


[docs]@ROI_EXTRACTORS.register_module() class SingleRoIExtractor(BaseRoIExtractor): """Extract RoI features from a single level feature map. If there are multiple input feature levels, each RoI is mapped to a level according to its scale. The mapping rule is proposed in `FPN <https://arxiv.org/abs/1612.03144>`_. Args: roi_layer (dict): Specify RoI layer type and arguments. out_channels (int): Output channels of RoI layers. featmap_strides (int): Strides of input feature maps. finest_scale (int): Scale threshold of mapping to level 0. Default: 56. """ def __init__(self, roi_layer, out_channels, featmap_strides, finest_scale=56): super(SingleRoIExtractor, self).__init__(roi_layer, out_channels, featmap_strides) self.finest_scale = finest_scale
[docs] def map_roi_levels(self, rois, num_levels): """Map rois to corresponding feature levels by scales. - scale < finest_scale * 2: level 0 - finest_scale * 2 <= scale < finest_scale * 4: level 1 - finest_scale * 4 <= scale < finest_scale * 8: level 2 - scale >= finest_scale * 8: level 3 Args: rois (Tensor): Input RoIs, shape (k, 5). num_levels (int): Total level number. Returns: Tensor: Level index (0-based) of each RoI, shape (k, ) """ scale = torch.sqrt( (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() return target_lvls
[docs] @force_fp32(apply_to=('feats', ), out_fp16=True) def forward(self, feats, rois, roi_scale_factor=None): """Forward function""" out_size = self.roi_layers[0].out_size num_levels = len(feats) roi_feats = feats[0].new_zeros( rois.size(0), self.out_channels, *out_size) if num_levels == 1: if len(rois) == 0: return roi_feats return self.roi_layers[0](feats[0], rois) target_lvls = self.map_roi_levels(rois, num_levels) if roi_scale_factor is not None: rois = self.roi_rescale(rois, roi_scale_factor) for i in range(num_levels): inds = target_lvls == i if inds.any(): rois_ = rois[inds, :] roi_feats_t = self.roi_layers[i](feats[i], rois_) roi_feats[inds] = roi_feats_t else: roi_feats += sum(x.view(-1)[0] for x in self.parameters()) * 0. return roi_feats