mmdet.models.utils.transformer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_activation_layer, build_conv_layer,
                      build_norm_layer, xavier_init)
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
                                      TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
                                         TransformerLayerSequence,
                                         build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from mmcv.utils import to_2tuple
from torch.nn.init import normal_

from mmdet.models.utils.builder import TRANSFORMER

try:
    from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention

except ImportError:
    warnings.warn(
        '`MultiScaleDeformableAttention` in MMCV has been moved to '
        '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
    from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention


class AdaptivePadding(nn.Module):
    """Applies padding to input (if needed) so that input can get fully covered
    by filter you specified. It support two modes "same" and "corner". The
    "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
    input. The "corner"  mode would pad zero to bottom right.

    Args:
        kernel_size (int | tuple): Size of the kernel:
        stride (int | tuple): Stride of the filter. Default: 1:
        dilation (int | tuple): Spacing between kernel elements.
            Default: 1
        padding (str): Support "same" and "corner", "corner" mode
            would pad zero to bottom right, and "same" mode would
            pad zero around input. Default: "corner".
    Example:
        >>> kernel_size = 16
        >>> stride = 16
        >>> dilation = 1
        >>> input = torch.rand(1, 1, 15, 17)
        >>> adap_pad = AdaptivePadding(
        >>>     kernel_size=kernel_size,
        >>>     stride=stride,
        >>>     dilation=dilation,
        >>>     padding="corner")
        >>> out = adap_pad(input)
        >>> assert (out.shape[2], out.shape[3]) == (16, 32)
        >>> input = torch.rand(1, 1, 16, 17)
        >>> out = adap_pad(input)
        >>> assert (out.shape[2], out.shape[3]) == (16, 32)
    """

    def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):

        super(AdaptivePadding, self).__init__()

        assert padding in ('same', 'corner')

        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        padding = to_2tuple(padding)
        dilation = to_2tuple(dilation)

        self.padding = padding
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation

    def get_pad_shape(self, input_shape):
        input_h, input_w = input_shape
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.stride
        output_h = math.ceil(input_h / stride_h)
        output_w = math.ceil(input_w / stride_w)
        pad_h = max((output_h - 1) * stride_h +
                    (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
        pad_w = max((output_w - 1) * stride_w +
                    (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
        return pad_h, pad_w

    def forward(self, x):
        pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
        if pad_h > 0 or pad_w > 0:
            if self.padding == 'corner':
                x = F.pad(x, [0, pad_w, 0, pad_h])
            elif self.padding == 'same':
                x = F.pad(x, [
                    pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
                    pad_h - pad_h // 2
                ])
        return x


class PatchEmbed(BaseModule):
    """Image to Patch Embedding.

    We use a conv layer to implement PatchEmbed.

    Args:
        in_channels (int): The num of input channels. Default: 3
        embed_dims (int): The dimensions of embedding. Default: 768
        conv_type (str): The config dict for embedding
            conv layer type selection. Default: "Conv2d.
        kernel_size (int): The kernel_size of embedding conv. Default: 16.
        stride (int): The slide stride of embedding conv.
            Default: None (Would be set as `kernel_size`).
        padding (int | tuple | string ): The padding length of
            embedding conv. When it is a string, it means the mode
            of adaptive padding, support "same" and "corner" now.
            Default: "corner".
        dilation (int): The dilation rate of embedding conv. Default: 1.
        bias (bool): Bias of embed conv. Default: True.
        norm_cfg (dict, optional): Config dict for normalization layer.
            Default: None.
        input_size (int | tuple | None): The size of input, which will be
            used to calculate the out size. Only work when `dynamic_size`
            is False. Default: None.
        init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
            Default: None.
    """

    def __init__(
        self,
        in_channels=3,
        embed_dims=768,
        conv_type='Conv2d',
        kernel_size=16,
        stride=16,
        padding='corner',
        dilation=1,
        bias=True,
        norm_cfg=None,
        input_size=None,
        init_cfg=None,
    ):
        super(PatchEmbed, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        if stride is None:
            stride = kernel_size

        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        dilation = to_2tuple(dilation)

        if isinstance(padding, str):
            self.adap_padding = AdaptivePadding(
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding)
            # disable the padding of conv
            padding = 0
        else:
            self.adap_padding = None
        padding = to_2tuple(padding)

        self.projection = build_conv_layer(
            dict(type=conv_type),
            in_channels=in_channels,
            out_channels=embed_dims,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias)

        if norm_cfg is not None:
            self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
        else:
            self.norm = None

        if input_size:
            input_size = to_2tuple(input_size)
            # `init_out_size` would be used outside to
            # calculate the num_patches
            # when `use_abs_pos_embed` outside
            self.init_input_size = input_size
            if self.adap_padding:
                pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
                input_h, input_w = input_size
                input_h = input_h + pad_h
                input_w = input_w + pad_w
                input_size = (input_h, input_w)

            # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
            h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
                     (kernel_size[0] - 1) - 1) // stride[0] + 1
            w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
                     (kernel_size[1] - 1) - 1) // stride[1] + 1
            self.init_out_size = (h_out, w_out)
        else:
            self.init_input_size = None
            self.init_out_size = None

    def forward(self, x):
        """
        Args:
            x (Tensor): Has shape (B, C, H, W). In most case, C is 3.

        Returns:
            tuple: Contains merged results and its spatial shape.

                - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
                - out_size (tuple[int]): Spatial shape of x, arrange as
                    (out_h, out_w).
        """

        if self.adap_padding:
            x = self.adap_padding(x)

        x = self.projection(x)
        out_size = (x.shape[2], x.shape[3])
        x = x.flatten(2).transpose(1, 2)
        if self.norm is not None:
            x = self.norm(x)
        return x, out_size


class PatchMerging(BaseModule):
    """Merge patch feature map.

    This layer groups feature map by kernel_size, and applies norm and linear
    layers to the grouped feature map. Our implementation uses `nn.Unfold` to
    merge patch, which is about 25% faster than original implementation.
    Instead, we need to modify pretrained models for compatibility.

    Args:
        in_channels (int): The num of input channels.
            to gets fully covered by filter and stride you specified..
            Default: True.
        out_channels (int): The num of output channels.
        kernel_size (int | tuple, optional): the kernel size in the unfold
            layer. Defaults to 2.
        stride (int | tuple, optional): the stride of the sliding blocks in the
            unfold layer. Default: None. (Would be set as `kernel_size`)
        padding (int | tuple | string ): The padding length of
            embedding conv. When it is a string, it means the mode
            of adaptive padding, support "same" and "corner" now.
            Default: "corner".
        dilation (int | tuple, optional): dilation parameter in the unfold
            layer. Default: 1.
        bias (bool, optional): Whether to add bias in linear layer or not.
            Defaults: False.
        norm_cfg (dict, optional): Config dict for normalization layer.
            Default: dict(type='LN').
        init_cfg (dict, optional): The extra config for initialization.
            Default: None.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=2,
                 stride=None,
                 padding='corner',
                 dilation=1,
                 bias=False,
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.in_channels = in_channels
        self.out_channels = out_channels
        if stride:
            stride = stride
        else:
            stride = kernel_size

        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        dilation = to_2tuple(dilation)

        if isinstance(padding, str):
            self.adap_padding = AdaptivePadding(
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                padding=padding)
            # disable the padding of unfold
            padding = 0
        else:
            self.adap_padding = None

        padding = to_2tuple(padding)
        self.sampler = nn.Unfold(
            kernel_size=kernel_size,
            dilation=dilation,
            padding=padding,
            stride=stride)

        sample_dim = kernel_size[0] * kernel_size[1] * in_channels

        if norm_cfg is not None:
            self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
        else:
            self.norm = None

        self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)

    def forward(self, x, input_size):
        """
        Args:
            x (Tensor): Has shape (B, H*W, C_in).
            input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
                Default: None.

        Returns:
            tuple: Contains merged results and its spatial shape.

                - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
                - out_size (tuple[int]): Spatial shape of x, arrange as
                    (Merged_H, Merged_W).
        """
        B, L, C = x.shape
        assert isinstance(input_size, Sequence), f'Expect ' \
                                                 f'input_size is ' \
                                                 f'`Sequence` ' \
                                                 f'but get {input_size}'

        H, W = input_size
        assert L == H * W, 'input feature has wrong size'

        x = x.view(B, H, W, C).permute([0, 3, 1, 2])  # B, C, H, W
        # Use nn.Unfold to merge patch. About 25% faster than original method,
        # but need to modify pretrained model for compatibility

        if self.adap_padding:
            x = self.adap_padding(x)
            H, W = x.shape[-2:]

        x = self.sampler(x)
        # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)

        out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
                 (self.sampler.kernel_size[0] - 1) -
                 1) // self.sampler.stride[0] + 1
        out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
                 (self.sampler.kernel_size[1] - 1) -
                 1) // self.sampler.stride[1] + 1

        output_size = (out_h, out_w)
        x = x.transpose(1, 2)  # B, H/2*W/2, 4*C
        x = self.norm(x) if self.norm else x
        x = self.reduction(x)
        return x, output_size


def inverse_sigmoid(x, eps=1e-5):
    """Inverse function of sigmoid.

    Args:
        x (Tensor): The tensor to do the
            inverse.
        eps (float): EPS avoid numerical
            overflow. Defaults 1e-5.
    Returns:
        Tensor: The x has passed the inverse
            function of sigmoid, has same
            shape with input.
    """
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)


[文档]@TRANSFORMER_LAYER.register_module() class DetrTransformerDecoderLayer(BaseTransformerLayer): """Implements decoder layer in DETR transformer. Args: attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): Configs for self_attention or cross_attention, the order should be consistent with it in `operation_order`. If it is a dict, it would be expand to the number of attention in `operation_order`. feedforward_channels (int): The hidden dimension for FFNs. ffn_dropout (float): Probability of an element to be zeroed in ffn. Default 0.0. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Default:None act_cfg (dict): The activation config for FFNs. Default: `LN` norm_cfg (dict): Config dict for normalization layer. Default: `LN`. ffn_num_fcs (int): The number of fully-connected layers in FFNs. Default:2. """ def __init__(self, attn_cfgs, feedforward_channels, ffn_dropout=0.0, operation_order=None, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), ffn_num_fcs=2, **kwargs): super(DetrTransformerDecoderLayer, self).__init__( attn_cfgs=attn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order, act_cfg=act_cfg, norm_cfg=norm_cfg, ffn_num_fcs=ffn_num_fcs, **kwargs) assert len(operation_order) == 6 assert set(operation_order) == set( ['self_attn', 'norm', 'cross_attn', 'ffn'])
@TRANSFORMER_LAYER_SEQUENCE.register_module() class DetrTransformerEncoder(TransformerLayerSequence): """TransformerEncoder of DETR. Args: post_norm_cfg (dict): Config of last normalization layer. Default: `LN`. Only used when `self.pre_norm` is `True` """ def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs): super(DetrTransformerEncoder, self).__init__(*args, **kwargs) if post_norm_cfg is not None: self.post_norm = build_norm_layer( post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None else: assert not self.pre_norm, f'Use prenorm in ' \ f'{self.__class__.__name__},' \ f'Please specify post_norm_cfg' self.post_norm = None def forward(self, *args, **kwargs): """Forward function for `TransformerCoder`. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) if self.post_norm is not None: x = self.post_norm(x) return x
[文档]@TRANSFORMER_LAYER_SEQUENCE.register_module() class DetrTransformerDecoder(TransformerLayerSequence): """Implements the decoder in DETR transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. post_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, post_norm_cfg=dict(type='LN'), return_intermediate=False, **kwargs): super(DetrTransformerDecoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate if post_norm_cfg is not None: self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] else: self.post_norm = None
[文档] def forward(self, query, *args, **kwargs): """Forward function for `TransformerDecoder`. Args: query (Tensor): Input query with shape `(num_query, bs, embed_dims)`. Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ if not self.return_intermediate: x = super().forward(query, *args, **kwargs) if self.post_norm: x = self.post_norm(x)[None] return x intermediate = [] for layer in self.layers: query = layer(query, *args, **kwargs) if self.return_intermediate: if self.post_norm is not None: intermediate.append(self.post_norm(query)) else: intermediate.append(query) return torch.stack(intermediate)
[文档]@TRANSFORMER.register_module() class Transformer(BaseModule): """Implements the DETR transformer. Following the official DETR implementation, this module copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MultiheadAttention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers See `paper: End-to-End Object Detection with Transformers <https://arxiv.org/pdf/2005.12872>`_ for details. Args: encoder (`mmcv.ConfigDict` | Dict): Config of TransformerEncoder. Defaults to None. decoder ((`mmcv.ConfigDict` | Dict)): Config of TransformerDecoder. Defaults to None init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Defaults to None. """ def __init__(self, encoder=None, decoder=None, init_cfg=None): super(Transformer, self).__init__(init_cfg=init_cfg) self.encoder = build_transformer_layer_sequence(encoder) self.decoder = build_transformer_layer_sequence(decoder) self.embed_dims = self.encoder.embed_dims
[文档] def init_weights(self): # follow the official DETR to init parameters for m in self.modules(): if hasattr(m, 'weight') and m.weight.dim() > 1: xavier_init(m, distribution='uniform') self._is_init = True
[文档] def forward(self, x, mask, query_embed, pos_embed): """Forward function for `Transformer`. Args: x (Tensor): Input query with shape [bs, c, h, w] where c = embed_dims. mask (Tensor): The key_padding_mask used for encoder and decoder, with shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. pos_embed (Tensor): The positional encoding for encoder and decoder, with the same shape as `x`. Returns: tuple[Tensor]: results of decoder containing the following tensor. - out_dec: Output from decoder. If return_intermediate_dec \ is True output has shape [num_dec_layers, bs, num_query, embed_dims], else has shape [1, bs, \ num_query, embed_dims]. - memory: Output results from encoder, with shape \ [bs, embed_dims, h, w]. """ bs, c, h, w = x.shape # use `view` instead of `flatten` for dynamically exporting to ONNX x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat( 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] memory = self.encoder( query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) target = torch.zeros_like(query_embed) # out_dec: [num_layers, num_query, bs, dim] out_dec = self.decoder( query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask) out_dec = out_dec.transpose(1, 2) memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) return out_dec, memory
@TRANSFORMER_LAYER_SEQUENCE.register_module() class DeformableDetrTransformerDecoder(TransformerLayerSequence): """Implements the decoder in DETR transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. coder_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, return_intermediate=False, **kwargs): super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): """Forward function for `TransformerDecoder`. Args: query (Tensor): Input query with shape `(num_query, bs, embed_dims)`. reference_points (Tensor): The reference points of offset. has shape (bs, num_query, 4) when as_two_stage, otherwise has shape ((bs, num_query, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) reg_branch: (obj:`nn.ModuleList`): Used for refining the regression results. Only would be passed when with_box_refine is True, otherwise would be passed a `None`. Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ output = query intermediate = [] intermediate_reference_points = [] for lid, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] * \ torch.cat([valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * \ valid_ratios[:, None] output = layer( output, *args, reference_points=reference_points_input, **kwargs) output = output.permute(1, 0, 2) if reg_branches is not None: tmp = reg_branches[lid](output) if reference_points.shape[-1] == 4: new_reference_points = tmp + inverse_sigmoid( reference_points) new_reference_points = new_reference_points.sigmoid() else: assert reference_points.shape[-1] == 2 new_reference_points = tmp new_reference_points[..., :2] = tmp[ ..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() output = output.permute(1, 0, 2) if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points @TRANSFORMER.register_module() class DeformableDetrTransformer(Transformer): """Implements the DeformableDETR transformer. Args: as_two_stage (bool): Generate query from encoder features. Default: False. num_feature_levels (int): Number of feature maps from FPN: Default: 4. two_stage_num_proposals (int): Number of proposals when set `as_two_stage` as True. Default: 300. """ def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): super(DeformableDetrTransformer, self).__init__(**kwargs) self.as_two_stage = as_two_stage self.num_feature_levels = num_feature_levels self.two_stage_num_proposals = two_stage_num_proposals self.embed_dims = self.encoder.embed_dims self.init_layers() def init_layers(self): """Initialize layers of the DeformableDetrTransformer.""" self.level_embeds = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims)) if self.as_two_stage: self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) self.enc_output_norm = nn.LayerNorm(self.embed_dims) self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) else: self.reference_points = nn.Linear(self.embed_dims, 2) def init_weights(self): """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() if not self.as_two_stage: xavier_init(self.reference_points, distribution='uniform', bias=0.) normal_(self.level_embeds) def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): """Generate proposals from encoded memory. Args: memory (Tensor) : The output of encoder, has shape (bs, num_key, embed_dim). num_key is equal the number of points on feature map from all level. memory_padding_mask (Tensor): Padding mask for memory. has shape (bs, num_key). spatial_shapes (Tensor): The shape of all feature maps. has shape (num_level, 2). Returns: tuple: A tuple of feature map and bbox prediction. - output_memory (Tensor): The input of decoder, \ has shape (bs, num_key, embed_dim). num_key is \ equal the number of points on feature map from \ all levels. - output_proposals (Tensor): The normalized proposal \ after a inverse sigmoid, has shape \ (bs, num_keys, 4). """ N, S, C = memory.shape proposals = [] _cur = 0 for lvl, (H, W) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view( N, H, W, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid( torch.linspace( 0, H - 1, H, dtype=torch.float32, device=memory.device), torch.linspace( 0, W - 1, W, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) proposal = torch.cat((grid, wh), -1).view(N, -1, 4) proposals.append(proposal) _cur += (H * W) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( -1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill( memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill( ~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill( memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = self.enc_output_norm(self.enc_output(output_memory)) return output_memory, output_proposals @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): """Get the reference points used in decoder. Args: spatial_shapes (Tensor): The shape of all feature maps, has shape (num_level, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) device (obj:`device`): The device where reference_points should be. Returns: Tensor: reference points used in decoder, has \ shape (bs, num_keys, num_levels, 2). """ reference_points_list = [] for lvl, (H, W) in enumerate(spatial_shapes): # TODO check this 0.5 ref_y, ref_x = torch.meshgrid( torch.linspace( 0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace( 0.5, W - 0.5, W, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 0] * W) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def get_valid_ratio(self, mask): """Get the valid radios of feature maps of all level.""" _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): """Get the position embedding of proposal.""" scale = 2 * math.pi dim_t = torch.arange( num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) # N, L, 4 proposals = proposals.sigmoid() * scale # N, L, 4, 128 pos = proposals[:, :, :, None] / dim_t # N, L, 4, 64, 2 pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) return pos def forward(self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs): """Forward function for `Transformer`. Args: mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, embed_dims, h, w]. mlvl_masks (list(Tensor)): The key_padding_mask from different level used for encoder and decoder, each element has shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. mlvl_pos_embeds (list(Tensor)): The positional encoding of feats from different level, has the shape [bs, embed_dims, h, w]. reg_branches (obj:`nn.ModuleList`): Regression heads for feature maps from each decoder layer. Only would be passed when `with_box_refine` is True. Default to None. cls_branches (obj:`nn.ModuleList`): Classification heads for feature maps from each decoder layer. Only would be passed when `as_two_stage` is True. Default to None. Returns: tuple[Tensor]: results of decoder containing the following tensor. - inter_states: Outputs from decoder. If return_intermediate_dec is True output has shape \ (num_dec_layers, bs, num_query, embed_dims), else has \ shape (1, bs, num_query, embed_dims). - init_reference_out: The initial value of reference \ points, has shape (bs, num_queries, 4). - inter_references_out: The internal value of reference \ points in decoder, has shape \ (num_dec_layers, bs,num_query, embed_dims) - enc_outputs_class: The classification score of \ proposals generated from \ encoder's feature maps, has shape \ (batch, h*w, num_classes). \ Only would be returned when `as_two_stage` is True, \ otherwise None. - enc_outputs_coord_unact: The regression results \ generated from encoder's feature maps., has shape \ (batch, h*w, 4). Only would \ be returned when `as_two_stage` is True, \ otherwise None. """ assert self.as_two_stage or query_embed is not None feat_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1) reference_points = \ self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( 1, 0, 2) # (H*W, bs, embed_dims) memory = self.encoder( query=feat_flatten, key=None, value=None, query_pos=lvl_pos_embed_flatten, query_key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, reference_points=reference_points, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs) memory = memory.permute(1, 0, 2) bs, _, c = memory.shape if self.as_two_stage: output_memory, output_proposals = \ self.gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes) enc_outputs_class = cls_branches[self.decoder.num_layers]( output_memory) enc_outputs_coord_unact = \ reg_branches[ self.decoder.num_layers](output_memory) + output_proposals topk = self.two_stage_num_proposals topk_proposals = torch.topk( enc_outputs_class[..., 0], topk, dim=1)[1] topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) topk_coords_unact = topk_coords_unact.detach() reference_points = topk_coords_unact.sigmoid() init_reference_out = reference_points pos_trans_out = self.pos_trans_norm( self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) query_pos, query = torch.split(pos_trans_out, c, dim=2) else: query_pos, query = torch.split(query_embed, c, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos).sigmoid() init_reference_out = reference_points # decoder query = query.permute(1, 0, 2) memory = memory.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=memory, query_pos=query_pos, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, **kwargs) inter_references_out = inter_references if self.as_two_stage: return inter_states, init_reference_out,\ inter_references_out, enc_outputs_class,\ enc_outputs_coord_unact return inter_states, init_reference_out, \ inter_references_out, None, None
[文档]@TRANSFORMER.register_module() class DynamicConv(BaseModule): """Implements Dynamic Convolution. This module generate parameters for each sample and use bmm to implement 1*1 convolution. Code is modified from the `official github repo <https://github.com/PeizeSun/ SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ . Args: in_channels (int): The input feature channel. Defaults to 256. feat_channels (int): The inner feature channel. Defaults to 64. out_channels (int, optional): The output feature channel. When not specified, it will be set to `in_channels` by default input_feat_shape (int): The shape of input feature. Defaults to 7. act_cfg (dict): The activation config for DynamicConv. norm_cfg (dict): Config dict for normalization layer. Default layer normalization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, in_channels=256, feat_channels=64, out_channels=None, input_feat_shape=7, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), init_cfg=None): super(DynamicConv, self).__init__(init_cfg) self.in_channels = in_channels self.feat_channels = feat_channels self.out_channels_raw = out_channels self.input_feat_shape = input_feat_shape self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.out_channels = out_channels if out_channels else in_channels self.num_params_in = self.in_channels * self.feat_channels self.num_params_out = self.out_channels * self.feat_channels self.dynamic_layer = nn.Linear( self.in_channels, self.num_params_in + self.num_params_out) self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] self.activation = build_activation_layer(act_cfg) num_output = self.out_channels * input_feat_shape**2 self.fc_layer = nn.Linear(num_output, self.out_channels) self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
[文档] def forward(self, param_feature, input_feature): """Forward function for `DynamicConv`. Args: param_feature (Tensor): The feature can be used to generate the parameter, has shape (num_all_proposals, in_channels). input_feature (Tensor): Feature that interact with parameters, has shape (num_all_proposals, in_channels, H, W). Returns: Tensor: The output feature has shape (num_all_proposals, out_channels). """ num_proposals = param_feature.size(0) input_feature = input_feature.view(num_proposals, self.in_channels, -1).permute(2, 0, 1) input_feature = input_feature.permute(1, 0, 2) parameters = self.dynamic_layer(param_feature) param_in = parameters[:, :self.num_params_in].view( -1, self.in_channels, self.feat_channels) param_out = parameters[:, -self.num_params_out:].view( -1, self.feat_channels, self.out_channels) # input_feature has shape (num_all_proposals, H*W, in_channels) # param_in has shape (num_all_proposals, in_channels, feat_channels) # feature has shape (num_all_proposals, H*W, feat_channels) features = torch.bmm(input_feature, param_in) features = self.norm_in(features) features = self.activation(features) # param_out has shape (batch_size, feat_channels, out_channels) features = torch.bmm(features, param_out) features = self.norm_out(features) features = self.activation(features) features = features.flatten(1) features = self.fc_layer(features) features = self.fc_norm(features) features = self.activation(features) return features