Shortcuts

mmdet.models.necks.hrfpn 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.utils.checkpoint import checkpoint

from ..builder import NECKS


[文档]@NECKS.register_module() class HRFPN(BaseModule): """HRFPN (High Resolution Feature Pyramids) paper: `High-Resolution Representations for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_. Args: in_channels (list): number of channels for each branch. out_channels (int): output channels of feature pyramids. num_outs (int): number of output stages. pooling_type (str): pooling for generating feature pyramids from {MAX, AVG}. conv_cfg (dict): dictionary to construct and config conv layer. norm_cfg (dict): dictionary to construct and config norm layer. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. stride (int): stride of 3x3 convolutional layers init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, in_channels, out_channels, num_outs=5, pooling_type='AVG', conv_cfg=None, norm_cfg=None, with_cp=False, stride=1, init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')): super(HRFPN, self).__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.with_cp = with_cp self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.reduction_conv = ConvModule( sum(in_channels), out_channels, kernel_size=1, conv_cfg=self.conv_cfg, act_cfg=None) self.fpn_convs = nn.ModuleList() for i in range(self.num_outs): self.fpn_convs.append( ConvModule( out_channels, out_channels, kernel_size=3, padding=1, stride=stride, conv_cfg=self.conv_cfg, act_cfg=None)) if pooling_type == 'MAX': self.pooling = F.max_pool2d else: self.pooling = F.avg_pool2d
[文档] def forward(self, inputs): """Forward function.""" assert len(inputs) == self.num_ins outs = [inputs[0]] for i in range(1, self.num_ins): outs.append( F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear')) out = torch.cat(outs, dim=1) if out.requires_grad and self.with_cp: out = checkpoint(self.reduction_conv, out) else: out = self.reduction_conv(out) outs = [out] for i in range(1, self.num_outs): outs.append(self.pooling(out, kernel_size=2**i, stride=2**i)) outputs = [] for i in range(self.num_outs): if outs[i].requires_grad and self.with_cp: tmp_out = checkpoint(self.fpn_convs[i], outs[i]) else: tmp_out = self.fpn_convs[i](outs[i]) outputs.append(tmp_out) return tuple(outputs)
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.