Shortcuts

Source code for mmdet.datasets.utils

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings

from mmcv.cnn import VGG
from mmcv.runner.hooks import HOOKS, Hook

from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
from mmdet.models.dense_heads import GARPNHead, RPNHead
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead


[docs]def replace_ImageToTensor(pipelines): """Replace the ImageToTensor transform in a data pipeline to DefaultFormatBundle, which is normally useful in batch inference. Args: pipelines (list[dict]): Data pipeline configs. Returns: list: The new pipeline list with all ImageToTensor replaced by DefaultFormatBundle. Examples: >>> pipelines = [ ... dict(type='LoadImageFromFile'), ... dict( ... type='MultiScaleFlipAug', ... img_scale=(1333, 800), ... flip=False, ... transforms=[ ... dict(type='Resize', keep_ratio=True), ... dict(type='RandomFlip'), ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), ... dict(type='Pad', size_divisor=32), ... dict(type='ImageToTensor', keys=['img']), ... dict(type='Collect', keys=['img']), ... ]) ... ] >>> expected_pipelines = [ ... dict(type='LoadImageFromFile'), ... dict( ... type='MultiScaleFlipAug', ... img_scale=(1333, 800), ... flip=False, ... transforms=[ ... dict(type='Resize', keep_ratio=True), ... dict(type='RandomFlip'), ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle'), ... dict(type='Collect', keys=['img']), ... ]) ... ] >>> assert expected_pipelines == replace_ImageToTensor(pipelines) """ pipelines = copy.deepcopy(pipelines) for i, pipeline in enumerate(pipelines): if pipeline['type'] == 'MultiScaleFlipAug': assert 'transforms' in pipeline pipeline['transforms'] = replace_ImageToTensor( pipeline['transforms']) elif pipeline['type'] == 'ImageToTensor': warnings.warn( '"ImageToTensor" pipeline is replaced by ' '"DefaultFormatBundle" for batch inference. It is ' 'recommended to manually replace it in the test ' 'data pipeline in your config file.', UserWarning) pipelines[i] = {'type': 'DefaultFormatBundle'} return pipelines
[docs]def get_loading_pipeline(pipeline): """Only keep loading image and annotations related configuration. Args: pipeline (list[dict]): Data pipeline configs. Returns: list[dict]: The new pipeline list with only keep loading image and annotations related configuration. Examples: >>> pipelines = [ ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations', with_bbox=True), ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), ... dict(type='RandomFlip', flip_ratio=0.5), ... dict(type='Normalize', **img_norm_cfg), ... dict(type='Pad', size_divisor=32), ... dict(type='DefaultFormatBundle'), ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) ... ] >>> expected_pipelines = [ ... dict(type='LoadImageFromFile'), ... dict(type='LoadAnnotations', with_bbox=True) ... ] >>> assert expected_pipelines ==\ ... get_loading_pipeline(pipelines) """ loading_pipeline_cfg = [] for cfg in pipeline: obj_cls = PIPELINES.get(cfg['type']) # TODO:use more elegant way to distinguish loading modules if obj_cls is not None and obj_cls in (LoadImageFromFile, LoadAnnotations): loading_pipeline_cfg.append(cfg) assert len(loading_pipeline_cfg) == 2, \ 'The data pipeline in your config file must include ' \ 'loading image and annotations related pipeline.' return loading_pipeline_cfg
@HOOKS.register_module() class NumClassCheckHook(Hook): def _check_head(self, runner): """Check whether the `num_classes` in head matches the length of `CLASSES` in `dataset`. Args: runner (obj:`EpochBasedRunner`): Epoch based Runner. """ model = runner.model dataset = runner.data_loader.dataset if dataset.CLASSES is None: runner.logger.warning( f'Please set `CLASSES` ' f'in the {dataset.__class__.__name__} and' f'check if it is consistent with the `num_classes` ' f'of head') else: assert type(dataset.CLASSES) is not str, \ (f'`CLASSES` in {dataset.__class__.__name__}' f'should be a tuple of str.' f'Add comma if number of classes is 1 as ' f'CLASSES = ({dataset.CLASSES},)') for name, module in model.named_modules(): if hasattr(module, 'num_classes') and not isinstance( module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)): assert module.num_classes == len(dataset.CLASSES), \ (f'The `num_classes` ({module.num_classes}) in ' f'{module.__class__.__name__} of ' f'{model.__class__.__name__} does not matches ' f'the length of `CLASSES` ' f'{len(dataset.CLASSES)}) in ' f'{dataset.__class__.__name__}') def before_train_epoch(self, runner): """Check whether the training dataset is compatible with head. Args: runner (obj:`EpochBasedRunner`): Epoch based Runner. """ self._check_head(runner) def before_val_epoch(self, runner): """Check whether the dataset in val epoch is compatible with head. Args: runner (obj:`EpochBasedRunner`): Epoch based Runner. """ self._check_head(runner)
Read the Docs v: v2.19.0
Versions
latest
stable
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.