Source code for mmdet.core.evaluation.eval_hooks

import os.path as osp

from mmcv.runner import Hook
from torch.utils.data import DataLoader


[docs]class EvalHook(Hook): """Evaluation hook. Attributes: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. """ def __init__(self, dataloader, interval=1, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError('dataloader must be a pytorch DataLoader, but got' f' {type(dataloader)}') self.dataloader = dataloader self.interval = interval self.eval_kwargs = eval_kwargs def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): return from mmdet.apis import single_gpu_test results = single_gpu_test(runner.model, self.dataloader, show=False) self.evaluate(runner, results) def evaluate(self, runner, results): eval_res = self.dataloader.dataset.evaluate( results, logger=runner.logger, **self.eval_kwargs) for name, val in eval_res.items(): runner.log_buffer.output[name] = val runner.log_buffer.ready = True
[docs]class DistEvalHook(EvalHook): """Distributed evaluation hook. Attributes: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. """ def __init__(self, dataloader, interval=1, gpu_collect=False, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError('dataloader must be a pytorch DataLoader, but got ' f'{type(dataloader)}') self.dataloader = dataloader self.interval = interval self.gpu_collect = gpu_collect self.eval_kwargs = eval_kwargs def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): return from mmdet.apis import multi_gpu_test results = multi_gpu_test( runner.model, self.dataloader, tmpdir=osp.join(runner.work_dir, '.eval_hook'), gpu_collect=self.gpu_collect) if runner.rank == 0: print('\n') self.evaluate(runner, results)