Shortcuts

mmdet.datasets.samplers.infinite_sampler 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import itertools

import numpy as np
import torch
from mmcv.runner import get_dist_info
from torch.utils.data.sampler import Sampler


[文档]class InfiniteGroupBatchSampler(Sampler): """Similar to `BatchSampler` warping a `GroupSampler. It is designed for iteration-based runners like `IterBasedRunner` and yields a mini-batch indices each time, all indices in a batch should be in the same group. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py Args: dataset (object): The dataset. batch_size (int): When model is :obj:`DistributedDataParallel`, it is the number of training samples on each GPU. When model is :obj:`DataParallel`, it is `num_gpus * samples_per_gpu`. Default : 1. world_size (int, optional): Number of processes participating in distributed training. Default: None. rank (int, optional): Rank of current process. Default: None. seed (int): Random seed. Default: 0. shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it should be noted that `shuffle` can not guarantee that you can generate sequential indices because it need to ensure that all indices in a batch is in a group. Default: True. """ # noqa: W605 def __init__(self, dataset, batch_size=1, world_size=None, rank=None, seed=0, shuffle=True): _rank, _world_size = get_dist_info() if world_size is None: world_size = _world_size if rank is None: rank = _rank self.rank = rank self.world_size = world_size self.dataset = dataset self.batch_size = batch_size self.seed = seed if seed is not None else 0 self.shuffle = shuffle assert hasattr(self.dataset, 'flag') self.flag = self.dataset.flag self.group_sizes = np.bincount(self.flag) # buffer used to save indices of each group self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))} self.size = len(dataset) self.indices = self._indices_of_rank() def _infinite_indices(self): """Infinitely yield a sequence of indices.""" g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.size, generator=g).tolist() else: yield from torch.arange(self.size).tolist() def _indices_of_rank(self): """Slice the infinite indices by rank.""" yield from itertools.islice(self._infinite_indices(), self.rank, None, self.world_size) def __iter__(self): # once batch size is reached, yield the indices for idx in self.indices: flag = self.flag[idx] group_buffer = self.buffer_per_group[flag] group_buffer.append(idx) if len(group_buffer) == self.batch_size: yield group_buffer[:] del group_buffer[:] def __len__(self): """Length of base dataset.""" return self.size
[文档] def set_epoch(self, epoch): """Not supported in `IterationBased` runner.""" raise NotImplementedError
[文档]class InfiniteBatchSampler(Sampler): """Similar to `BatchSampler` warping a `DistributedSampler. It is designed iteration-based runners like `IterBasedRunner` and yields a mini-batch indices each time. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py Args: dataset (object): The dataset. batch_size (int): When model is :obj:`DistributedDataParallel`, it is the number of training samples on each GPU, When model is :obj:`DataParallel`, it is `num_gpus * samples_per_gpu`. Default : 1. world_size (int, optional): Number of processes participating in distributed training. Default: None. rank (int, optional): Rank of current process. Default: None. seed (int): Random seed. Default: 0. shuffle (bool): Whether shuffle the dataset or not. Default: True. """ # noqa: W605 def __init__(self, dataset, batch_size=1, world_size=None, rank=None, seed=0, shuffle=True): _rank, _world_size = get_dist_info() if world_size is None: world_size = _world_size if rank is None: rank = _rank self.rank = rank self.world_size = world_size self.dataset = dataset self.batch_size = batch_size self.seed = seed if seed is not None else 0 self.shuffle = shuffle self.size = len(dataset) self.indices = self._indices_of_rank() def _infinite_indices(self): """Infinitely yield a sequence of indices.""" g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.size, generator=g).tolist() else: yield from torch.arange(self.size).tolist() def _indices_of_rank(self): """Slice the infinite indices by rank.""" yield from itertools.islice(self._infinite_indices(), self.rank, None, self.world_size) def __iter__(self): # once batch size is reached, yield the indices batch_buffer = [] for idx in self.indices: batch_buffer.append(idx) if len(batch_buffer) == self.batch_size: yield batch_buffer batch_buffer = [] def __len__(self): """Length of base dataset.""" return self.size
[文档] def set_epoch(self, epoch): """Not supported in `IterationBased` runner.""" raise NotImplementedError
Read the Docs v: latest
Versions
latest
stable
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.