Shortcuts

Source code for mmdet.datasets.samplers.class_aware_sampler

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler

from mmdet.core.utils import sync_random_seed


[docs]class ClassAwareSampler(Sampler): r"""Sampler that restricts data loading to the label of the dataset. A class-aware sampling strategy to effectively tackle the non-uniform class distribution. The length of the training data is consistent with source data. Simple improvements based on `Relay Backpropagation for Effective Learning of Deep Convolutional Neural Networks <https://arxiv.org/abs/1512.05830>`_ The implementation logic is referred to https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py Args: dataset: Dataset used for sampling. samples_per_gpu (int): When model is :obj:`DistributedDataParallel`, it is the number of training samples on each GPU. When model is :obj:`DataParallel`, it is `num_gpus * samples_per_gpu`. Default : 1. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. seed (int, optional): random seed used to shuffle the sampler if ``shuffle=True``. This number should be identical across all processes in the distributed group. Default: 0. num_sample_class (int): The number of samples taken from each per-label list. Default: 1 """ def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0, num_sample_class=1): _rank, _num_replicas = get_dist_info() if num_replicas is None: num_replicas = _num_replicas if rank is None: rank = _rank self.dataset = dataset self.num_replicas = num_replicas self.samples_per_gpu = samples_per_gpu self.rank = rank self.epoch = 0 # Must be the same across all workers. If None, will use a # random seed shared among workers # (require synchronization among all workers) self.seed = sync_random_seed(seed) # The number of samples taken from each per-label list assert num_sample_class > 0 and isinstance(num_sample_class, int) self.num_sample_class = num_sample_class # Get per-label image list from dataset assert hasattr(dataset, 'get_cat2imgs'), \ 'dataset must have `get_cat2imgs` function' self.cat_dict = dataset.get_cat2imgs() self.num_samples = int( math.ceil( len(self.dataset) * 1.0 / self.num_replicas / self.samples_per_gpu)) * self.samples_per_gpu self.total_size = self.num_samples * self.num_replicas # get number of images containing each category self.num_cat_imgs = [len(x) for x in self.cat_dict.values()] # filter labels without images self.valid_cat_inds = [ i for i, length in enumerate(self.num_cat_imgs) if length != 0 ] self.num_classes = len(self.valid_cat_inds) def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch + self.seed) # initialize label list label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g) # initialize each per-label image list data_iter_dict = dict() for i in self.valid_cat_inds: data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g) def gen_cat_img_inds(cls_list, data_dict, num_sample_cls): """Traverse the categories and extract `num_sample_cls` image indexes of the corresponding categories one by one.""" id_indices = [] for _ in range(len(cls_list)): cls_idx = next(cls_list) for _ in range(num_sample_cls): id = next(data_dict[cls_idx]) id_indices.append(id) return id_indices # deterministically shuffle based on epoch num_bins = int( math.ceil(self.total_size * 1.0 / self.num_classes / self.num_sample_class)) indices = [] for i in range(num_bins): indices += gen_cat_img_inds(label_iter_list, data_iter_dict, self.num_sample_class) # fix extra samples to make it evenly divisible if len(indices) >= self.total_size: indices = indices[:self.total_size] else: indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size # subsample offset = self.num_samples * self.rank indices = indices[offset:offset + self.num_samples] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch
class RandomCycleIter: """Shuffle the list and do it again after the list have traversed. The implementation logic is referred to https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py Example: >>> label_list = [0, 1, 2, 4, 5] >>> g = torch.Generator() >>> g.manual_seed(0) >>> label_iter_list = RandomCycleIter(label_list, generator=g) >>> index = next(label_iter_list) Args: data (list or ndarray): The data that needs to be shuffled. generator: An torch.Generator object, which is used in setting the seed for generating random numbers. """ # noqa: W605 def __init__(self, data, generator=None): self.data = data self.length = len(data) self.index = torch.randperm(self.length, generator=generator).numpy() self.i = 0 self.generator = generator def __iter__(self): return self def __len__(self): return len(self.data) def __next__(self): if self.i == self.length: self.index = torch.randperm( self.length, generator=self.generator).numpy() self.i = 0 idx = self.data[self.index[self.i]] self.i += 1 return idx
Read the Docs v: v2.24.1
Versions
latest
stable
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.