Shortcuts

mmdet.models.utils.positional_encoding 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
from mmcv.runner import BaseModule


[文档]@POSITIONAL_ENCODING.register_module() class SinePositionalEncoding(BaseModule): """Position encoding with sine and cosine functions. See `End-to-End Object Detection with Transformers <https://arxiv.org/pdf/2005.12872>`_ for details. Args: num_feats (int): The feature dimension for each position along x-axis or y-axis. Note the final returned dimension for each position is 2 times of this value. temperature (int, optional): The temperature used for scaling the position embedding. Defaults to 10000. normalize (bool, optional): Whether to normalize the position embedding. Defaults to False. scale (float, optional): A scale factor that scales the position embedding. The scale will be used only when `normalize` is True. Defaults to 2*pi. eps (float, optional): A value added to the denominator for numerical stability. Defaults to 1e-6. offset (float): offset add to embed when do the normalization. Defaults to 0. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0., init_cfg=None): super(SinePositionalEncoding, self).__init__(init_cfg) if normalize: assert isinstance(scale, (float, int)), 'when normalize is set,' \ 'scale should be provided and in float or int type, ' \ f'found {type(scale)}' self.num_feats = num_feats self.temperature = temperature self.normalize = normalize self.scale = scale self.eps = eps self.offset = offset
[文档] def forward(self, mask): """Forward function for `SinePositionalEncoding`. Args: mask (Tensor): ByteTensor mask. Non-zero values representing ignored positions, while zero values means valid positions for this image. Shape [bs, h, w]. Returns: pos (Tensor): Returned position embedding with shape [bs, num_feats*2, h, w]. """ # For convenience of exporting to ONNX, it's required to convert # `masks` from bool to int. mask = mask.to(torch.int) not_mask = 1 - mask # logical_not y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: y_embed = (y_embed + self.offset) / \ (y_embed[:, -1:, :] + self.eps) * self.scale x_embed = (x_embed + self.offset) / \ (x_embed[:, :, -1:] + self.eps) * self.scale dim_t = torch.arange( self.num_feats, dtype=torch.float32, device=mask.device) dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t # use `view` instead of `flatten` for dynamically exporting to ONNX B, H, W = mask.size() pos_x = torch.stack( (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) pos_y = torch.stack( (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos
def __repr__(self): """str: a string that describes the module""" repr_str = self.__class__.__name__ repr_str += f'(num_feats={self.num_feats}, ' repr_str += f'temperature={self.temperature}, ' repr_str += f'normalize={self.normalize}, ' repr_str += f'scale={self.scale}, ' repr_str += f'eps={self.eps})' return repr_str
[文档]@POSITIONAL_ENCODING.register_module() class LearnedPositionalEncoding(BaseModule): """Position embedding with learnable embedding weights. Args: num_feats (int): The feature dimension for each position along x-axis or y-axis. The final returned dimension for each position is 2 times of this value. row_num_embed (int, optional): The dictionary size of row embeddings. Default 50. col_num_embed (int, optional): The dictionary size of col embeddings. Default 50. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type='Uniform', layer='Embedding')): super(LearnedPositionalEncoding, self).__init__(init_cfg) self.row_embed = nn.Embedding(row_num_embed, num_feats) self.col_embed = nn.Embedding(col_num_embed, num_feats) self.num_feats = num_feats self.row_num_embed = row_num_embed self.col_num_embed = col_num_embed
[文档] def forward(self, mask): """Forward function for `LearnedPositionalEncoding`. Args: mask (Tensor): ByteTensor mask. Non-zero values representing ignored positions, while zero values means valid positions for this image. Shape [bs, h, w]. Returns: pos (Tensor): Returned position embedding with shape [bs, num_feats*2, h, w]. """ h, w = mask.shape[-2:] x = torch.arange(w, device=mask.device) y = torch.arange(h, device=mask.device) x_embed = self.col_embed(x) y_embed = self.row_embed(y) pos = torch.cat( (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( 1, w, 1)), dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) return pos
def __repr__(self): """str: a string that describes the module""" repr_str = self.__class__.__name__ repr_str += f'(num_feats={self.num_feats}, ' repr_str += f'row_num_embed={self.row_num_embed}, ' repr_str += f'col_num_embed={self.col_num_embed})' return repr_str
Read the Docs v: dev
Versions
latest
stable
3.x
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.