Shortcuts

Source code for mmdet.apis.inference

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint

from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector


[docs]def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): """Initialize a detector from config file. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. cfg_options (dict): Options to override some settings in the used config. Returns: nn.Module: The constructed detector. """ if isinstance(config, str): config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') if cfg_options is not None: config.merge_from_dict(cfg_options) config.model.pretrained = None config.model.train_cfg = None model = build_detector(config.model, test_cfg=config.get('test_cfg')) if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') if 'CLASSES' in checkpoint.get('meta', {}): model.CLASSES = checkpoint['meta']['CLASSES'] else: warnings.simplefilter('once') warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use COCO classes by default.') model.CLASSES = get_classes('coco') model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model
class LoadImage: """Deprecated. A simple pipeline to load image. """ def __call__(self, results): """Call function to load images into results. Args: results (dict): A result dict contains the file name of the image to be read. Returns: dict: ``results`` will be returned containing loaded image. """ warnings.simplefilter('once') warnings.warn('`LoadImage` is deprecated and will be removed in ' 'future releases. You may use `LoadImageFromWebcam` ' 'from `mmdet.datasets.pipelines.` instead.') if isinstance(results['img'], str): results['filename'] = results['img'] results['ori_filename'] = results['img'] else: results['filename'] = None results['ori_filename'] = None img = mmcv.imread(results['img']) results['img'] = img results['img_fields'] = ['img'] results['img_shape'] = img.shape results['ori_shape'] = img.shape return results
[docs]def inference_detector(model, imgs): """Inference image(s) with the detector. Args: model (nn.Module): The loaded detector. imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): Either image files or loaded images. Returns: If imgs is a list or tuple, the same length list type results will be returned, otherwise return the detection results directly. """ if isinstance(imgs, (list, tuple)): is_batch = True else: imgs = [imgs] is_batch = False cfg = model.cfg device = next(model.parameters()).device # model device if isinstance(imgs[0], np.ndarray): cfg = cfg.copy() # set loading pipeline type cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) datas = [] for img in imgs: # prepare data if isinstance(img, np.ndarray): # directly add img data = dict(img=img) else: # add information into dict data = dict(img_info=dict(filename=img), img_prefix=None) # build the data pipeline data = test_pipeline(data) datas.append(data) data = collate(datas, samples_per_gpu=len(imgs)) # just get the actual data from DataContainer data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] data['img'] = [img.data[0] for img in data['img']] if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: for m in model.modules(): assert not isinstance( m, RoIPool ), 'CPU inference with RoIPool is not supported currently.' # forward the model with torch.no_grad(): results = model(return_loss=False, rescale=True, **data) if not is_batch: return results[0] else: return results
[docs]async def async_inference_detector(model, imgs): """Async inference image(s) with the detector. Args: model (nn.Module): The loaded detector. img (str | ndarray): Either image files or loaded images. Returns: Awaitable detection results. """ if not isinstance(imgs, (list, tuple)): imgs = [imgs] cfg = model.cfg device = next(model.parameters()).device # model device if isinstance(imgs[0], np.ndarray): cfg = cfg.copy() # set loading pipeline type cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) datas = [] for img in imgs: # prepare data if isinstance(img, np.ndarray): # directly add img data = dict(img=img) else: # add information into dict data = dict(img_info=dict(filename=img), img_prefix=None) # build the data pipeline data = test_pipeline(data) datas.append(data) data = collate(datas, samples_per_gpu=len(imgs)) # just get the actual data from DataContainer data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] data['img'] = [img.data[0] for img in data['img']] if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: for m in model.modules(): assert not isinstance( m, RoIPool ), 'CPU inference with RoIPool is not supported currently.' # We don't restore `torch.is_grad_enabled()` value during concurrent # inference since execution can overlap torch.set_grad_enabled(False) results = await model.aforward_test(rescale=True, **data) return results
[docs]def show_result_pyplot(model, img, result, score_thr=0.3, title='result', wait_time=0, palette=None): """Visualize the detection results on the image. Args: model (nn.Module): The loaded detector. img (str or np.ndarray): Image filename or loaded image. result (tuple[list] or list): The detection result, can be either (bbox, segm) or just bbox. score_thr (float): The threshold to visualize the bboxes and masks. title (str): Title of the pyplot figure. wait_time (float): Value of waitKey param. Default: 0. """ if hasattr(model, 'module'): model = model.module model.show_result( img, result, score_thr=score_thr, show=True, wait_time=wait_time, win_name=title, bbox_color=palette, text_color=palette, mask_color=palette)
Read the Docs v: v2.21.0
Versions
latest
stable
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
v2.12.0
v2.11.0
v2.10.0
v2.9.0
v2.8.0
v2.7.0
v2.6.0
v2.5.0
v2.4.0
v2.3.0
v2.2.1
v2.2.0
v2.1.0
v2.0.0
v1.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.