Shortcuts

教程 10: 权重初始化

在训练过程中,适当的初始化策略有利于加快训练速度或获得更⾼的性能。 MMCV 提供了一些常⽤的初始化模块的⽅法,如 nn.Conv2d。 MMdetection 中的模型初始化主要使⽤ init_cfg。⽤⼾可以通过以下两个步骤来初始化模型:

  1. model_cfg 中为模型或其组件定义 init_cfg,但⼦组件的 init_cfg 优先级更⾼,会覆盖⽗模块的 init_cfg

  2. 像往常一样构建模型,然后显式调⽤ model.init_weights() ⽅法,此时模型参数将会被按照配置文件写法进行初始化。

MMdetection 初始化工作流的高层 API 调用流程是:

model_cfg(init_cfg) -> build_from_cfg -> model -> init_weight() -> initialize(self, self.init_cfg) -> children’s init_weight()

描述

它的数据类型是 dict 或者 list[dict],包含了下列键值:

  • type (str),包含 INTIALIZERS 中的初始化器名称,后面跟着初始化器的参数。

  • layer(str 或 list[str]),包含 Pytorch 或 MMCV 中基本层的名称,以及将被初始化的可学习参数,例如 'Conv2d''DeformConv2d'

  • override (dict 或 list[dict]),包含不继承⾃ BaseModule 且其初始化配置与 layer 键中的其他层不同的⼦模块。 type 中定义的初始化器将适⽤于 layer 中定义的所有层,因此如果⼦模块不是 BaseModule 的派⽣类但可以与 layer 中的层相同的⽅式初始化,则不需要使⽤ overrideoverride 包含了:

    • type 后跟初始化器的参数;

    • name 用以指⽰将被初始化的⼦模块。

初始化参数

mmcv.runner.BaseModulemmdet.models 继承一个新模型。这里我们用 FooModel 来举个例子。

import torch.nn as nn
from mmcv.runner import BaseModule

class FooModel(BaseModule)
	def __init__(self,
                 arg1,
                 arg2,
                 init_cfg=None):
    	super(FooModel, self).__init__(init_cfg)
		...
  • 直接在代码中使⽤ init_cfg 初始化模型

    import torch.nn as nn
    from mmcv.runner import BaseModule
    # or directly inherit mmdet models
    
    class FooModel(BaseModule)
    	def __init__(self,
                    arg1,
                    arg2,
                    init_cfg=XXX):
      		super(FooModel, self).__init__(init_cfg)
      	    ...
    
  • mmcv.Sequentialmmcv.ModuleList 代码中直接使⽤ init_cfg 初始化模型

    from mmcv.runner import BaseModule, ModuleList
    
    class FooModel(BaseModule)
    	def __init__(self,
                  	arg1,
                  	arg2,
                  	init_cfg=None):
      		super(FooModel, self).__init__(init_cfg)
          	...
          	self.conv1 = ModuleList(init_cfg=XXX)
    
  • 使⽤配置⽂件中的 init_cfg 初始化模型

    model = dict(
    	...
      	model = dict(
          	type='FooModel',
          	arg1=XXX,
          	arg2=XXX,
          	init_cfg=XXX),
              ...
    

init_cfg 的使用

  1. layer 键初始化模型

    如果我们只定义了 layer, 它只会在 layer 键中初始化网络层。

    注意: layer 键对应的值是 Pytorch 的带有 weights 和 bias 属性的类名(因此不⽀持 MultiheadAttention 层)。

  • 定义⽤于初始化具有相同配置的模块的 layer 键。

    init_cfg = dict(type='Constant', layer=['Conv1d', 'Conv2d', 'Linear'], val=1)
    # ⽤相同的配置初始化整个模块
    
  • 定义⽤于初始化具有不同配置的层的 layer 键。

    init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
                dict(type='Constant', layer='Conv2d', val=2),
                dict(type='Constant', layer='Linear', val=3)]
    # nn.Conv1d 将被初始化为 dict(type='Constant', val=1)
    # nn.Conv2d 将被初始化为 dict(type='Constant', val=2)
    # nn.Linear 将被初始化为 dict(type='Constant', val=3)
    
  1. 使⽤ override 键初始化模型

  • 当使⽤属性名初始化某些特定部分时,我们可以使⽤ override 键, override 中的值将忽略 init_cfg 中的值。

    # layers:
    # self.feat = nn.Conv1d(3, 1, 3)
    # self.reg = nn.Conv2d(3, 3, 3)
    # self.cls = nn.Linear(1,2)
    
    init_cfg = dict(type='Constant',
                    layer=['Conv1d','Conv2d'], val=1, bias=2,
                    override=dict(type='Constant', name='reg', val=3, bias=4))
    # self.feat and self.cls 将被初始化为 dict(type='Constant', val=1, bias=2)
    # 叫 'reg' 的模块将被初始化为 dict(type='Constant', val=3, bias=4)
    
  • 如果 init_cfg 中的 layer 为 None,则只会初始化 override 中有 name 的⼦模块,⽽ override 中的 type 和其他参数可以省略。

    # layers:
    # self.feat = nn.Conv1d(3, 1, 3)
    # self.reg = nn.Conv2d(3, 3, 3)
    # self.cls = nn.Linear(1,2)
    
    init_cfg = dict(type='Constant', val=1, bias=2, 	override=dict(name='reg'))
    
    # self.feat and self.cls 将被 Pytorch 初始化
    # 叫 'reg' 的模块将被 dict(type='Constant', val=1, bias=2) 初始化
    
  • 如果我们不定义 layeroverride 键,它不会初始化任何东西。

  • 无效的使用

    # override 没有 name 键的话是无效的
    init_cfg = dict(type='Constant', layer=['Conv1d','Conv2d'], val=1, bias=2,
                	override=dict(type='Constant', val=3, bias=4))
    
    # override 有 name 键和其他参数但是没有 type 键也是无效的
    init_cfg = dict(type='Constant', layer=['Conv1d','Conv2d'], val=1, bias=2,
                    override=dict(name='reg', val=3, bias=4))
    
  1. 使⽤预训练模型初始化模型

    init_cfg = dict(type='Pretrained',
                 checkpoint='torchvision://resnet50')
    

更多细节可以参考 MMCV 的文档和 MMCV PR #780

Read the Docs v: v2.24.1
Versions
latest
stable
v2.24.1
v2.24.0
v2.23.0
v2.22.0
v2.21.0
v2.20.0
v2.19.1
v2.19.0
v2.18.1
v2.18.0
v2.17.0
v2.16.0
v2.15.1
v2.15.0
v2.14.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.