Source code for mmdet.core.utils.misc

from functools import partial

import numpy as np
import torch
from six.moves import map, zip

from ..mask.structures import BitmapMasks, PolygonMasks

[docs]def multi_apply(func, *args, **kwargs): """Apply function to a list of arguments. Note: This function applies the ``func`` to multiple inputs and map the multiple outputs of the ``func`` into different list. Each list contains the same type of outputs corresponding to different inputs. Args: func (Function): A function that will be applied to a list of arguments Returns: tuple(list): A tuple containing multiple list, each list contains \ a kind of returned results by the function """ pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results)))
[docs]def unmap(data, count, inds, fill=0): """Unmap a subset of item (data) back to the original set of items (of size count)""" if data.dim() == 1: ret = data.new_full((count, ), fill) ret[inds.type(torch.bool)] = data else: new_size = (count, ) + data.size()[1:] ret = data.new_full(new_size, fill) ret[inds.type(torch.bool), :] = data return ret
[docs]def mask2ndarray(mask): """Convert Mask to ndarray.. Args: mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or torch.Tensor or np.ndarray): The mask to be converted. Returns: np.ndarray: Ndarray mask of shape (n, h, w) that has been converted """ if isinstance(mask, (BitmapMasks, PolygonMasks)): mask = mask.to_ndarray() elif isinstance(mask, torch.Tensor): mask = mask.detach().cpu().numpy() elif not isinstance(mask, np.ndarray): raise TypeError(f'Unsupported {type(mask)} data type') return mask
[docs]def flip_tensor(src_tensor, flip_direction): """flip tensor base on flip_direction. Args: src_tensor (Tensor): input feature map, shape (B, C, H, W). flip_direction (str): The flipping direction. Options are 'horizontal', 'vertical', 'diagonal'. Returns: out_tensor (Tensor): Flipped tensor. """ assert src_tensor.ndim == 4 valid_directions = ['horizontal', 'vertical', 'diagonal'] assert flip_direction in valid_directions if flip_direction == 'horizontal': out_tensor = torch.flip(src_tensor, [3]) elif flip_direction == 'vertical': out_tensor = torch.flip(src_tensor, [2]) else: out_tensor = torch.flip(src_tensor, [2, 3]) return out_tensor