Source code for mmdet.apis.inference

import warnings

import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint

from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector


[docs]def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): """Initialize a detector from config file. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. cfg_options (dict): Options to override some settings in the used config. Returns: nn.Module: The constructed detector. """ if isinstance(config, str): config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') if cfg_options is not None: config.merge_from_dict(cfg_options) config.model.pretrained = None config.model.train_cfg = None model = build_detector(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: map_loc = 'cpu' if device == 'cpu' else None checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: warnings.simplefilter('once') warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use COCO classes by default.') model.CLASSES = get_classes('coco') model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model
class LoadImage: """Deprecated. A simple pipeline to load image. """ def __call__(self, results): """Call function to load images into results. Args: results (dict): A result dict contains the file name of the image to be read. Returns: dict: ``results`` will be returned containing loaded image. """ warnings.simplefilter('once') warnings.warn('`LoadImage` is deprecated and will be removed in ' 'future releases. You may use `LoadImageFromWebcam` ' 'from `mmdet.datasets.pipelines.` instead.') if isinstance(results['img'], str): results['filename'] = results['img'] results['ori_filename'] = results['img'] else: results['filename'] = None results['ori_filename'] = None img = mmcv.imread(results['img']) results['img'] = img results['img_fields'] = ['img'] results['img_shape'] = img.shape results['ori_shape'] = img.shape return results
[docs]def inference_detector(model, imgs): """Inference image(s) with the detector. Args: model (nn.Module): The loaded detector. imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): Either image files or loaded images. Returns: If imgs is a list or tuple, the same length list type results will be returned, otherwise return the detection results directly. """ if isinstance(imgs, (list, tuple)): is_batch = True else: imgs = [imgs] is_batch = False cfg = model.cfg device = next(model.parameters()).device # model device if isinstance(imgs[0], np.ndarray): cfg = cfg.copy() # set loading pipeline type cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) datas = [] for img in imgs: # prepare data if isinstance(img, np.ndarray): # directly add img data = dict(img=img) else: # add information into dict data = dict(img_info=dict(filename=img), img_prefix=None) # build the data pipeline data = test_pipeline(data) datas.append(data) data = collate(datas, samples_per_gpu=len(imgs)) # just get the actual data from DataContainer data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] data['img'] = [img.data[0] for img in data['img']] if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: for m in model.modules(): assert not isinstance( m, RoIPool ), 'CPU inference with RoIPool is not supported currently.' # forward the model with torch.no_grad(): results = model(return_loss=False, rescale=True, **data) if not is_batch: return results[0] else: return results
[docs]async def async_inference_detector(model, imgs): """Async inference image(s) with the detector. Args: model (nn.Module): The loaded detector. img (str | ndarray): Either image files or loaded images. Returns: Awaitable detection results. """ if not isinstance(imgs, (list, tuple)): imgs = [imgs] cfg = model.cfg device = next(model.parameters()).device # model device if isinstance(imgs[0], np.ndarray): cfg = cfg.copy() # set loading pipeline type cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) datas = [] for img in imgs: # prepare data if isinstance(img, np.ndarray): # directly add img data = dict(img=img) else: # add information into dict data = dict(img_info=dict(filename=img), img_prefix=None) # build the data pipeline data = test_pipeline(data) datas.append(data) data = collate(datas, samples_per_gpu=len(imgs)) # just get the actual data from DataContainer data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] data['img'] = [img.data[0] for img in data['img']] if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: for m in model.modules(): assert not isinstance( m, RoIPool ), 'CPU inference with RoIPool is not supported currently.' # We don't restore `torch.is_grad_enabled()` value during concurrent # inference since execution can overlap torch.set_grad_enabled(False) results = await model.aforward_test(rescale=True, **data) return results
[docs]def show_result_pyplot(model, img, result, score_thr=0.3, title='result', wait_time=0): """Visualize the detection results on the image. Args: model (nn.Module): The loaded detector. img (str or np.ndarray): Image filename or loaded image. result (tuple[list] or list): The detection result, can be either (bbox, segm) or just bbox. score_thr (float): The threshold to visualize the bboxes and masks. title (str): Title of the pyplot figure. wait_time (float): Value of waitKey param. Default: 0. """ if hasattr(model, 'module'): model = model.module model.show_result( img, result, score_thr=score_thr, show=True, wait_time=wait_time, win_name=title, bbox_color=(72, 101, 241), text_color=(72, 101, 241))