Shortcuts

教程 4: 自定义模型

我们简单地把模型的各个组件分为五类:

  • 主干网络 (backbone):通常是一个用来提取特征图 (feature map) 的全卷积网络 (FCN network),例如:ResNet, MobileNet。

  • Neck:主干网络和 Head 之间的连接部分,例如:FPN, PAFPN。

  • Head:用于具体任务的组件,例如:边界框预测和掩码预测。

  • 区域提取器 (roi extractor):从特征图中提取 RoI 特征,例如:RoI Align。

  • 损失 (loss):在 Head 组件中用于计算损失的部分,例如:FocalLoss, L1Loss, GHMLoss.

开发新的组件

添加一个新的主干网络

这里,我们以 MobileNet 为例来展示如何开发新组件。

1. 定义一个新的主干网络(以 MobileNet 为例)

新建一个文件 mmdet/models/backbones/mobilenet.py

import torch.nn as nn

from ..builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

2. 导入该模块

你可以添加下述代码到 mmdet/models/backbones/__init__.py

from .mobilenet import MobileNet

或添加:

custom_imports = dict(
    imports=['mmdet.models.backbones.mobilenet'],
    allow_failed_imports=False)

到配置文件以避免原始代码被修改。

3. 在你的配置文件中使用该主干网络

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

添加新的 Neck

1. 定义一个 Neck(以 PAFPN 为例)

新建一个文件 mmdet/models/necks/pafpn.py

from ..builder import NECKS

@NECKS.register_module()
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # implementation is ignored
        pass

2. 导入该模块

你可以添加下述代码到 mmdet/models/necks/__init__.py

from .pafpn import PAFPN

或添加:

custom_imports = dict(
    imports=['mmdet.models.necks.pafpn.py'],
    allow_failed_imports=False)

到配置文件以避免原始代码被修改。

3. 修改配置文件

neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

添加新的损失

假设你想添加一个新的损失 MyLoss 用于边界框回归。 为了添加一个新的损失函数,用户需要在 mmdet/models/losses/my_loss.py 中实现。 装饰器 weighted_loss 可以使损失每个部分加权。

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

然后,用户需要把它加到 mmdet/models/losses/__init__.py

from .my_loss import MyLoss, my_loss

或者,你可以添加:

custom_imports=dict(
    imports=['mmdet.models.losses.my_loss'])

到配置文件来实现相同的目的。

如使用,请修改 loss_xxx 字段。 因为 MyLoss 是用于回归的,你需要在 Head 中修改 loss_xxx 字段。

loss_bbox=dict(type='MyLoss', loss_weight=1.0))
Read the Docs v: v2.28.2
Versions
latest
stable
3.x
v2.28.2
v2.28.1
v2.28.0
v2.27.0
v2.26.0
v2.25.3
v2.25.2
v2.25.1
v2.25.0
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
v2.13.0
dev-3.x
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.