Shortcuts

Source code for mmdet.core.post_processing.matrix_nms

# Copyright (c) OpenMMLab. All rights reserved.
import torch


[docs]def mask_matrix_nms(masks, labels, scores, filter_thr=-1, nms_pre=-1, max_num=-1, kernel='gaussian', sigma=2.0, mask_area=None): """Matrix NMS for multi-class masks. Args: masks (Tensor): Has shape (num_instances, h, w) labels (Tensor): Labels of corresponding masks, has shape (num_instances,). scores (Tensor): Mask scores of corresponding masks, has shape (num_instances). filter_thr (float): Score threshold to filter the masks after matrix nms. Default: -1, which means do not use filter_thr. nms_pre (int): The max number of instances to do the matrix nms. Default: -1, which means do not use nms_pre. max_num (int, optional): If there are more than max_num masks after matrix, only top max_num will be kept. Default: -1, which means do not use max_num. kernel (str): 'linear' or 'gaussian'. sigma (float): std in gaussian method. mask_area (Tensor): The sum of seg_masks. Returns: tuple(Tensor): Processed mask results. - scores (Tensor): Updated scores, has shape (n,). - labels (Tensor): Remained labels, has shape (n,). - masks (Tensor): Remained masks, has shape (n, w, h). - keep_inds (Tensor): The indices number of the remaining mask in the input mask, has shape (n,). """ assert len(labels) == len(masks) == len(scores) if len(labels) == 0: return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( 0, *masks.shape[-2:]), labels.new_zeros(0) if mask_area is None: mask_area = masks.sum((1, 2)).float() else: assert len(masks) == len(mask_area) # sort and keep top nms_pre scores, sort_inds = torch.sort(scores, descending=True) keep_inds = sort_inds if nms_pre > 0 and len(sort_inds) > nms_pre: sort_inds = sort_inds[:nms_pre] keep_inds = keep_inds[:nms_pre] scores = scores[:nms_pre] masks = masks[sort_inds] mask_area = mask_area[sort_inds] labels = labels[sort_inds] num_masks = len(labels) flatten_masks = masks.reshape(num_masks, -1).float() # inter. inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) expanded_mask_area = mask_area.expand(num_masks, num_masks) # Upper triangle iou matrix. iou_matrix = (inter_matrix / (expanded_mask_area + expanded_mask_area.transpose(1, 0) - inter_matrix)).triu(diagonal=1) # label_specific matrix. expanded_labels = labels.expand(num_masks, num_masks) # Upper triangle label matrix. label_matrix = (expanded_labels == expanded_labels.transpose( 1, 0)).triu(diagonal=1) # IoU compensation compensate_iou, _ = (iou_matrix * label_matrix).max(0) compensate_iou = compensate_iou.expand(num_masks, num_masks).transpose(1, 0) # IoU decay decay_iou = iou_matrix * label_matrix # Calculate the decay_coefficient if kernel == 'gaussian': decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) elif kernel == 'linear': decay_matrix = (1 - decay_iou) / (1 - compensate_iou) decay_coefficient, _ = decay_matrix.min(0) else: raise NotImplementedError( f'{kernel} kernel is not supported in matrix nms!') # update the score. scores = scores * decay_coefficient if filter_thr > 0: keep = scores >= filter_thr keep_inds = keep_inds[keep] if not keep.any(): return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( 0, *masks.shape[-2:]), labels.new_zeros(0) masks = masks[keep] scores = scores[keep] labels = labels[keep] # sort and keep top max_num scores, sort_inds = torch.sort(scores, descending=True) keep_inds = keep_inds[sort_inds] if max_num > 0 and len(sort_inds) > max_num: sort_inds = sort_inds[:max_num] keep_inds = keep_inds[:max_num] scores = scores[:max_num] masks = masks[sort_inds] labels = labels[sort_inds] return scores, labels, masks, keep_inds
Read the Docs v: v2.18.1
Versions
latest
stable
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.