Shortcuts

Note

You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

Source code for mmcls.models.backbones.van

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmcv.runner import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm

from ..builder import BACKBONES
from .base_backbone import BaseBackbone


class MixFFN(BaseModule):
    """An implementation of MixFFN of VAN. Refer to
    mmdetection/mmdet/models/backbones/pvt.py.

    The differences between MixFFN & FFN:
        1. Use 1X1 Conv to replace Linear layer.
        2. Introduce 3X3 Depth-wise Conv to encode positional information.

    Args:
        embed_dims (int): The feature dimension. Same as
            `MultiheadAttention`.
        feedforward_channels (int): The hidden dimension of FFNs.
        act_cfg (dict, optional): The activation config for FFNs.
            Default: dict(type='GELU').
        ffn_drop (float, optional): Probability of an element to be
            zeroed in FFN. Default 0.0.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 act_cfg=dict(type='GELU'),
                 ffn_drop=0.,
                 init_cfg=None):
        super(MixFFN, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.act_cfg = act_cfg

        self.fc1 = Conv2d(
            in_channels=embed_dims,
            out_channels=feedforward_channels,
            kernel_size=1)
        self.dwconv = Conv2d(
            in_channels=feedforward_channels,
            out_channels=feedforward_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=True,
            groups=feedforward_channels)
        self.act = build_activation_layer(act_cfg)
        self.fc2 = Conv2d(
            in_channels=feedforward_channels,
            out_channels=embed_dims,
            kernel_size=1)
        self.drop = nn.Dropout(ffn_drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LKA(BaseModule):
    """Large Kernel Attention(LKA) of VAN.

    .. code:: text
            DW_conv (depth-wise convolution)
                            |
                            |
        DW_D_conv (depth-wise dilation convolution)
                            |
                            |
        Transition Convolution (1×1 convolution)

    Args:
        embed_dims (int): Number of input channels.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self, embed_dims, init_cfg=None):
        super(LKA, self).__init__(init_cfg=init_cfg)

        # a spatial local convolution (depth-wise convolution)
        self.DW_conv = Conv2d(
            in_channels=embed_dims,
            out_channels=embed_dims,
            kernel_size=5,
            padding=2,
            groups=embed_dims)

        # a spatial long-range convolution (depth-wise dilation convolution)
        self.DW_D_conv = Conv2d(
            in_channels=embed_dims,
            out_channels=embed_dims,
            kernel_size=7,
            stride=1,
            padding=9,
            groups=embed_dims,
            dilation=3)

        self.conv1 = Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

    def forward(self, x):
        u = x.clone()
        attn = self.DW_conv(x)
        attn = self.DW_D_conv(attn)
        attn = self.conv1(attn)

        return u * attn


class SpatialAttention(BaseModule):
    """Basic attention module in VANBloack.

    Args:
        embed_dims (int): Number of input channels.
        act_cfg (dict, optional): The activation config for FFNs.
            Default: dict(type='GELU').
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
        super(SpatialAttention, self).__init__(init_cfg=init_cfg)

        self.proj_1 = Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
        self.activation = build_activation_layer(act_cfg)
        self.spatial_gating_unit = LKA(embed_dims)
        self.proj_2 = Conv2d(
            in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x


class VANBlock(BaseModule):
    """A block of VAN.

    Args:
        embed_dims (int): Number of input channels.
        ffn_ratio (float): The expansion ratio of feedforward network hidden
            layer channels. Defaults to 4.
        drop_rate (float): Dropout rate after embedding. Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
        act_cfg (dict, optional): The activation config for FFNs.
            Default: dict(type='GELU').
        layer_scale_init_value (float): Init value for Layer Scale.
            Defaults to 1e-2.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 ffn_ratio=4.,
                 drop_rate=0.,
                 drop_path_rate=0.,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='BN', eps=1e-5),
                 layer_scale_init_value=1e-2,
                 init_cfg=None):
        super(VANBlock, self).__init__(init_cfg=init_cfg)
        self.out_channels = embed_dims

        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
        self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg)
        self.drop_path = DropPath(
            drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
        mlp_hidden_dim = int(embed_dims * ffn_ratio)
        self.mlp = MixFFN(
            embed_dims=embed_dims,
            feedforward_channels=mlp_hidden_dim,
            act_cfg=act_cfg,
            ffn_drop=drop_rate)
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((embed_dims)),
            requires_grad=True) if layer_scale_init_value > 0 else None
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((embed_dims)),
            requires_grad=True) if layer_scale_init_value > 0 else None

    def forward(self, x):
        identity = x
        x = self.norm1(x)
        x = self.attn(x)
        if self.layer_scale_1 is not None:
            x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x
        x = identity + self.drop_path(x)

        identity = x
        x = self.norm2(x)
        x = self.mlp(x)
        if self.layer_scale_2 is not None:
            x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x
        x = identity + self.drop_path(x)

        return x


class VANPatchEmbed(PatchEmbed):
    """Image to Patch Embedding of VAN.

    The differences between VANPatchEmbed & PatchEmbed:
        1. Use BN.
        2. Do not use 'flatten' and 'transpose'.
    """

    def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
        super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)

    def forward(self, x):
        """
        Args:
            x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
        Returns:
            tuple: Contains merged results and its spatial shape.
            - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
            - out_size (tuple[int]): Spatial shape of x, arrange as
              (out_h, out_w).
        """

        if self.adaptive_padding:
            x = self.adaptive_padding(x)

        x = self.projection(x)
        out_size = (x.shape[2], x.shape[3])
        if self.norm is not None:
            x = self.norm(x)
        return x, out_size


[docs]@BACKBONES.register_module() class VAN(BaseBackbone): """Visual Attention Network. A PyTorch implement of : `Visual Attention Network <https://arxiv.org/pdf/2202.09741v2.pdf>`_ Inspiration from https://github.com/Visual-Attention-Network/VAN-Classification Args: arch (str | dict): Visual Attention Network architecture. If use string, choose from 'b0', 'b1', b2', b3' and etc., if use dict, it should have below keys: - **embed_dims** (List[int]): The dimensions of embedding. - **depths** (List[int]): The number of blocks in each stage. - **ffn_ratios** (List[int]): The number of expansion ratio of feedforward network hidden layer channels. Defaults to 'tiny'. patch_sizes (List[int | tuple]): The patch size in patch embeddings. Defaults to [7, 3, 3, 3]. in_channels (int): The num of input channels. Defaults to 3. drop_rate (float): Dropout rate after embedding. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. out_indices (Sequence[int]): Output from which stages. Default: ``(3, )``. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to False. norm_cfg (dict): Config dict for normalization layer for all output features. Defaults to ``dict(type='LN')`` block_cfgs (Sequence[dict] | dict): The extra config of each block. Defaults to empty dicts. init_cfg (dict, optional): The Config for initialization. Defaults to None. Examples: >>> from mmcls.models import VAN >>> import torch >>> model = VAN(arch='b0') >>> inputs = torch.rand(1, 3, 224, 224) >>> outputs = model(inputs) >>> for out in outputs: >>> print(out.size()) (1, 256, 7, 7) """ arch_zoo = { **dict.fromkeys(['b0', 't', 'tiny'], {'embed_dims': [32, 64, 160, 256], 'depths': [3, 3, 5, 2], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b1', 's', 'small'], {'embed_dims': [64, 128, 320, 512], 'depths': [2, 2, 4, 2], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b2', 'b', 'base'], {'embed_dims': [64, 128, 320, 512], 'depths': [3, 3, 12, 3], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b3', 'l', 'large'], {'embed_dims': [64, 128, 320, 512], 'depths': [3, 5, 27, 3], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b4'], {'embed_dims': [64, 128, 320, 512], 'depths': [3, 6, 40, 3], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b5'], {'embed_dims': [96, 192, 480, 768], 'depths': [3, 3, 24, 3], 'ffn_ratios': [8, 8, 4, 4]}), **dict.fromkeys(['b6'], {'embed_dims': [96, 192, 384, 768], 'depths': [6, 6, 90, 6], 'ffn_ratios': [8, 8, 4, 4]}), } # yapf: disable def __init__(self, arch='tiny', patch_sizes=[7, 3, 3, 3], in_channels=3, drop_rate=0., drop_path_rate=0., out_indices=(3, ), frozen_stages=-1, norm_eval=False, norm_cfg=dict(type='LN'), block_cfgs=dict(), init_cfg=None): super(VAN, self).__init__(init_cfg=init_cfg) if isinstance(arch, str): arch = arch.lower() assert arch in set(self.arch_zoo), \ f'Arch {arch} is not in default archs {set(self.arch_zoo)}' self.arch_settings = self.arch_zoo[arch] else: essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} assert isinstance(arch, dict) and set(arch) == essential_keys, \ f'Custom arch needs a dict with keys {essential_keys}' self.arch_settings = arch self.embed_dims = self.arch_settings['embed_dims'] self.depths = self.arch_settings['depths'] self.ffn_ratios = self.arch_settings['ffn_ratios'] self.num_stages = len(self.depths) self.out_indices = out_indices self.frozen_stages = frozen_stages self.norm_eval = norm_eval total_depth = sum(self.depths) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] # stochastic depth decay rule cur_block_idx = 0 for i, depth in enumerate(self.depths): patch_embed = VANPatchEmbed( in_channels=in_channels if i == 0 else self.embed_dims[i - 1], input_size=None, embed_dims=self.embed_dims[i], kernel_size=patch_sizes[i], stride=patch_sizes[i] // 2 + 1, padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), norm_cfg=dict(type='BN')) blocks = ModuleList([ VANBlock( embed_dims=self.embed_dims[i], ffn_ratio=self.ffn_ratios[i], drop_rate=drop_rate, drop_path_rate=dpr[cur_block_idx + j], **block_cfgs) for j in range(depth) ]) cur_block_idx += depth norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1] self.add_module(f'patch_embed{i + 1}', patch_embed) self.add_module(f'blocks{i + 1}', blocks) self.add_module(f'norm{i + 1}', norm)
[docs] def train(self, mode=True): super(VAN, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()
def _freeze_stages(self): for i in range(0, self.frozen_stages + 1): # freeze patch embed m = getattr(self, f'patch_embed{i + 1}') m.eval() for param in m.parameters(): param.requires_grad = False # freeze blocks m = getattr(self, f'blocks{i + 1}') m.eval() for param in m.parameters(): param.requires_grad = False # freeze norm m = getattr(self, f'norm{i + 1}') m.eval() for param in m.parameters(): param.requires_grad = False
[docs] def forward(self, x): outs = [] for i in range(self.num_stages): patch_embed = getattr(self, f'patch_embed{i + 1}') blocks = getattr(self, f'blocks{i + 1}') norm = getattr(self, f'norm{i + 1}') x, hw_shape = patch_embed(x) for block in blocks: x = block(x) x = x.flatten(2).transpose(1, 2) x = norm(x) x = x.reshape(-1, *hw_shape, block.out_channels).permute(0, 3, 1, 2).contiguous() if i in self.out_indices: outs.append(x) return tuple(outs)
Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.