Shortcuts

Source code for mmdet.datasets.samplers.batch_sampler

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

from torch.utils.data import BatchSampler, Sampler

from mmdet.registry import DATA_SAMPLERS


# TODO: maybe replace with a data_loader wrapper
[docs]@DATA_SAMPLERS.register_module() class AspectRatioBatchSampler(BatchSampler): """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. >= 1) into a same batch. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size``. """ def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool = False) -> None: if not isinstance(sampler, Sampler): raise TypeError('sampler should be an instance of ``Sampler``, ' f'but got {sampler}') if not isinstance(batch_size, int) or batch_size <= 0: raise ValueError('batch_size should be a positive integer value, ' f'but got batch_size={batch_size}') self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last # two groups for w < h and w >= h self._aspect_ratio_buckets = [[] for _ in range(2)] def __iter__(self) -> Sequence[int]: for idx in self.sampler: data_info = self.sampler.dataset.get_data_info(idx) width, height = data_info['width'], data_info['height'] bucket_id = 0 if width < height else 1 bucket = self._aspect_ratio_buckets[bucket_id] bucket.append(idx) # yield a batch of indices in the same aspect ratio group if len(bucket) == self.batch_size: yield bucket[:] del bucket[:] # yield the rest data and reset the bucket left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ 1] self._aspect_ratio_buckets = [[] for _ in range(2)] while len(left_data) > 0: if len(left_data) <= self.batch_size: if not self.drop_last: yield left_data[:] left_data = [] else: yield left_data[:self.batch_size] left_data = left_data[self.batch_size:] def __len__(self) -> int: if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size
Read the Docs v: 3.x
Versions
latest
stable
3.x
v3.0.0rc0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
test-3.0.0rc0
main
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.