Source code for mmdet.datasets.pipelines.auto_augment

import copy

import numpy as np

from ..builder import PIPELINES
from .compose import Compose


[docs]@PIPELINES.register_module() class AutoAugment(object): """Auto augmentation. This data augmentation is proposed in `Learning Data Augmentation Strategies for Object Detection <https://arxiv.org/pdf/1906.11172>`_ # noqa: E501 Args: policies (list[list[dict]]): The policies of auto augmentation. Each policy in ``policies`` is a specific augmentation policy, and is composed by several augmentations (dict). When AutoAugment is called, a random policy in ``policies`` will be selected to augment images. Examples: >>> replace = (104, 116, 124) >>> policies = [ >>> [ >>> dict(type='Sharpness', prob=0.0, level=8), >>> dict( >>> type='Shear', >>> prob=0.4, >>> level=0, >>> replace=replace, >>> axis='x') >>> ], >>> [ >>> dict( >>> type='Rotate', >>> prob=0.6, >>> level=10, >>> replace=replace), >>> dict(type='Color', prob=1.0, level=6) >>> ] >>> ] >>> augmentation = AutoAugment(policies) >>> img = np.ones(100, 100, 3) >>> gt_bboxes = np.ones(10, 4) >>> results = dict(img=img, gt_bboxes=gt_bboxes) >>> results = augmentation(results) """ def __init__(self, policies): assert isinstance(policies, list) and len(policies) > 0, \ 'Policies must be a non-empty list.' for policy in policies: assert isinstance(policy, list) and len(policy) > 0, \ 'Each policy in policies must be a non-empty list.' for augment in policy: assert isinstance(augment, dict) and 'type' in augment, \ 'Each specific augmentation must be a dict with key' \ ' "type".' self.policies = copy.deepcopy(policies) self.transforms = [Compose(policy) for policy in self.policies] def __call__(self, results): transform = np.random.choice(self.transforms) return transform(results) def __repr__(self): return f'{self.__class__.__name__}(policies={self.policies}'