Shortcuts

Source code for mmdet.engine.hooks.mean_teacher_hook

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch.nn as nn
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner

from mmdet.registry import HOOKS


[docs]@HOOKS.register_module() class MeanTeacherHook(Hook): """Mean Teacher Hook. Mean Teacher is an efficient semi-supervised learning method in `Mean Teacher <https://arxiv.org/abs/1703.01780>`_. This method requires two models with exactly the same structure, as the student model and the teacher model, respectively. The student model updates the parameters through gradient descent, and the teacher model updates the parameters through exponential moving average of the student model. Compared with the student model, the teacher model is smoother and accumulates more knowledge. Args: momentum (float): The momentum used for updating teacher's parameter. Teacher's parameter are updated with the formula: `teacher = (1-momentum) * teacher + momentum * student`. Defaults to 0.001. interval (int): Update teacher's parameter every interval iteration. Defaults to 1. skip_buffers (bool): Whether to skip the model buffers, such as batchnorm running stats (running_mean, running_var), it does not perform the ema operation. Default to True. """ def __init__(self, momentum: float = 0.001, interval: int = 1, skip_buffer=True) -> None: assert 0 < momentum < 1 self.momentum = momentum self.interval = interval self.skip_buffers = skip_buffer
[docs] def before_train(self, runner: Runner) -> None: """To check that teacher model and student model exist.""" model = runner.model if is_model_wrapper(model): model = model.module assert hasattr(model, 'teacher') assert hasattr(model, 'student') # only do it at initial stage if runner.iter == 0: self.momentum_update(model, 1)
[docs] def after_train_iter(self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None, outputs: Optional[dict] = None) -> None: """Update teacher's parameter every self.interval iterations.""" if (runner.iter + 1) % self.interval != 0: return model = runner.model if is_model_wrapper(model): model = model.module self.momentum_update(model, self.momentum)
[docs] def momentum_update(self, model: nn.Module, momentum: float) -> None: """Compute the moving average of the parameters using exponential moving average.""" if self.skip_buffers: for (src_name, src_parm), (dst_name, dst_parm) in zip( model.student.named_parameters(), model.teacher.named_parameters()): dst_parm.data.mul_(1 - momentum).add_( src_parm.data, alpha=momentum) else: for (src_parm, dst_parm) in zip(model.student.state_dict().values(), model.teacher.state_dict().values()): # exclude num_tracking if dst_parm.dtype.is_floating_point: dst_parm.data.mul_(1 - momentum).add_( src_parm.data, alpha=momentum)
Read the Docs v: 3.x
Versions
latest
stable
3.x
v3.0.0rc0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
test-3.0.0rc0
main
dev-3.x
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.