Shortcuts

Source code for mmdet.engine.hooks.sync_norm_hook

# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict

from mmengine.dist import get_dist_info
from mmengine.hooks import Hook
from torch import nn

from mmdet.registry import HOOKS
from mmdet.utils import all_reduce_dict


def get_norm_states(module: nn.Module) -> OrderedDict:
    """Get the state_dict of batch norms in the module."""
    async_norm_states = OrderedDict()
    for name, child in module.named_modules():
        if isinstance(child, nn.modules.batchnorm._NormBase):
            for k, v in child.state_dict().items():
                async_norm_states['.'.join([name, k])] = v
    return async_norm_states


[docs]@HOOKS.register_module() class SyncNormHook(Hook): """Synchronize Norm states before validation, currently used in YOLOX."""
[docs] def before_val_epoch(self, runner): """Synchronizing norm.""" module = runner.model _, world_size = get_dist_info() if world_size == 1: return norm_states = get_norm_states(module) if len(norm_states) == 0: return # TODO: use `all_reduce_dict` in mmengine norm_states = all_reduce_dict(norm_states, op='mean') module.load_state_dict(norm_states, strict=False)
Read the Docs v: v3.0.0
Versions
latest
stable
v3.0.0
3.x
v3.0.0rc0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
test-3.0.0rc0
main
dev-3.x
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.