Shortcuts

mmdet.core.post_processing.merge_augs 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings

import numpy as np
import torch
from mmcv import ConfigDict
from mmcv.ops import nms

from ..bbox import bbox_mapping_back


[文档]def merge_aug_proposals(aug_proposals, img_metas, cfg): """Merge augmented proposals (multiscale, flip, etc.) Args: aug_proposals (list[Tensor]): proposals from different testing schemes, shape (n, 5). Note that they are not rescaled to the original image size. img_metas (list[dict]): list of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. cfg (dict): rpn test config. Returns: Tensor: shape (n, 4), proposals corresponding to original image scale. """ cfg = copy.deepcopy(cfg) # deprecate arguments warning if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: warnings.warn( 'In rpn_proposal or test_cfg, ' 'nms_thr has been moved to a dict named nms as ' 'iou_threshold, max_num has been renamed as max_per_img, ' 'name of original arguments and the way to specify ' 'iou_threshold of NMS will be deprecated.') if 'nms' not in cfg: cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) if 'max_num' in cfg: if 'max_per_img' in cfg: assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \ f'max_per_img at the same time, but get {cfg.max_num} ' \ f'and {cfg.max_per_img} respectively' \ f'Please delete max_num which will be deprecated.' else: cfg.max_per_img = cfg.max_num if 'nms_thr' in cfg: assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \ f'iou_threshold in nms and ' \ f'nms_thr at the same time, but get ' \ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \ f' respectively. Please delete the nms_thr ' \ f'which will be deprecated.' recovered_proposals = [] for proposals, img_info in zip(aug_proposals, img_metas): img_shape = img_info['img_shape'] scale_factor = img_info['scale_factor'] flip = img_info['flip'] flip_direction = img_info['flip_direction'] _proposals = proposals.clone() _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape, scale_factor, flip, flip_direction) recovered_proposals.append(_proposals) aug_proposals = torch.cat(recovered_proposals, dim=0) merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(), aug_proposals[:, -1].contiguous(), cfg.nms.iou_threshold) scores = merged_proposals[:, 4] _, order = scores.sort(0, descending=True) num = min(cfg.max_per_img, merged_proposals.shape[0]) order = order[:num] merged_proposals = merged_proposals[order, :] return merged_proposals
[文档]def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg): """Merge augmented detection bboxes and scores. Args: aug_bboxes (list[Tensor]): shape (n, 4*#class) aug_scores (list[Tensor] or None): shape (n, #class) img_shapes (list[Tensor]): shape (3, ). rcnn_test_cfg (dict): rcnn test config. Returns: tuple: (bboxes, scores) """ recovered_bboxes = [] for bboxes, img_info in zip(aug_bboxes, img_metas): img_shape = img_info[0]['img_shape'] scale_factor = img_info[0]['scale_factor'] flip = img_info[0]['flip'] flip_direction = img_info[0]['flip_direction'] bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, flip_direction) recovered_bboxes.append(bboxes) bboxes = torch.stack(recovered_bboxes).mean(dim=0) if aug_scores is None: return bboxes else: scores = torch.stack(aug_scores).mean(dim=0) return bboxes, scores
[文档]def merge_aug_scores(aug_scores): """Merge augmented bbox scores.""" if isinstance(aug_scores[0], torch.Tensor): return torch.mean(torch.stack(aug_scores), dim=0) else: return np.mean(aug_scores, axis=0)
[文档]def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None): """Merge augmented mask prediction. Args: aug_masks (list[ndarray]): shape (n, #class, h, w) img_shapes (list[ndarray]): shape (3, ). rcnn_test_cfg (dict): rcnn test config. Returns: tuple: (bboxes, scores) """ recovered_masks = [] for mask, img_info in zip(aug_masks, img_metas): flip = img_info[0]['flip'] if flip: flip_direction = img_info[0]['flip_direction'] if flip_direction == 'horizontal': mask = mask[:, :, :, ::-1] elif flip_direction == 'vertical': mask = mask[:, :, ::-1, :] elif flip_direction == 'diagonal': mask = mask[:, :, :, ::-1] mask = mask[:, :, ::-1, :] else: raise ValueError( f"Invalid flipping direction '{flip_direction}'") recovered_masks.append(mask) if weights is None: merged_masks = np.mean(recovered_masks, axis=0) else: merged_masks = np.average( np.array(recovered_masks), axis=0, weights=np.array(weights)) return merged_masks
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.