Shortcuts

Source code for mmdet.core.utils.misc

# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import numpy as np
import torch
from six.moves import map, zip

from ..mask.structures import BitmapMasks, PolygonMasks


[docs]def multi_apply(func, *args, **kwargs): """Apply function to a list of arguments. Note: This function applies the ``func`` to multiple inputs and map the multiple outputs of the ``func`` into different list. Each list contains the same type of outputs corresponding to different inputs. Args: func (Function): A function that will be applied to a list of arguments Returns: tuple(list): A tuple containing multiple list, each list contains \ a kind of returned results by the function """ pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results)))
[docs]def unmap(data, count, inds, fill=0): """Unmap a subset of item (data) back to the original set of items (of size count)""" if data.dim() == 1: ret = data.new_full((count, ), fill) ret[inds.type(torch.bool)] = data else: new_size = (count, ) + data.size()[1:] ret = data.new_full(new_size, fill) ret[inds.type(torch.bool), :] = data return ret
[docs]def mask2ndarray(mask): """Convert Mask to ndarray.. Args: mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or torch.Tensor or np.ndarray): The mask to be converted. Returns: np.ndarray: Ndarray mask of shape (n, h, w) that has been converted """ if isinstance(mask, (BitmapMasks, PolygonMasks)): mask = mask.to_ndarray() elif isinstance(mask, torch.Tensor): mask = mask.detach().cpu().numpy() elif not isinstance(mask, np.ndarray): raise TypeError(f'Unsupported {type(mask)} data type') return mask
[docs]def flip_tensor(src_tensor, flip_direction): """flip tensor base on flip_direction. Args: src_tensor (Tensor): input feature map, shape (B, C, H, W). flip_direction (str): The flipping direction. Options are 'horizontal', 'vertical', 'diagonal'. Returns: out_tensor (Tensor): Flipped tensor. """ assert src_tensor.ndim == 4 valid_directions = ['horizontal', 'vertical', 'diagonal'] assert flip_direction in valid_directions if flip_direction == 'horizontal': out_tensor = torch.flip(src_tensor, [3]) elif flip_direction == 'vertical': out_tensor = torch.flip(src_tensor, [2]) else: out_tensor = torch.flip(src_tensor, [2, 3]) return out_tensor
[docs]def center_of_mass(mask, esp=1e-6): """Calculate the centroid coordinates of the mask. Args: mask (Tensor): The mask to be calculated, shape (h, w). esp (float): Avoid dividing by zero. Default: 1e-6. Returns: tuple[Tensor]: the coordinates of the center point of the mask. - center_h (Tensor): the center point of the height. - center_w (Tensor): the center point of the width. """ h, w = mask.shape grid_h = torch.arange(h, device=mask.device)[:, None] grid_w = torch.arange(w, device=mask.device) normalizer = mask.sum().float().clamp(min=esp) center_h = (mask * grid_h).sum() / normalizer center_w = (mask * grid_w).sum() / normalizer return center_h, center_w
[docs]def generate_coordinate(featmap_sizes, device='cuda'): """Generate the coordinate. Args: featmap_sizes (tuple): The feature to be calculated, of shape (N, C, W, H). device (str): The device where the feature will be put on. Returns: coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H). """ x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) y, x = torch.meshgrid(y_range, x_range) y = y.expand([featmap_sizes[0], 1, -1, -1]) x = x.expand([featmap_sizes[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) return coord_feat
Read the Docs v: v2.17.0
Versions
latest
stable
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.