Shortcuts

mmdet.models.necks.bfp 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import NonLocal2d
from mmcv.runner import BaseModule

from ..builder import NECKS


[文档]@NECKS.register_module() class BFP(BaseModule): """BFP (Balanced Feature Pyramids) BFP takes multi-level features as inputs and gather them into a single one, then refine the gathered feature and scatter the refined results to multi-level features. This module is used in Libra R-CNN (CVPR 2019), see the paper `Libra R-CNN: Towards Balanced Learning for Object Detection <https://arxiv.org/abs/1904.02701>`_ for details. Args: in_channels (int): Number of input channels (feature maps of all levels should have the same channels). num_levels (int): Number of input feature levels. conv_cfg (dict): The config dict for convolution layers. norm_cfg (dict): The config dict for normalization layers. refine_level (int): Index of integration and refine level of BSF in multi-level features from bottom to top. refine_type (str): Type of the refine op, currently support [None, 'conv', 'non_local']. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, in_channels, num_levels, refine_level=2, refine_type=None, conv_cfg=None, norm_cfg=None, init_cfg=dict( type='Xavier', layer='Conv2d', distribution='uniform')): super(BFP, self).__init__(init_cfg) assert refine_type in [None, 'conv', 'non_local'] self.in_channels = in_channels self.num_levels = num_levels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.refine_level = refine_level self.refine_type = refine_type assert 0 <= self.refine_level < self.num_levels if self.refine_type == 'conv': self.refine = ConvModule( self.in_channels, self.in_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) elif self.refine_type == 'non_local': self.refine = NonLocal2d( self.in_channels, reduction=1, use_scale=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)
[文档] def forward(self, inputs): """Forward function.""" assert len(inputs) == self.num_levels # step 1: gather multi-level features by resize and average feats = [] gather_size = inputs[self.refine_level].size()[2:] for i in range(self.num_levels): if i < self.refine_level: gathered = F.adaptive_max_pool2d( inputs[i], output_size=gather_size) else: gathered = F.interpolate( inputs[i], size=gather_size, mode='nearest') feats.append(gathered) bsf = sum(feats) / len(feats) # step 2: refine gathered features if self.refine_type is not None: bsf = self.refine(bsf) # step 3: scatter refined features to multi-levels by a residual path outs = [] for i in range(self.num_levels): out_size = inputs[i].size()[2:] if i < self.refine_level: residual = F.interpolate(bsf, size=out_size, mode='nearest') else: residual = F.adaptive_max_pool2d(bsf, output_size=out_size) outs.append(residual + inputs[i]) return tuple(outs)
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.