Shortcuts

mmdet.core.bbox.samplers.instance_balanced_pos_sampler 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch

from ..builder import BBOX_SAMPLERS
from .random_sampler import RandomSampler


[文档]@BBOX_SAMPLERS.register_module() class InstanceBalancedPosSampler(RandomSampler): """Instance balanced sampler that samples equal number of positive samples for each instance.""" def _sample_pos(self, assign_result, num_expected, **kwargs): """Sample positive boxes. Args: assign_result (:obj:`AssignResult`): The assigned results of boxes. num_expected (int): The number of expected positive samples Returns: Tensor or ndarray: sampled indices. """ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) if pos_inds.numel() != 0: pos_inds = pos_inds.squeeze(1) if pos_inds.numel() <= num_expected: return pos_inds else: unique_gt_inds = assign_result.gt_inds[pos_inds].unique() num_gts = len(unique_gt_inds) num_per_gt = int(round(num_expected / float(num_gts)) + 1) sampled_inds = [] for i in unique_gt_inds: inds = torch.nonzero( assign_result.gt_inds == i.item(), as_tuple=False) if inds.numel() != 0: inds = inds.squeeze(1) else: continue if len(inds) > num_per_gt: inds = self.random_choice(inds, num_per_gt) sampled_inds.append(inds) sampled_inds = torch.cat(sampled_inds) if len(sampled_inds) < num_expected: num_extra = num_expected - len(sampled_inds) extra_inds = np.array( list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) if len(extra_inds) > num_extra: extra_inds = self.random_choice(extra_inds, num_extra) extra_inds = torch.from_numpy(extra_inds).to( assign_result.gt_inds.device).long() sampled_inds = torch.cat([sampled_inds, extra_inds]) elif len(sampled_inds) > num_expected: sampled_inds = self.random_choice(sampled_inds, num_expected) return sampled_inds
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.