import warnings
import mmcv
from ..builder import PIPELINES
from .compose import Compose
from .transforms import RandomFlip
[docs]@PIPELINES.register_module()
class MultiScaleFlipAug(object):
"""Test-time augmentation with multiple scales and flipping
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal" and "vertical". If flip_direction is list,
multiple flip augmentations will be applied.
It has no effect when flip == False. Default: "horizontal".
"""
def __init__(self,
transforms,
img_scale,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
self.flip = flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip and
not any([isinstance(_, RandomFlip) for _ in self.transforms])):
warnings.warn(
'flip has no effect when RandFlip is not in transforms')
def __call__(self, results):
aug_data = []
flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale:
for flip in flip_aug:
for direction in self.flip_direction:
_results = results.copy()
_results['scale'] = scale
_results['flip'] = flip
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
repr_str += f'flip_direction={self.flip_direction}'
return repr_str