Shortcuts

mmdet.engine.hooks.num_class_check_hook 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import VGG
from mmengine.hooks import Hook
from mmengine.runner import Runner

from mmdet.registry import HOOKS


[文档]@HOOKS.register_module() class NumClassCheckHook(Hook): """Check whether the `num_classes` in head matches the length of `classes` in `dataset.metainfo`.""" def _check_head(self, runner: Runner, mode: str) -> None: """Check whether the `num_classes` in head matches the length of `classes` in `dataset.metainfo`. Args: runner (:obj:`Runner`): The runner of the training or evaluation process. """ assert mode in ['train', 'val'] model = runner.model dataset = runner.train_dataloader.dataset if mode == 'train' else \ runner.val_dataloader.dataset if dataset.metainfo.get('classes', None) is None: runner.logger.warning( f'Please set `classes` ' f'in the {dataset.__class__.__name__} `metainfo` and' f'check if it is consistent with the `num_classes` ' f'of head') else: classes = dataset.metainfo['classes'] assert type(classes) is not str, \ (f'`classes` in {dataset.__class__.__name__}' f'should be a tuple of str.' f'Add comma if number of classes is 1 as ' f'classes = ({classes},)') from mmdet.models.roi_heads.mask_heads import FusedSemanticHead for name, module in model.named_modules(): if hasattr(module, 'num_classes') and not name.endswith( 'rpn_head') and not isinstance( module, (VGG, FusedSemanticHead)): assert module.num_classes == len(classes), \ (f'The `num_classes` ({module.num_classes}) in ' f'{module.__class__.__name__} of ' f'{model.__class__.__name__} does not matches ' f'the length of `classes` ' f'{len(classes)}) in ' f'{dataset.__class__.__name__}')
[文档] def before_train_epoch(self, runner: Runner) -> None: """Check whether the training dataset is compatible with head. Args: runner (:obj:`Runner`): The runner of the training or evaluation process. """ self._check_head(runner, 'train')
[文档] def before_val_epoch(self, runner: Runner) -> None: """Check whether the dataset in val epoch is compatible with head. Args: runner (:obj:`Runner`): The runner of the training or evaluation process. """ self._check_head(runner, 'val')
Read the Docs v: 3.x
Versions
latest
stable
3.x
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.