Shortcuts

Note

You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

Source code for mmcls.models.backbones.efficientformer

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Optional, Sequence

import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
                             build_norm_layer)
from mmcv.runner import BaseModule, ModuleList, Sequential

from ..builder import BACKBONES
from ..utils import LayerScale
from .base_backbone import BaseBackbone
from .poolformer import Pooling


class AttentionWithBias(BaseModule):
    """Multi-head Attention Module with attention_bias.

    Args:
        embed_dims (int): The embedding dimension.
        num_heads (int): Parallel attention heads. Defaults to 8.
        key_dim (int): The dimension of q, k. Defaults to 32.
        attn_ratio (float): The dimension of v equals to
            ``key_dim * attn_ratio``. Defaults to 4.
        resolution (int): The height and width of attention_bias.
            Defaults to 7.
        init_cfg (dict, optional): The Config for initialization.
            Defaults to None.
    """

    def __init__(self,
                 embed_dims,
                 num_heads=8,
                 key_dim=32,
                 attn_ratio=4.,
                 resolution=7,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.num_heads = num_heads
        self.scale = key_dim**-0.5
        self.attn_ratio = attn_ratio
        self.key_dim = key_dim
        self.nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        h = self.dh + self.nh_kd * 2
        self.qkv = nn.Linear(embed_dims, h)
        self.proj = nn.Linear(self.dh, embed_dims)

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(N, N))

    @torch.no_grad()
    def train(self, mode=True):
        """change the mode of model."""
        super().train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):
        """forward function.

        Args:
            x (tensor): input features with shape of (B, N, C)
        """
        B, N, _ = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)

        attn = ((q @ k.transpose(-2, -1)) * self.scale +
                (self.attention_biases[:, self.attention_bias_idxs]
                 if self.training else self.ab))
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
        x = self.proj(x)
        return x


class Flat(nn.Module):
    """Flat the input from (B, C, H, W) to (B, H*W, C)."""

    def __init__(self, ):
        super().__init__()

    def forward(self, x: torch.Tensor):
        x = x.flatten(2).transpose(1, 2)
        return x


class LinearMlp(BaseModule):
    """Mlp implemented with linear.

    The shape of input and output tensor are (B, N, C).

    Args:
        in_features (int): Dimension of input features.
        hidden_features (int): Dimension of hidden features.
        out_features (int): Dimension of output features.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='BN')``.
        act_cfg (dict): The config dict for activation between pointwise
            convolution. Defaults to ``dict(type='GELU')``.
        drop (float): Dropout rate. Defaults to 0.0.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = build_activation_layer(act_cfg)
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor with shape (B, N, C).

        Returns:
            torch.Tensor: output tensor with shape (B, N, C).
        """
        x = self.drop1(self.act(self.fc1(x)))
        x = self.drop2(self.fc2(x))
        return x


class ConvMlp(BaseModule):
    """Mlp implemented with 1*1 convolutions.

    Args:
        in_features (int): Dimension of input features.
        hidden_features (int): Dimension of hidden features.
        out_features (int): Dimension of output features.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='BN')``.
        act_cfg (dict): The config dict for activation between pointwise
            convolution. Defaults to ``dict(type='GELU')``.
        drop (float): Dropout rate. Defaults to 0.0.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.act = build_activation_layer(act_cfg)
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
        self.norm2 = build_norm_layer(norm_cfg, out_features)[1]

        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor with shape (B, C, H, W).

        Returns:
            torch.Tensor: output tensor with shape (B, C, H, W).
        """

        x = self.act(self.norm1(self.fc1(x)))
        x = self.drop(x)
        x = self.norm2(self.fc2(x))
        x = self.drop(x)
        return x


class Meta3D(BaseModule):
    """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
    (B, N, C)."""

    def __init__(self,
                 dim,
                 mlp_ratio=4.,
                 norm_cfg=dict(type='LN'),
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 drop_path=0.,
                 use_layer_scale=True,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.norm1 = build_norm_layer(norm_cfg, dim)[1]
        self.token_mixer = AttentionWithBias(dim)
        self.norm2 = build_norm_layer(norm_cfg, dim)[1]
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = LinearMlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_cfg=act_cfg,
            drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        if use_layer_scale:
            self.ls1 = LayerScale(dim)
            self.ls2 = LayerScale(dim)
        else:
            self.ls1, self.ls2 = nn.Identity(), nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        return x


class Meta4D(BaseModule):
    """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
    (B, C, H, W)."""

    def __init__(self,
                 dim,
                 pool_size=3,
                 mlp_ratio=4.,
                 act_cfg=dict(type='GELU'),
                 drop=0.,
                 drop_path=0.,
                 use_layer_scale=True,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)

        self.token_mixer = Pooling(pool_size=pool_size)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = ConvMlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_cfg=act_cfg,
            drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        if use_layer_scale:
            self.ls1 = LayerScale(dim, data_format='channels_first')
            self.ls2 = LayerScale(dim, data_format='channels_first')
        else:
            self.ls1, self.ls2 = nn.Identity(), nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.token_mixer(x)))
        x = x + self.drop_path(self.ls2(self.mlp(x)))
        return x


def basic_blocks(in_channels,
                 out_channels,
                 index,
                 layers,
                 pool_size=3,
                 mlp_ratio=4.,
                 act_cfg=dict(type='GELU'),
                 drop_rate=.0,
                 drop_path_rate=0.,
                 use_layer_scale=True,
                 vit_num=1,
                 has_downsamper=False):
    """generate EfficientFormer blocks for a stage."""
    blocks = []
    if has_downsamper:
        blocks.append(
            ConvModule(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=True,
                norm_cfg=dict(type='BN'),
                act_cfg=None))
    if index == 3 and vit_num == layers[index]:
        blocks.append(Flat())
    for block_idx in range(layers[index]):
        block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
            sum(layers) - 1)
        if index == 3 and layers[index] - block_idx <= vit_num:
            blocks.append(
                Meta3D(
                    out_channels,
                    mlp_ratio=mlp_ratio,
                    act_cfg=act_cfg,
                    drop=drop_rate,
                    drop_path=block_dpr,
                    use_layer_scale=use_layer_scale,
                ))
        else:
            blocks.append(
                Meta4D(
                    out_channels,
                    pool_size=pool_size,
                    act_cfg=act_cfg,
                    drop=drop_rate,
                    drop_path=block_dpr,
                    use_layer_scale=use_layer_scale))
            if index == 3 and layers[index] - block_idx - 1 == vit_num:
                blocks.append(Flat())
    blocks = nn.Sequential(*blocks)
    return blocks


[docs]@BACKBONES.register_module() class EfficientFormer(BaseBackbone): """EfficientFormer. A PyTorch implementation of EfficientFormer introduced by: `EfficientFormer: Vision Transformers at MobileNet Speed <https://arxiv.org/abs/2206.01191>`_ Modified from the `official repo <https://github.com/snap-research/EfficientFormer>`. Args: arch (str | dict): The model's architecture. If string, it should be one of architecture in ``EfficientFormer.arch_settings``. And if dict, it should include the following 4 keys: - layers (list[int]): Number of blocks at each stage. - embed_dims (list[int]): The number of channels at each stage. - downsamples (list[int]): Has downsample or not in the four stages. - vit_num (int): The num of vit blocks in the last stage. Defaults to 'l1'. in_channels (int): The num of input channels. Defaults to 3. pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3. mlp_ratios (int): The dimension ratio of multi-head attention mechanism in ``Meta4D`` blocks. Defaults to 3. reshape_last_feat (bool): Whether to reshape the feature map from (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num`` in ``arch`` is not 0. Defaults to False. Usually set to True in downstream tasks. out_indices (Sequence[int]): Output from which stages. Defaults to -1. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. act_cfg (dict): The config dict for activation between pointwise convolution. Defaults to ``dict(type='GELU')``. drop_rate (float): Dropout rate. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer block. Defaults to True. init_cfg (dict, optional): Initialization config dict. Defaults to None. Example: >>> from mmcls.models import EfficientFormer >>> import torch >>> inputs = torch.rand((1, 3, 224, 224)) >>> # build EfficientFormer backbone for classification task >>> model = EfficientFormer(arch="l1") >>> model.eval() >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 448, 49) >>> # build EfficientFormer backbone for downstream task >>> model = EfficientFormer( >>> arch="l3", >>> out_indices=(0, 1, 2, 3), >>> reshape_last_feat=True) >>> model.eval() >>> level_outputs = model(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 64, 56, 56) (1, 128, 28, 28) (1, 320, 14, 14) (1, 512, 7, 7) """ # noqa: E501 # --layers: [x,x,x,x], numbers of layers for the four stages # --embed_dims: [x,x,x,x], embedding dims for the four stages # --downsamples: [x,x,x,x], has downsample or not in the four stages # --vit_num:(int), the num of vit blocks in the last stage arch_settings = { 'l1': { 'layers': [3, 2, 6, 4], 'embed_dims': [48, 96, 224, 448], 'downsamples': [False, True, True, True], 'vit_num': 1, }, 'l3': { 'layers': [4, 4, 12, 6], 'embed_dims': [64, 128, 320, 512], 'downsamples': [False, True, True, True], 'vit_num': 4, }, 'l7': { 'layers': [6, 6, 18, 8], 'embed_dims': [96, 192, 384, 768], 'downsamples': [False, True, True, True], 'vit_num': 8, }, } def __init__(self, arch='l1', in_channels=3, pool_size=3, mlp_ratios=4, reshape_last_feat=False, out_indices=-1, frozen_stages=-1, act_cfg=dict(type='GELU'), drop_rate=0., drop_path_rate=0., use_layer_scale=True, init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_extra_tokens = 0 # no cls_token, no dist_token if isinstance(arch, str): assert arch in self.arch_settings, \ f'Unavailable arch, please choose from ' \ f'({set(self.arch_settings)}) or pass a dict.' arch = self.arch_settings[arch] elif isinstance(arch, dict): default_keys = set(self.arch_settings['l1'].keys()) assert set(arch.keys()) == default_keys, \ f'The arch dict must have {default_keys}, ' \ f'but got {list(arch.keys())}.' self.layers = arch['layers'] self.embed_dims = arch['embed_dims'] self.downsamples = arch['downsamples'] assert isinstance(self.layers, list) and isinstance( self.embed_dims, list) and isinstance(self.downsamples, list) assert len(self.layers) == len(self.embed_dims) == len( self.downsamples) self.vit_num = arch['vit_num'] self.reshape_last_feat = reshape_last_feat assert self.vit_num >= 0, "'vit_num' must be an integer " \ 'greater than or equal to 0.' assert self.vit_num <= self.layers[-1], ( "'vit_num' must be an integer smaller than layer number") self._make_stem(in_channels, self.embed_dims[0]) # set the main block in network network = [] for i in range(len(self.layers)): if i != 0: in_channels = self.embed_dims[i - 1] else: in_channels = self.embed_dims[i] out_channels = self.embed_dims[i] stage = basic_blocks( in_channels, out_channels, i, self.layers, pool_size=pool_size, mlp_ratio=mlp_ratios, act_cfg=act_cfg, drop_rate=drop_rate, drop_path_rate=drop_path_rate, vit_num=self.vit_num, use_layer_scale=use_layer_scale, has_downsamper=self.downsamples[i]) network.append(stage) self.network = ModuleList(network) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = 4 + index assert out_indices[i] >= 0, f'Invalid out_indices {index}' self.out_indices = out_indices for i_layer in self.out_indices: if not self.reshape_last_feat and \ i_layer == 3 and self.vit_num > 0: layer = build_norm_layer( dict(type='LN'), self.embed_dims[i_layer])[1] else: # use GN with 1 group as channel-first LN2D layer = build_norm_layer( dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1] layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self.frozen_stages = frozen_stages self._freeze_stages() def _make_stem(self, in_channels: int, stem_channels: int): """make 2-ConvBNReLu stem layer.""" self.patch_embed = Sequential( ConvModule( in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, bias=True, conv_cfg=None, norm_cfg=dict(type='BN'), inplace=True), ConvModule( stem_channels // 2, stem_channels, kernel_size=3, stride=2, padding=1, bias=True, conv_cfg=None, norm_cfg=dict(type='BN'), inplace=True)) def forward_tokens(self, x): outs = [] for idx, block in enumerate(self.network): if idx == len(self.network) - 1: N, _, H, W = x.shape if self.downsamples[idx]: H, W = H // 2, W // 2 x = block(x) if idx in self.out_indices: norm_layer = getattr(self, f'norm{idx}') if idx == len(self.network) - 1 and x.dim() == 3: # when ``vit-num`` > 0 and in the last stage, # if `self.reshape_last_feat`` is True, reshape the # features to `BCHW` format before the final normalization. # if `self.reshape_last_feat`` is False, do # normalization directly and permute the features to `BCN`. if self.reshape_last_feat: x = x.permute((0, 2, 1)).reshape(N, -1, H, W) x_out = norm_layer(x) else: x_out = norm_layer(x).permute((0, 2, 1)) else: x_out = norm_layer(x) outs.append(x_out.contiguous()) return tuple(outs)
[docs] def forward(self, x): # input embedding x = self.patch_embed(x) # through stages x = self.forward_tokens(x) return x
def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False for i in range(self.frozen_stages): # Include both block and downsample layer. module = self.network[i] module.eval() for param in module.parameters(): param.requires_grad = False if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') norm_layer.eval() for param in norm_layer.parameters(): param.requires_grad = False
[docs] def train(self, mode=True): super(EfficientFormer, self).train(mode) self._freeze_stages()
Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.