Shortcuts

mmdet.core.bbox.samplers.iou_balanced_neg_sampler 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch

from ..builder import BBOX_SAMPLERS
from .random_sampler import RandomSampler


[文档]@BBOX_SAMPLERS.register_module() class IoUBalancedNegSampler(RandomSampler): """IoU Balanced Sampling. arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) Sampling proposals according to their IoU. `floor_fraction` of needed RoIs are sampled from proposals whose IoU are lower than `floor_thr` randomly. The others are sampled from proposals whose IoU are higher than `floor_thr`. These proposals are sampled from some bins evenly, which are split by `num_bins` via IoU evenly. Args: num (int): number of proposals. pos_fraction (float): fraction of positive proposals. floor_thr (float): threshold (minimum) IoU for IoU balanced sampling, set to -1 if all using IoU balanced sampling. floor_fraction (float): sampling fraction of proposals under floor_thr. num_bins (int): number of bins in IoU balanced sampling. """ def __init__(self, num, pos_fraction, floor_thr=-1, floor_fraction=0, num_bins=3, **kwargs): super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, **kwargs) assert floor_thr >= 0 or floor_thr == -1 assert 0 <= floor_fraction <= 1 assert num_bins >= 1 self.floor_thr = floor_thr self.floor_fraction = floor_fraction self.num_bins = num_bins
[文档] def sample_via_interval(self, max_overlaps, full_set, num_expected): """Sample according to the iou interval. Args: max_overlaps (torch.Tensor): IoU between bounding boxes and ground truth boxes. full_set (set(int)): A full set of indices of boxes。 num_expected (int): Number of expected samples。 Returns: np.ndarray: Indices of samples """ max_iou = max_overlaps.max() iou_interval = (max_iou - self.floor_thr) / self.num_bins per_num_expected = int(num_expected / self.num_bins) sampled_inds = [] for i in range(self.num_bins): start_iou = self.floor_thr + i * iou_interval end_iou = self.floor_thr + (i + 1) * iou_interval tmp_set = set( np.where( np.logical_and(max_overlaps >= start_iou, max_overlaps < end_iou))[0]) tmp_inds = list(tmp_set & full_set) if len(tmp_inds) > per_num_expected: tmp_sampled_set = self.random_choice(tmp_inds, per_num_expected) else: tmp_sampled_set = np.array(tmp_inds, dtype=np.int) sampled_inds.append(tmp_sampled_set) sampled_inds = np.concatenate(sampled_inds) if len(sampled_inds) < num_expected: num_extra = num_expected - len(sampled_inds) extra_inds = np.array(list(full_set - set(sampled_inds))) if len(extra_inds) > num_extra: extra_inds = self.random_choice(extra_inds, num_extra) sampled_inds = np.concatenate([sampled_inds, extra_inds]) return sampled_inds
def _sample_neg(self, assign_result, num_expected, **kwargs): """Sample negative boxes. Args: assign_result (:obj:`AssignResult`): The assigned results of boxes. num_expected (int): The number of expected negative samples Returns: Tensor or ndarray: sampled indices. """ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) if neg_inds.numel() != 0: neg_inds = neg_inds.squeeze(1) if len(neg_inds) <= num_expected: return neg_inds else: max_overlaps = assign_result.max_overlaps.cpu().numpy() # balance sampling for negative samples neg_set = set(neg_inds.cpu().numpy()) if self.floor_thr > 0: floor_set = set( np.where( np.logical_and(max_overlaps >= 0, max_overlaps < self.floor_thr))[0]) iou_sampling_set = set( np.where(max_overlaps >= self.floor_thr)[0]) elif self.floor_thr == 0: floor_set = set(np.where(max_overlaps == 0)[0]) iou_sampling_set = set( np.where(max_overlaps > self.floor_thr)[0]) else: floor_set = set() iou_sampling_set = set( np.where(max_overlaps > self.floor_thr)[0]) # for sampling interval calculation self.floor_thr = 0 floor_neg_inds = list(floor_set & neg_set) iou_sampling_neg_inds = list(iou_sampling_set & neg_set) num_expected_iou_sampling = int(num_expected * (1 - self.floor_fraction)) if len(iou_sampling_neg_inds) > num_expected_iou_sampling: if self.num_bins >= 2: iou_sampled_inds = self.sample_via_interval( max_overlaps, set(iou_sampling_neg_inds), num_expected_iou_sampling) else: iou_sampled_inds = self.random_choice( iou_sampling_neg_inds, num_expected_iou_sampling) else: iou_sampled_inds = np.array( iou_sampling_neg_inds, dtype=np.int) num_expected_floor = num_expected - len(iou_sampled_inds) if len(floor_neg_inds) > num_expected_floor: sampled_floor_inds = self.random_choice( floor_neg_inds, num_expected_floor) else: sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int) sampled_inds = np.concatenate( (sampled_floor_inds, iou_sampled_inds)) if len(sampled_inds) < num_expected: num_extra = num_expected - len(sampled_inds) extra_inds = np.array(list(neg_set - set(sampled_inds))) if len(extra_inds) > num_extra: extra_inds = self.random_choice(extra_inds, num_extra) sampled_inds = np.concatenate((sampled_inds, extra_inds)) sampled_inds = torch.from_numpy(sampled_inds).long().to( assign_result.gt_inds.device) return sampled_inds
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.