Shortcuts

mmdet.models.backbones.hourglass 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import BasicBlock


class HourglassModule(BaseModule):
    """Hourglass Module for HourglassNet backbone.

    Generate module recursively and use BasicBlock as the base unit.

    Args:
        depth (int): Depth of current HourglassModule.
        stage_channels (list[int]): Feature channels of sub-modules in current
            and follow-up HourglassModule.
        stage_blocks (list[int]): Number of sub-modules stacked in current and
            follow-up HourglassModule.
        norm_cfg (dict): Dictionary to construct and config norm layer.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None
        upsample_cfg (dict, optional): Config dict for interpolate layer.
            Default: `dict(mode='nearest')`
    """

    def __init__(self,
                 depth,
                 stage_channels,
                 stage_blocks,
                 norm_cfg=dict(type='BN', requires_grad=True),
                 init_cfg=None,
                 upsample_cfg=dict(mode='nearest')):
        super(HourglassModule, self).__init__(init_cfg)

        self.depth = depth

        cur_block = stage_blocks[0]
        next_block = stage_blocks[1]

        cur_channel = stage_channels[0]
        next_channel = stage_channels[1]

        self.up1 = ResLayer(
            BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)

        self.low1 = ResLayer(
            BasicBlock,
            cur_channel,
            next_channel,
            cur_block,
            stride=2,
            norm_cfg=norm_cfg)

        if self.depth > 1:
            self.low2 = HourglassModule(depth - 1, stage_channels[1:],
                                        stage_blocks[1:])
        else:
            self.low2 = ResLayer(
                BasicBlock,
                next_channel,
                next_channel,
                next_block,
                norm_cfg=norm_cfg)

        self.low3 = ResLayer(
            BasicBlock,
            next_channel,
            cur_channel,
            cur_block,
            norm_cfg=norm_cfg,
            downsample_first=False)

        self.up2 = F.interpolate
        self.upsample_cfg = upsample_cfg

    def forward(self, x):
        """Forward function."""
        up1 = self.up1(x)
        low1 = self.low1(x)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        # Fixing `scale factor` (e.g. 2) is common for upsampling, but
        # in some cases the spatial size is mismatched and error will arise.
        if 'scale_factor' in self.upsample_cfg:
            up2 = self.up2(low3, **self.upsample_cfg)
        else:
            shape = up1.shape[2:]
            up2 = self.up2(low3, size=shape, **self.upsample_cfg)
        return up1 + up2


[文档]@BACKBONES.register_module() class HourglassNet(BaseModule): """HourglassNet backbone. Stacked Hourglass Networks for Human Pose Estimation. More details can be found in the `paper <https://arxiv.org/abs/1603.06937>`_ . Args: downsample_times (int): Downsample times in a HourglassModule. num_stacks (int): Number of HourglassModule modules stacked, 1 for Hourglass-52, 2 for Hourglass-104. stage_channels (list[int]): Feature channel of each sub-module in a HourglassModule. stage_blocks (list[int]): Number of sub-modules stacked in a HourglassModule. feat_channel (int): Feature channel of conv after a HourglassModule. norm_cfg (dict): Dictionary to construct and config norm layer. pretrained (str, optional): model pretrained path. Default: None init_cfg (dict or list[dict], optional): Initialization config dict. Default: None Example: >>> from mmdet.models import HourglassNet >>> import torch >>> self = HourglassNet() >>> self.eval() >>> inputs = torch.rand(1, 3, 511, 511) >>> level_outputs = self.forward(inputs) >>> for level_output in level_outputs: ... print(tuple(level_output.shape)) (1, 256, 128, 128) (1, 256, 128, 128) """ def __init__(self, downsample_times=5, num_stacks=2, stage_channels=(256, 256, 384, 384, 384, 512), stage_blocks=(2, 2, 2, 2, 2, 4), feat_channel=256, norm_cfg=dict(type='BN', requires_grad=True), pretrained=None, init_cfg=None): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super(HourglassNet, self).__init__(init_cfg) self.num_stacks = num_stacks assert self.num_stacks >= 1 assert len(stage_channels) == len(stage_blocks) assert len(stage_channels) > downsample_times cur_channel = stage_channels[0] self.stem = nn.Sequential( ConvModule( 3, cur_channel // 2, 7, padding=3, stride=2, norm_cfg=norm_cfg), ResLayer( BasicBlock, cur_channel // 2, cur_channel, 1, stride=2, norm_cfg=norm_cfg)) self.hourglass_modules = nn.ModuleList([ HourglassModule(downsample_times, stage_channels, stage_blocks) for _ in range(num_stacks) ]) self.inters = ResLayer( BasicBlock, cur_channel, cur_channel, num_stacks - 1, norm_cfg=norm_cfg) self.conv1x1s = nn.ModuleList([ ConvModule( cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) for _ in range(num_stacks - 1) ]) self.out_convs = nn.ModuleList([ ConvModule( cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg) for _ in range(num_stacks) ]) self.remap_convs = nn.ModuleList([ ConvModule( feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None) for _ in range(num_stacks - 1) ]) self.relu = nn.ReLU(inplace=True)
[文档] def init_weights(self): """Init module weights.""" # Training Centripetal Model needs to reset parameters for Conv2d super(HourglassNet, self).init_weights() for m in self.modules(): if isinstance(m, nn.Conv2d): m.reset_parameters()
[文档] def forward(self, x): """Forward function.""" inter_feat = self.stem(x) out_feats = [] for ind in range(self.num_stacks): single_hourglass = self.hourglass_modules[ind] out_conv = self.out_convs[ind] hourglass_feat = single_hourglass(inter_feat) out_feat = out_conv(hourglass_feat) out_feats.append(out_feat) if ind < self.num_stacks - 1: inter_feat = self.conv1x1s[ind]( inter_feat) + self.remap_convs[ind]( out_feat) inter_feat = self.inters[ind](self.relu(inter_feat)) return out_feats
Read the Docs v: v2.21.0
Versions
latest
stable
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.