mmdet.core.bbox.samplers.iou_balanced_neg_sampler 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from ..builder import BBOX_SAMPLERS
from .random_sampler import RandomSampler
[文档]@BBOX_SAMPLERS.register_module()
class IoUBalancedNegSampler(RandomSampler):
"""IoU Balanced Sampling.
arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
Sampling proposals according to their IoU. `floor_fraction` of needed RoIs
are sampled from proposals whose IoU are lower than `floor_thr` randomly.
The others are sampled from proposals whose IoU are higher than
`floor_thr`. These proposals are sampled from some bins evenly, which are
split by `num_bins` via IoU evenly.
Args:
num (int): number of proposals.
pos_fraction (float): fraction of positive proposals.
floor_thr (float): threshold (minimum) IoU for IoU balanced sampling,
set to -1 if all using IoU balanced sampling.
floor_fraction (float): sampling fraction of proposals under floor_thr.
num_bins (int): number of bins in IoU balanced sampling.
"""
def __init__(self,
num,
pos_fraction,
floor_thr=-1,
floor_fraction=0,
num_bins=3,
**kwargs):
super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
**kwargs)
assert floor_thr >= 0 or floor_thr == -1
assert 0 <= floor_fraction <= 1
assert num_bins >= 1
self.floor_thr = floor_thr
self.floor_fraction = floor_fraction
self.num_bins = num_bins
[文档] def sample_via_interval(self, max_overlaps, full_set, num_expected):
"""Sample according to the iou interval.
Args:
max_overlaps (torch.Tensor): IoU between bounding boxes and ground
truth boxes.
full_set (set(int)): A full set of indices of boxes。
num_expected (int): Number of expected samples。
Returns:
np.ndarray: Indices of samples
"""
max_iou = max_overlaps.max()
iou_interval = (max_iou - self.floor_thr) / self.num_bins
per_num_expected = int(num_expected / self.num_bins)
sampled_inds = []
for i in range(self.num_bins):
start_iou = self.floor_thr + i * iou_interval
end_iou = self.floor_thr + (i + 1) * iou_interval
tmp_set = set(
np.where(
np.logical_and(max_overlaps >= start_iou,
max_overlaps < end_iou))[0])
tmp_inds = list(tmp_set & full_set)
if len(tmp_inds) > per_num_expected:
tmp_sampled_set = self.random_choice(tmp_inds,
per_num_expected)
else:
tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
sampled_inds.append(tmp_sampled_set)
sampled_inds = np.concatenate(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(full_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = self.random_choice(extra_inds, num_extra)
sampled_inds = np.concatenate([sampled_inds, extra_inds])
return sampled_inds
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative boxes.
Args:
assign_result (:obj:`AssignResult`): The assigned results of boxes.
num_expected (int): The number of expected negative samples
Returns:
Tensor or ndarray: sampled indices.
"""
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
max_overlaps = assign_result.max_overlaps.cpu().numpy()
# balance sampling for negative samples
neg_set = set(neg_inds.cpu().numpy())
if self.floor_thr > 0:
floor_set = set(
np.where(
np.logical_and(max_overlaps >= 0,
max_overlaps < self.floor_thr))[0])
iou_sampling_set = set(
np.where(max_overlaps >= self.floor_thr)[0])
elif self.floor_thr == 0:
floor_set = set(np.where(max_overlaps == 0)[0])
iou_sampling_set = set(
np.where(max_overlaps > self.floor_thr)[0])
else:
floor_set = set()
iou_sampling_set = set(
np.where(max_overlaps > self.floor_thr)[0])
# for sampling interval calculation
self.floor_thr = 0
floor_neg_inds = list(floor_set & neg_set)
iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
num_expected_iou_sampling = int(num_expected *
(1 - self.floor_fraction))
if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
if self.num_bins >= 2:
iou_sampled_inds = self.sample_via_interval(
max_overlaps, set(iou_sampling_neg_inds),
num_expected_iou_sampling)
else:
iou_sampled_inds = self.random_choice(
iou_sampling_neg_inds, num_expected_iou_sampling)
else:
iou_sampled_inds = np.array(
iou_sampling_neg_inds, dtype=np.int)
num_expected_floor = num_expected - len(iou_sampled_inds)
if len(floor_neg_inds) > num_expected_floor:
sampled_floor_inds = self.random_choice(
floor_neg_inds, num_expected_floor)
else:
sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
sampled_inds = np.concatenate(
(sampled_floor_inds, iou_sampled_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = self.random_choice(extra_inds, num_extra)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
sampled_inds = torch.from_numpy(sampled_inds).long().to(
assign_result.gt_inds.device)
return sampled_inds