Shortcuts

自定义模型

我们简单地把模型的各个组件分为五类:

  • 主干网络 (backbone):通常是一个用来提取特征图 (feature map) 的全卷积网络 (FCN network),例如:ResNet, MobileNet。

  • Neck:主干网络和 Head 之间的连接部分,例如:FPN, PAFPN。

  • Head:用于具体任务的组件,例如:边界框预测和掩码预测。

  • 区域提取器 (roi extractor):从特征图中提取 RoI 特征,例如:RoI Align。

  • 损失 (loss):在 Head 组件中用于计算损失的部分,例如:FocalLoss, L1Loss, GHMLoss.

开发新的组件

添加一个新的主干网络

这里,我们以 MobileNet 为例来展示如何开发新组件。

1. 定义一个新的主干网络(以 MobileNet 为例)

新建一个文件 mmdet/models/backbones/mobilenet.py

import torch.nn as nn

from mmdet.registry import MODELS


@MODELS.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

2. 导入该模块

你可以添加下述代码到 mmdet/models/backbones/__init__.py

from .mobilenet import MobileNet

或添加:

custom_imports = dict(
    imports=['mmdet.models.backbones.mobilenet'],
    allow_failed_imports=False)

到配置文件以避免原始代码被修改。

3. 在你的配置文件中使用该主干网络

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

添加新的 Neck

1. 定义一个 Neck(以 PAFPN 为例)

新建一个文件 mmdet/models/necks/pafpn.py

import torch.nn as nn

from mmdet.registry import MODELS


@MODELS.register_module()
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. 导入该模块

你可以添加下述代码到 mmdet/models/necks/__init__.py

from .pafpn import PAFPN

或添加:

custom_imports = dict(
    imports=['mmdet.models.necks.pafpn'],
    allow_failed_imports=False)

到配置文件以避免原始代码被修改。

3. 修改配置文件

neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

添加新的损失

假设你想添加一个新的损失 MyLoss 用于边界框回归。 为了添加一个新的损失函数,用户需要在 mmdet/models/losses/my_loss.py 中实现。 装饰器 weighted_loss 可以使损失每个部分加权。

import torch
import torch.nn as nn

from mmdet.registry import LOSSES
from .utils import weighted_loss


@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

然后,用户需要把它加到 mmdet/models/losses/__init__.py

from .my_loss import MyLoss, my_loss

或者,你可以添加:

custom_imports=dict(
    imports=['mmdet.models.losses.my_loss'])

到配置文件来实现相同的目的。

如使用,请修改 loss_xxx 字段。 因为 MyLoss 是用于回归的,你需要在 Head 中修改 loss_xxx 字段。

loss_bbox=dict(type='MyLoss', loss_weight=1.0))
Read the Docs v: stable
Versions
latest
stable
3.x
v3.3.0
v3.2.0
v3.1.0
v3.0.0
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.