备注
您正在阅读 MMClassification 0.x 版本的文档。MMClassification 0.x 会在 2022 年末被切换为次要分支。建议您升级到 MMClassification 1.0 版本,体验更多新特性和新功能。请查阅 MMClassification 1.0 的安装教程、迁移教程以及更新日志。
mmcls.models.RegNet¶
- class mmcls.models.RegNet(arch, in_channels=3, stem_channels=32, base_channels=32, strides=(2, 2, 2, 2), dilations=(1, 1, 1, 1), out_indices=(3,), style='pytorch', deep_stem=False, avg_down=False, frozen_stages=- 1, conv_cfg=None, norm_cfg={'requires_grad': True, 'type': 'BN'}, norm_eval=False, with_cp=False, zero_init_residual=True, init_cfg=None)[源代码]¶
RegNet backbone.
More details can be found in paper .
- 参数
arch (dict) – The parameter of RegNets. - w0 (int): initial width - wa (float): slope of width - wm (float): quantization parameter to quantize the width - depth (int): depth of the backbone - group_w (int): width of group - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
strides (Sequence[int]) – Strides of the first block of each stage.
base_channels (int) – Base channels after stem layer.
in_channels (int) – Number of input image channels. Default: 3.
dilations (Sequence[int]) – Dilation of each stage.
out_indices (Sequence[int]) – Output from which stages.
style (str) – pytorch or caffe. If set to “pytorch”, the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: “pytorch”.
frozen_stages (int) – Stages to be frozen (all param fixed). -1 means not freezing any parameters. Default: -1.
norm_cfg (dict) – dictionary to construct and config norm layer. Default: dict(type=’BN’, requires_grad=True).
norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False.
with_cp (bool) – Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.
zero_init_residual (bool) – whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: True.
示例
>>> from mmcls.models import RegNet >>> import torch >>> inputs = torch.rand(1, 3, 32, 32) >>> # use str type 'arch' >>> # Note that default out_indices is (3,) >>> regnet_cfg = dict(arch='regnetx_4.0gf') >>> model = RegNet(**regnet_cfg) >>> model.eval() >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 1360, 1, 1) >>> # use dict type 'arch' >>> arch_cfg =dict(w0=88, wa=26.31, wm=2.25, >>> group_w=48, depth=25, bot_mul=1.0) >>> regnet_cfg = dict(arch=arch_cfg, out_indices=(0, 1, 2, 3)) >>> model = RegNet(**regnet_cfg) >>> model.eval() >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 96, 8, 8) (1, 192, 4, 4) (1, 432, 2, 2) (1, 1008, 1, 1)