Shortcuts

mmdet.models.necks.rfp 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import constant_init, xavier_init
from mmcv.runner import BaseModule, ModuleList

from ..builder import NECKS, build_backbone
from .fpn import FPN


class ASPP(BaseModule):
    """ASPP (Atrous Spatial Pyramid Pooling)

    This is an implementation of the ASPP module used in DetectoRS
    (https://arxiv.org/pdf/2006.02334.pdf)

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of channels produced by this module
        dilations (tuple[int]): Dilations of the four branches.
            Default: (1, 3, 6, 1)
        init_cfg (dict or list[dict], optional): Initialization config dict.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 dilations=(1, 3, 6, 1),
                 init_cfg=dict(type='Kaiming', layer='Conv2d')):
        super().__init__(init_cfg)
        assert dilations[-1] == 1
        self.aspp = nn.ModuleList()
        for dilation in dilations:
            kernel_size = 3 if dilation > 1 else 1
            padding = dilation if dilation > 1 else 0
            conv = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=1,
                dilation=dilation,
                padding=padding,
                bias=True)
            self.aspp.append(conv)
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        avg_x = self.gap(x)
        out = []
        for aspp_idx in range(len(self.aspp)):
            inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
            out.append(F.relu_(self.aspp[aspp_idx](inp)))
        out[-1] = out[-1].expand_as(out[-2])
        out = torch.cat(out, dim=1)
        return out


[文档]@NECKS.register_module() class RFP(FPN): """RFP (Recursive Feature Pyramid) This is an implementation of RFP in `DetectoRS <https://arxiv.org/pdf/2006.02334.pdf>`_. Different from standard FPN, the input of RFP should be multi level features along with origin input image of backbone. Args: rfp_steps (int): Number of unrolled steps of RFP. rfp_backbone (dict): Configuration of the backbone for RFP. aspp_out_channels (int): Number of output channels of ASPP module. aspp_dilations (tuple[int]): Dilation rates of four branches. Default: (1, 3, 6, 1) init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, rfp_steps, rfp_backbone, aspp_out_channels, aspp_dilations=(1, 3, 6, 1), init_cfg=None, **kwargs): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super().__init__(init_cfg=init_cfg, **kwargs) self.rfp_steps = rfp_steps # Be careful! Pretrained weights cannot be loaded when use # nn.ModuleList self.rfp_modules = ModuleList() for rfp_idx in range(1, rfp_steps): rfp_module = build_backbone(rfp_backbone) self.rfp_modules.append(rfp_module) self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels, aspp_dilations) self.rfp_weight = nn.Conv2d( self.out_channels, 1, kernel_size=1, stride=1, padding=0, bias=True)
[文档] def init_weights(self): # Avoid using super().init_weights(), which may alter the default # initialization of the modules in self.rfp_modules that have missing # keys in the pretrained checkpoint. for convs in [self.lateral_convs, self.fpn_convs]: for m in convs.modules(): if isinstance(m, nn.Conv2d): xavier_init(m, distribution='uniform') for rfp_idx in range(self.rfp_steps - 1): self.rfp_modules[rfp_idx].init_weights() constant_init(self.rfp_weight, 0)
[文档] def forward(self, inputs): inputs = list(inputs) assert len(inputs) == len(self.in_channels) + 1 # +1 for input image img = inputs.pop(0) # FPN forward x = super().forward(tuple(inputs)) for rfp_idx in range(self.rfp_steps - 1): rfp_feats = [x[0]] + list( self.rfp_aspp(x[i]) for i in range(1, len(x))) x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats) # FPN forward x_idx = super().forward(x_idx) x_new = [] for ft_idx in range(len(x_idx)): add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx])) x_new.append(add_weight * x_idx[ft_idx] + (1 - add_weight) * x[ft_idx]) x = x_new return x
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.