Source code for mmdet.core.post_processing.bbox_nms

import torch
from mmcv.ops.nms import batched_nms


[docs]def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=-1, score_factors=None): """NMS for multi-class bboxes. Args: multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) multi_scores (Tensor): shape (n, #class), where the last column contains scores of the background class, but this will be ignored. score_thr (float): bbox threshold, bboxes with scores lower than it will not be considered. nms_thr (float): NMS IoU threshold max_num (int): if there are more than max_num bboxes after NMS, only top max_num will be kept. score_factors (Tensor): The factors multiplied to scores before applying NMS Returns: tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels \ are 0-based. """ num_classes = multi_scores.size(1) - 1 # exclude background category if multi_bboxes.shape[1] > 4: bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) else: bboxes = multi_bboxes[:, None].expand( multi_scores.size(0), num_classes, 4) scores = multi_scores[:, :-1] # filter out boxes with low scores valid_mask = scores > score_thr # We use masked_select for ONNX exporting purpose, # which is equivalent to bboxes = bboxes[valid_mask] # (TODO): as ONNX does not support repeat now, # we have to use this ugly code bboxes = torch.masked_select( bboxes, torch.stack((valid_mask, valid_mask, valid_mask, valid_mask), -1)).view(-1, 4) if score_factors is not None: scores = scores * score_factors[:, None] scores = torch.masked_select(scores, valid_mask) labels = valid_mask.nonzero()[:, 1] if bboxes.numel() == 0: bboxes = multi_bboxes.new_zeros((0, 5)) labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) if torch.onnx.is_in_onnx_export(): raise RuntimeError('[ONNX Error] Can not record NMS ' 'as it has not been executed this time') return bboxes, labels dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) if max_num > 0: dets = dets[:max_num] keep = keep[:max_num] return dets, labels[keep]