Shortcuts

mmdet.core.mask.utils 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import pycocotools.mask as mask_util
import torch


[文档]def split_combined_polys(polys, poly_lens, polys_per_mask): """Split the combined 1-D polys into masks. A mask is represented as a list of polys, and a poly is represented as a 1-D array. In dataset, all masks are concatenated into a single 1-D tensor. Here we need to split the tensor into original representations. Args: polys (list): a list (length = image num) of 1-D tensors poly_lens (list): a list (length = image num) of poly length polys_per_mask (list): a list (length = image num) of poly number of each mask Returns: list: a list (length = image num) of list (length = mask num) of \ list (length = poly num) of numpy array. """ mask_polys_list = [] for img_id in range(len(polys)): polys_single = polys[img_id] polys_lens_single = poly_lens[img_id].tolist() polys_per_mask_single = polys_per_mask[img_id].tolist() split_polys = mmcv.slice_list(polys_single, polys_lens_single) mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single) mask_polys_list.append(mask_polys) return mask_polys_list
# TODO: move this function to more proper place
[文档]def encode_mask_results(mask_results): """Encode bitmap mask to RLE code. Args: mask_results (list | tuple[list]): bitmap mask results. In mask scoring rcnn, mask_results is a tuple of (segm_results, segm_cls_score). Returns: list | tuple: RLE encoded mask. """ if isinstance(mask_results, tuple): # mask scoring cls_segms, cls_mask_scores = mask_results else: cls_segms = mask_results num_classes = len(cls_segms) encoded_mask_results = [[] for _ in range(num_classes)] for i in range(len(cls_segms)): for cls_segm in cls_segms[i]: encoded_mask_results[i].append( mask_util.encode( np.array( cls_segm[:, :, np.newaxis], order='F', dtype='uint8'))[0]) # encoded with RLE if isinstance(mask_results, tuple): return encoded_mask_results, cls_mask_scores else: return encoded_mask_results
[文档]def mask2bbox(masks): """Obtain tight bounding boxes of binary masks. Args: masks (Tensor): Binary mask of shape (n, h, w). Returns: Tensor: Bboxe with shape (n, 4) of \ positive region in binary mask. """ N = masks.shape[0] bboxes = masks.new_zeros((N, 4), dtype=torch.float32) x_any = torch.any(masks, dim=1) y_any = torch.any(masks, dim=2) for i in range(N): x = torch.where(x_any[i, :])[0] y = torch.where(y_any[i, :])[0] if len(x) > 0 and len(y) > 0: bboxes[i, :] = bboxes.new_tensor( [x[0], y[0], x[-1] + 1, y[-1] + 1]) return bboxes
Read the Docs v: v2.28.2
Versions
latest
stable
3.x
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.