You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

Source code for mmcls.models.backbones.seresnet

# Copyright (c) OpenMMLab. All rights reserved.
import torch.utils.checkpoint as cp

from ..builder import BACKBONES
from ..utils.se_layer import SELayer
from .resnet import Bottleneck, ResLayer, ResNet

class SEBottleneck(Bottleneck):
    """SEBottleneck block for SEResNet.

        in_channels (int): The input channels of the SEBottleneck block.
        out_channels (int): The output channel of the SEBottleneck block.
        se_ratio (int): Squeeze ratio in SELayer. Default: 16

    def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
        super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs)
        self.se_layer = SELayer(out_channels, ratio=se_ratio)

    def forward(self, x):

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.norm3(out)

            out = self.se_layer(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
            out = _inner_forward(x)

        out = self.relu(out)

        return out

[docs]@BACKBONES.register_module() class SEResNet(ResNet): """SEResNet backbone. Please refer to the `paper <>`__ for details. Args: depth (int): Network depth, from {50, 101, 152}. se_ratio (int): Squeeze ratio in SELayer. Default: 16. in_channels (int): Number of input image channels. Default: 3. stem_channels (int): Output channels of the stem layer. Default: 64. num_stages (int): Stages of the network. Default: 4. strides (Sequence[int]): Strides of the first block of each stage. Default: ``(1, 2, 2, 2)``. dilations (Sequence[int]): Dilation of each stage. Default: ``(1, 1, 1, 1)``. out_indices (Sequence[int]): Output from which stages. If only one stage is specified, a single tensor (feature map) is returned, otherwise multiple stages are specified, a tuple of tensors will be returned. Default: ``(3, )``. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. Default: False. avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1. conv_cfg (dict | None): The config dict for conv layers. Default: None. norm_cfg (dict): The config dict for norm layers. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: True. Example: >>> from mmcls.models import SEResNet >>> import torch >>> self = SEResNet(depth=50) >>> self.eval() >>> inputs = torch.rand(1, 3, 224, 224) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 64, 56, 56) (1, 128, 28, 28) (1, 256, 14, 14) (1, 512, 7, 7) """ arch_settings = { 50: (SEBottleneck, (3, 4, 6, 3)), 101: (SEBottleneck, (3, 4, 23, 3)), 152: (SEBottleneck, (3, 8, 36, 3)) } def __init__(self, depth, se_ratio=16, **kwargs): if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for SEResNet') self.se_ratio = se_ratio super(SEResNet, self).__init__(depth, **kwargs) def make_res_layer(self, **kwargs): return ResLayer(se_ratio=self.se_ratio, **kwargs)
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.