Shortcuts

模型部署

MMDeploy 是 OpenMMLab 的部署仓库,负责包括 MMPretrain、MMDetection 等在内的各算法库的部署工作。 你可以从这里获取 MMDeploy 对 MMDetection 部署支持的最新文档。

本文的结构如下:

安装

请参考此处安装 mmdet。然后,按照说明安装 mmdeploy。

注解

如果安装的是 mmdeploy 预编译包,那么也请通过 ‘git clone https://github.com/open-mmlab/mmdeploy.git –depth=1’ 下载 mmdeploy 源码。因为它包含了部署时要用到的配置文件

模型转换

假设在安装步骤中,mmdetection 和 mmdeploy 代码库在同级目录下,并且当前的工作目录为 mmdetection 的根目录,那么以 Faster R-CNN 模型为例,你可以从此处下载对应的 checkpoint,并使用以下代码将之转换为 onnx 模型:

from mmdeploy.apis import torch2onnx
from mmdeploy.backend.sdk.export_info import export2SDK

img = 'demo/demo.jpg'
work_dir = 'mmdeploy_models/mmdet/onnx'
save_file = 'end2end.onnx'
deploy_cfg = '../mmdeploy/configs/mmdet/detection/detection_onnxruntime_dynamic.py'
model_cfg = 'configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'
model_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
device = 'cpu'

# 1. convert model to onnx
torch2onnx(img, work_dir, save_file, deploy_cfg, model_cfg,
           model_checkpoint, device)

# 2. extract pipeline info for inference by MMDeploy SDK
export2SDK(deploy_cfg, model_cfg, work_dir, pth=model_checkpoint,
           device=device)

转换的关键之一是使用正确的配置文件。项目中已内置了各后端部署配置文件。 文件的命名模式是:

{task}/{task}_{backend}-{precision}_{static | dynamic}_{shape}.py

其中:

  • {task}: mmdet 中的任务

    mmdet 任务有2种:物体检测(detection)、实例分割(instance-seg)。例如,RetinaNetFaster R-CNNDETR等属于前者。Mask R-CNNSOLO等属于后者。更多模型-任务的划分,请参考章节模型支持列表

    请务必使用 detection/detection_*.py 转换检测模型,使用 instance-seg/instance-seg_*.py 转换实例分割模型。

  • {backend}: 推理后端名称。比如,onnxruntime、tensorrt、pplnn、ncnn、openvino、coreml 等等

  • {precision}: 推理精度。比如,fp16、int8。不填表示 fp32

  • {static | dynamic}: 动态、静态 shape

  • {shape}: 模型输入的 shape 或者 shape 范围

在上例中,你也可以把 Faster R-CNN 转为其他后端模型。比如使用detection_tensorrt-fp16_dynamic-320x320-1344x1344.py,把模型转为 tensorrt-fp16 模型。

小技巧

当转 tensorrt 模型时, –device 需要被设置为 “cuda”

模型规范

在使用转换后的模型进行推理之前,有必要了解转换结果的结构。 它存放在 --work-dir 指定的路路径下。

上例中的mmdeploy_models/mmdet/onnx,结构如下:

mmdeploy_models/mmdet/onnx
├── deploy.json
├── detail.json
├── end2end.onnx
└── pipeline.json

重要的是:

  • end2end.onnx: 推理引擎文件。可用 ONNX Runtime 推理

  • xxx.json: mmdeploy SDK 推理所需的 meta 信息

整个文件夹被定义为mmdeploy SDK model。换言之,mmdeploy SDK model既包括推理引擎,也包括推理 meta 信息。

模型推理

后端模型推理

以上述模型转换后的 end2end.onnx 为例,你可以使用如下代码进行推理:

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config
import torch

deploy_cfg = '../mmdeploy/configs/mmdet/detection/detection_onnxruntime_dynamic.py'
model_cfg = 'configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'
device = 'cpu'
backend_model = ['mmdeploy_models/mmdet/onnx/end2end.onnx']
image = 'demo/demo.jpg'

# read deploy_cfg and model_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)

# build task and backend model
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
model = task_processor.build_backend_model(backend_model)

# process input image
input_shape = get_input_shape(deploy_cfg)
model_inputs, _ = task_processor.create_input(image, input_shape)

# do model inference
with torch.no_grad():
    result = model.test_step(model_inputs)

# visualize results
task_processor.visualize(
    image=image,
    model=model,
    result=result[0],
    window_name='visualize',
    output_file='output_detection.png')

SDK 模型推理

你也可以参考如下代码,对 SDK model 进行推理:

from mmdeploy_python import Detector
import cv2

img = cv2.imread('demo/demo.jpg')

# create a detector
detector = Detector(model_path='mmdeploy_models/mmdet/onnx',
                    device_name='cpu', device_id=0)
# perform inference
bboxes, labels, masks = detector(img)

# visualize inference result
indices = [i for i in range(len(bboxes))]
for index, bbox, label_id in zip(indices, bboxes, labels):
    [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
    if score < 0.3:
        continue

    cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0))

cv2.imwrite('output_detection.png', img)

除了python API,mmdeploy SDK 还提供了诸如 C、C++、C#、Java等多语言接口。 你可以参考样例学习其他语言接口的使用方法。

模型支持列表

请参考这里

Read the Docs v: stable
Versions
latest
stable
3.x
v3.3.0
v3.2.0
v3.1.0
v3.0.0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.