Shortcuts

Source code for mmdet.datasets.pipelines.test_time_aug

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv

from ..builder import PIPELINES
from .compose import Compose


[docs]@PIPELINES.register_module() class MultiScaleFlipAug: """Test-time augmentation with multiple scales and flipping. An example configuration is as followed: .. code-block:: img_scale=[(1333, 400), (1333, 800)], flip=True, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ] After MultiScaleFLipAug with above configuration, the results are wrapped into lists of the same length as followed: .. code-block:: dict( img=[...], img_shape=[...], scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)] flip=[False, True, False, True] ... ) Args: transforms (list[dict]): Transforms to apply in each augmentation. img_scale (tuple | list[tuple] | None): Images scales for resizing. scale_factor (float | list[float] | None): Scale factors for resizing. flip (bool): Whether apply flip augmentation. Default: False. flip_direction (str | list[str]): Flip augmentation directions, options are "horizontal", "vertical" and "diagonal". If flip_direction is a list, multiple flip augmentations will be applied. It has no effect when flip == False. Default: "horizontal". """ def __init__(self, transforms, img_scale=None, scale_factor=None, flip=False, flip_direction='horizontal'): self.transforms = Compose(transforms) assert (img_scale is None) ^ (scale_factor is None), ( 'Must have but only one variable can be set') if img_scale is not None: self.img_scale = img_scale if isinstance(img_scale, list) else [img_scale] self.scale_key = 'scale' assert mmcv.is_list_of(self.img_scale, tuple) else: self.img_scale = scale_factor if isinstance( scale_factor, list) else [scale_factor] self.scale_key = 'scale_factor' self.flip = flip self.flip_direction = flip_direction if isinstance( flip_direction, list) else [flip_direction] assert mmcv.is_list_of(self.flip_direction, str) if not self.flip and self.flip_direction != ['horizontal']: warnings.warn( 'flip_direction has no effect when flip is set to False') if (self.flip and not any([t['type'] == 'RandomFlip' for t in transforms])): warnings.warn( 'flip has no effect when RandomFlip is not in transforms') def __call__(self, results): """Call function to apply test time augment transforms on results. Args: results (dict): Result dict contains the data to transform. Returns: dict[str: list]: The augmented data, where each value is wrapped into a list. """ aug_data = [] flip_args = [(False, None)] if self.flip: flip_args += [(True, direction) for direction in self.flip_direction] for scale in self.img_scale: for flip, direction in flip_args: _results = results.copy() _results[self.scale_key] = scale _results['flip'] = flip _results['flip_direction'] = direction data = self.transforms(_results) aug_data.append(data) # list of dict to dict of list aug_data_dict = {key: [] for key in aug_data[0]} for data in aug_data: for key, val in data.items(): aug_data_dict[key].append(val) return aug_data_dict def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(transforms={self.transforms}, ' repr_str += f'img_scale={self.img_scale}, flip={self.flip}, ' repr_str += f'flip_direction={self.flip_direction})' return repr_str
Read the Docs v: v2.21.0
Versions
latest
stable
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.