Shortcuts

Source code for mmdet.models.necks.fpg

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ..builder import NECKS


class Transition(BaseModule):
    """Base class for transition.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
    """

    def __init__(self, in_channels, out_channels, init_cfg=None):
        super().__init__(init_cfg)
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(x):
        pass


class UpInterpolationConv(Transition):
    """A transition used for up-sampling.

    Up-sample the input by interpolation then refines the feature by
    a convolution layer.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        scale_factor (int): Up-sampling factor. Default: 2.
        mode (int): Interpolation mode. Default: nearest.
        align_corners (bool): Whether align corners when interpolation.
            Default: None.
        kernel_size (int): Kernel size for the conv. Default: 3.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 scale_factor=2,
                 mode='nearest',
                 align_corners=None,
                 kernel_size=3,
                 init_cfg=None,
                 **kwargs):
        super().__init__(in_channels, out_channels, init_cfg)
        self.mode = mode
        self.scale_factor = scale_factor
        self.align_corners = align_corners
        self.conv = ConvModule(
            in_channels,
            out_channels,
            kernel_size,
            padding=(kernel_size - 1) // 2,
            **kwargs)

    def forward(self, x):
        x = F.interpolate(
            x,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners)
        x = self.conv(x)
        return x


class LastConv(Transition):
    """A transition used for refining the output of the last stage.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        num_inputs (int): Number of inputs of the FPN features.
        kernel_size (int): Kernel size for the conv. Default: 3.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 num_inputs,
                 kernel_size=3,
                 init_cfg=None,
                 **kwargs):
        super().__init__(in_channels, out_channels, init_cfg)
        self.num_inputs = num_inputs
        self.conv_out = ConvModule(
            in_channels,
            out_channels,
            kernel_size,
            padding=(kernel_size - 1) // 2,
            **kwargs)

    def forward(self, inputs):
        assert len(inputs) == self.num_inputs
        return self.conv_out(inputs[-1])


[docs]@NECKS.register_module() class FPG(BaseModule): """FPG. Implementation of `Feature Pyramid Grids (FPG) <https://arxiv.org/abs/2004.03580>`_. This implementation only gives the basic structure stated in the paper. But users can implement different type of transitions to fully explore the the potential power of the structure of FPG. Args: in_channels (int): Number of input channels (feature maps of all levels should have the same channels). out_channels (int): Number of output channels (used at each scale) num_outs (int): Number of output scales. stack_times (int): The number of times the pyramid architecture will be stacked. paths (list[str]): Specify the path order of each stack level. Each element in the list should be either 'bu' (bottom-up) or 'td' (top-down). inter_channels (int): Number of inter channels. same_up_trans (dict): Transition that goes down at the same stage. same_down_trans (dict): Transition that goes up at the same stage. across_lateral_trans (dict): Across-pathway same-stage across_down_trans (dict): Across-pathway bottom-up connection. across_up_trans (dict): Across-pathway top-down connection. across_skip_trans (dict): Across-pathway skip connection. output_trans (dict): Transition that trans the output of the last stage. start_level (int): Index of the start input backbone level used to build the feature pyramid. Default: 0. end_level (int): Index of the end input backbone level (exclusive) to build the feature pyramid. Default: -1, which means the last level. add_extra_convs (bool): It decides whether to add conv layers on top of the original feature maps. Default to False. If True, its actual mode is specified by `extra_convs_on_inputs`. norm_cfg (dict): Config dict for normalization layer. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. """ transition_types = { 'conv': ConvModule, 'interpolation_conv': UpInterpolationConv, 'last_conv': LastConv, } def __init__(self, in_channels, out_channels, num_outs, stack_times, paths, inter_channels=None, same_down_trans=None, same_up_trans=dict( type='conv', kernel_size=3, stride=2, padding=1), across_lateral_trans=dict(type='conv', kernel_size=1), across_down_trans=dict(type='conv', kernel_size=3), across_up_trans=None, across_skip_trans=dict(type='identity'), output_trans=dict(type='last_conv', kernel_size=3), start_level=0, end_level=-1, add_extra_convs=False, norm_cfg=None, skip_inds=None, init_cfg=[ dict(type='Caffe2Xavier', layer='Conv2d'), dict( type='Constant', layer=[ '_BatchNorm', '_InstanceNorm', 'GroupNorm', 'LayerNorm' ], val=1.0) ]): super(FPG, self).__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs if inter_channels is None: self.inter_channels = [out_channels for _ in range(num_outs)] elif isinstance(inter_channels, int): self.inter_channels = [inter_channels for _ in range(num_outs)] else: assert isinstance(inter_channels, list) assert len(inter_channels) == num_outs self.inter_channels = inter_channels self.stack_times = stack_times self.paths = paths assert isinstance(paths, list) and len(paths) == stack_times for d in paths: assert d in ('bu', 'td') self.same_down_trans = same_down_trans self.same_up_trans = same_up_trans self.across_lateral_trans = across_lateral_trans self.across_down_trans = across_down_trans self.across_up_trans = across_up_trans self.output_trans = output_trans self.across_skip_trans = across_skip_trans self.with_bias = norm_cfg is None # skip inds must be specified if across skip trans is not None if self.across_skip_trans is not None: skip_inds is not None self.skip_inds = skip_inds assert len(self.skip_inds[0]) <= self.stack_times if end_level == -1: self.backbone_end_level = self.num_ins assert num_outs >= self.num_ins - start_level else: # if end_level < inputs, no extra level is allowed self.backbone_end_level = end_level assert end_level <= len(in_channels) assert num_outs == end_level - start_level self.start_level = start_level self.end_level = end_level self.add_extra_convs = add_extra_convs # build lateral 1x1 convs to reduce channels self.lateral_convs = nn.ModuleList() for i in range(self.start_level, self.backbone_end_level): l_conv = nn.Conv2d(self.in_channels[i], self.inter_channels[i - self.start_level], 1) self.lateral_convs.append(l_conv) extra_levels = num_outs - self.backbone_end_level + self.start_level self.extra_downsamples = nn.ModuleList() for i in range(extra_levels): if self.add_extra_convs: fpn_idx = self.backbone_end_level - self.start_level + i extra_conv = nn.Conv2d( self.inter_channels[fpn_idx - 1], self.inter_channels[fpn_idx], 3, stride=2, padding=1) self.extra_downsamples.append(extra_conv) else: self.extra_downsamples.append(nn.MaxPool2d(1, stride=2)) self.fpn_transitions = nn.ModuleList() # stack times for s in range(self.stack_times): stage_trans = nn.ModuleList() # num of feature levels for i in range(self.num_outs): # same, across_lateral, across_down, across_up trans = nn.ModuleDict() if s in self.skip_inds[i]: stage_trans.append(trans) continue # build same-stage down trans (used in bottom-up paths) if i == 0 or self.same_up_trans is None: same_up_trans = None else: same_up_trans = self.build_trans( self.same_up_trans, self.inter_channels[i - 1], self.inter_channels[i]) trans['same_up'] = same_up_trans # build same-stage up trans (used in top-down paths) if i == self.num_outs - 1 or self.same_down_trans is None: same_down_trans = None else: same_down_trans = self.build_trans( self.same_down_trans, self.inter_channels[i + 1], self.inter_channels[i]) trans['same_down'] = same_down_trans # build across lateral trans across_lateral_trans = self.build_trans( self.across_lateral_trans, self.inter_channels[i], self.inter_channels[i]) trans['across_lateral'] = across_lateral_trans # build across down trans if i == self.num_outs - 1 or self.across_down_trans is None: across_down_trans = None else: across_down_trans = self.build_trans( self.across_down_trans, self.inter_channels[i + 1], self.inter_channels[i]) trans['across_down'] = across_down_trans # build across up trans if i == 0 or self.across_up_trans is None: across_up_trans = None else: across_up_trans = self.build_trans( self.across_up_trans, self.inter_channels[i - 1], self.inter_channels[i]) trans['across_up'] = across_up_trans if self.across_skip_trans is None: across_skip_trans = None else: across_skip_trans = self.build_trans( self.across_skip_trans, self.inter_channels[i - 1], self.inter_channels[i]) trans['across_skip'] = across_skip_trans # build across_skip trans stage_trans.append(trans) self.fpn_transitions.append(stage_trans) self.output_transition = nn.ModuleList() # output levels for i in range(self.num_outs): trans = self.build_trans( self.output_trans, self.inter_channels[i], self.out_channels, num_inputs=self.stack_times + 1) self.output_transition.append(trans) self.relu = nn.ReLU(inplace=True) def build_trans(self, cfg, in_channels, out_channels, **extra_args): cfg_ = cfg.copy() trans_type = cfg_.pop('type') trans_cls = self.transition_types[trans_type] return trans_cls(in_channels, out_channels, **cfg_, **extra_args) def fuse(self, fuse_dict): out = None for item in fuse_dict.values(): if item is not None: if out is None: out = item else: out = out + item return out
[docs] def forward(self, inputs): assert len(inputs) == len(self.in_channels) # build all levels from original feature maps feats = [ lateral_conv(inputs[i + self.start_level]) for i, lateral_conv in enumerate(self.lateral_convs) ] for downsample in self.extra_downsamples: feats.append(downsample(feats[-1])) outs = [feats] for i in range(self.stack_times): current_outs = outs[-1] next_outs = [] direction = self.paths[i] for j in range(self.num_outs): if i in self.skip_inds[j]: next_outs.append(outs[-1][j]) continue # feature level if direction == 'td': lvl = self.num_outs - j - 1 else: lvl = j # get transitions if direction == 'td': same_trans = self.fpn_transitions[i][lvl]['same_down'] else: same_trans = self.fpn_transitions[i][lvl]['same_up'] across_lateral_trans = self.fpn_transitions[i][lvl][ 'across_lateral'] across_down_trans = self.fpn_transitions[i][lvl]['across_down'] across_up_trans = self.fpn_transitions[i][lvl]['across_up'] across_skip_trans = self.fpn_transitions[i][lvl]['across_skip'] # init output to_fuse = dict( same=None, lateral=None, across_up=None, across_down=None) # same downsample/upsample if same_trans is not None: to_fuse['same'] = same_trans(next_outs[-1]) # across lateral if across_lateral_trans is not None: to_fuse['lateral'] = across_lateral_trans( current_outs[lvl]) # across downsample if lvl > 0 and across_up_trans is not None: to_fuse['across_up'] = across_up_trans(current_outs[lvl - 1]) # across upsample if (lvl < self.num_outs - 1 and across_down_trans is not None): to_fuse['across_down'] = across_down_trans( current_outs[lvl + 1]) if across_skip_trans is not None: to_fuse['across_skip'] = across_skip_trans(outs[0][lvl]) x = self.fuse(to_fuse) next_outs.append(x) if direction == 'td': outs.append(next_outs[::-1]) else: outs.append(next_outs) # output trans final_outs = [] for i in range(self.num_outs): lvl_out_list = [] for s in range(len(outs)): lvl_out_list.append(outs[s][i]) lvl_out = self.output_transition[i](lvl_out_list) final_outs.append(lvl_out) return final_outs
Read the Docs v: v2.17.0
Versions
latest
stable
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.