Shortcuts

mmdet.models.necks.channel_mapper 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ..builder import NECKS


[文档]@NECKS.register_module() class ChannelMapper(BaseModule): r"""Channel Mapper to reduce/increase channels of backbone features. This is used to reduce/increase channels of backbone features. Args: in_channels (List[int]): Number of input channels per scale. out_channels (int): Number of output channels (used at each scale). kernel_size (int, optional): kernel_size for reducing channels (used at each scale). Default: 3. conv_cfg (dict, optional): Config dict for convolution layer. Default: None. norm_cfg (dict, optional): Config dict for normalization layer. Default: None. act_cfg (dict, optional): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). num_outs (int, optional): Number of output feature maps. There would be extra_convs when num_outs larger than the length of in_channels. init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [2, 3, 5, 7] >>> scales = [340, 170, 84, 43] >>> inputs = [torch.rand(1, c, s, s) ... for c, s in zip(in_channels, scales)] >>> self = ChannelMapper(in_channels, 11, 3).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 11, 340, 340]) outputs[1].shape = torch.Size([1, 11, 170, 170]) outputs[2].shape = torch.Size([1, 11, 84, 84]) outputs[3].shape = torch.Size([1, 11, 43, 43]) """ def __init__(self, in_channels, out_channels, kernel_size=3, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), num_outs=None, init_cfg=dict( type='Xavier', layer='Conv2d', distribution='uniform')): super(ChannelMapper, self).__init__(init_cfg) assert isinstance(in_channels, list) self.extra_convs = None if num_outs is None: num_outs = len(in_channels) self.convs = nn.ModuleList() for in_channel in in_channels: self.convs.append( ConvModule( in_channel, out_channels, kernel_size, padding=(kernel_size - 1) // 2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) if num_outs > len(in_channels): self.extra_convs = nn.ModuleList() for i in range(len(in_channels), num_outs): if i == len(in_channels): in_channel = in_channels[-1] else: in_channel = out_channels self.extra_convs.append( ConvModule( in_channel, out_channels, 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg))
[文档] def forward(self, inputs): """Forward function.""" assert len(inputs) == len(self.convs) outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] if self.extra_convs: for i in range(len(self.extra_convs)): if i == 0: outs.append(self.extra_convs[0](inputs[-1])) else: outs.append(self.extra_convs[i](outs[-1])) return tuple(outs)
Read the Docs v: dev
Versions
latest
stable
3.x
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.