Shortcuts

Note

You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

Source code for mmcls.models.backbones.t2t_vit

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList

from ..builder import BACKBONES
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone


class T2TTransformerLayer(BaseModule):
    """Transformer Layer for T2T_ViT.

    Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
    different ``input_dims`` and ``embed_dims``.

    Args:
        embed_dims (int): The feature dimension.
        num_heads (int): Parallel attention heads.
        feedforward_channels (int): The hidden dimension for FFNs
        input_dims (int, optional): The input token dimension.
            Defaults to None.
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        attn_drop_rate (float): The drop out rate for attention output weights.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        qkv_bias (bool): enable bias for qkv if True. Defaults to True.
        qk_scale (float, optional): Override default qk scale of
            ``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
        act_cfg (dict): The activation config for FFNs.
            Defaluts to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.

    Notes:
        In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
        ``(embed_dims // num_heads) ** -0.5``. However, in the official
        code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
        keep the same with the official implementation.
    """

    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 input_dims=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 num_fcs=2,
                 qkv_bias=False,
                 qk_scale=None,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='LN'),
                 init_cfg=None):
        super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)

        self.v_shortcut = True if input_dims is not None else False
        input_dims = input_dims or embed_dims

        self.norm1_name, norm1 = build_norm_layer(
            norm_cfg, input_dims, postfix=1)
        self.add_module(self.norm1_name, norm1)

        self.attn = MultiheadAttention(
            input_dims=input_dims,
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
            v_shortcut=self.v_shortcut)

        self.norm2_name, norm2 = build_norm_layer(
            norm_cfg, embed_dims, postfix=2)
        self.add_module(self.norm2_name, norm2)

        self.ffn = FFN(
            embed_dims=embed_dims,
            feedforward_channels=feedforward_channels,
            num_fcs=num_fcs,
            ffn_drop=drop_rate,
            dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
            act_cfg=act_cfg)

    @property
    def norm1(self):
        return getattr(self, self.norm1_name)

    @property
    def norm2(self):
        return getattr(self, self.norm2_name)

    def forward(self, x):
        if self.v_shortcut:
            x = self.attn(self.norm1(x))
        else:
            x = x + self.attn(self.norm1(x))
        x = self.ffn(self.norm2(x), identity=x)
        return x


class T2TModule(BaseModule):
    """Tokens-to-Token module.

    "Tokens-to-Token module" (T2T Module) can model the local structure
    information of images and reduce the length of tokens progressively.

    Args:
        img_size (int): Input image size
        in_channels (int): Number of input channels
        embed_dims (int): Embedding dimension
        token_dims (int): Tokens dimension in T2TModuleAttention.
        use_performer (bool): If True, use Performer version self-attention to
            adopt regular self-attention. Defaults to False.
        init_cfg (dict, optional): The extra config for initialization.
            Default: None.

    Notes:
        Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
        MACs
    """

    def __init__(
        self,
        img_size=224,
        in_channels=3,
        embed_dims=384,
        token_dims=64,
        use_performer=False,
        init_cfg=None,
    ):
        super(T2TModule, self).__init__(init_cfg)

        self.embed_dims = embed_dims

        self.soft_split0 = nn.Unfold(
            kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
        self.soft_split1 = nn.Unfold(
            kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.soft_split2 = nn.Unfold(
            kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

        if not use_performer:
            self.attention1 = T2TTransformerLayer(
                input_dims=in_channels * 7 * 7,
                embed_dims=token_dims,
                num_heads=1,
                feedforward_channels=token_dims)

            self.attention2 = T2TTransformerLayer(
                input_dims=token_dims * 3 * 3,
                embed_dims=token_dims,
                num_heads=1,
                feedforward_channels=token_dims)

            self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
        else:
            raise NotImplementedError("Performer hasn't been implemented.")

        # there are 3 soft split, stride are 4,2,2 separately
        out_side = img_size // (4 * 2 * 2)
        self.init_out_size = [out_side, out_side]
        self.num_patches = out_side**2

    @staticmethod
    def _get_unfold_size(unfold: nn.Unfold, input_size):
        h, w = input_size
        kernel_size = to_2tuple(unfold.kernel_size)
        stride = to_2tuple(unfold.stride)
        padding = to_2tuple(unfold.padding)
        dilation = to_2tuple(unfold.dilation)

        h_out = (h + 2 * padding[0] - dilation[0] *
                 (kernel_size[0] - 1) - 1) // stride[0] + 1
        w_out = (w + 2 * padding[1] - dilation[1] *
                 (kernel_size[1] - 1) - 1) // stride[1] + 1
        return (h_out, w_out)

    def forward(self, x):
        # step0: soft split
        hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
        x = self.soft_split0(x).transpose(1, 2)

        for step in [1, 2]:
            # re-structurization/reconstruction
            attn = getattr(self, f'attention{step}')
            x = attn(x).transpose(1, 2)
            B, C, _ = x.shape
            x = x.reshape(B, C, hw_shape[0], hw_shape[1])

            # soft split
            soft_split = getattr(self, f'soft_split{step}')
            hw_shape = self._get_unfold_size(soft_split, hw_shape)
            x = soft_split(x).transpose(1, 2)

        # final tokens
        x = self.project(x)
        return x, hw_shape


def get_sinusoid_encoding(n_position, embed_dims):
    """Generate sinusoid encoding table.

    Sinusoid encoding is a kind of relative position encoding method came from
    `Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
    Args:
        n_position (int): The length of the input token.
        embed_dims (int): The position embedding dimension.
    Returns:
        :obj:`torch.FloatTensor`: The sinusoid encoding table.
    """

    vec = torch.arange(embed_dims, dtype=torch.float64)
    vec = (vec - vec % 2) / embed_dims
    vec = torch.pow(10000, -vec).view(1, -1)

    sinusoid_table = torch.arange(n_position).view(-1, 1) * vec
    sinusoid_table[:, 0::2].sin_()  # dim 2i
    sinusoid_table[:, 1::2].cos_()  # dim 2i+1

    sinusoid_table = sinusoid_table.to(torch.float32)

    return sinusoid_table.unsqueeze(0)


[docs]@BACKBONES.register_module() class T2T_ViT(BaseBackbone): """Tokens-to-Token Vision Transformer (T2T-ViT) A PyTorch implementation of `Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_ Args: img_size (int | tuple): The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224. in_channels (int): Number of input channels. embed_dims (int): Embedding dimension. num_layers (int): Num of transformer layers in encoder. Defaults to 14. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. drop_rate (float): Dropout rate after position embedding. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. output_cls_token (bool): Whether output the cls_token. If set True, ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". t2t_cfg (dict): Extra config of Tokens-to-Token module. Defaults to an empty dict. layer_cfgs (Sequence | dict): Configs of each transformer layer in encoder. Defaults to an empty dict. init_cfg (dict, optional): The Config for initialization. Defaults to None. """ num_extra_tokens = 1 # cls_token def __init__(self, img_size=224, in_channels=3, embed_dims=384, num_layers=14, out_indices=-1, drop_rate=0., drop_path_rate=0., norm_cfg=dict(type='LN'), final_norm=True, with_cls_token=True, output_cls_token=True, interpolate_mode='bicubic', t2t_cfg=dict(), layer_cfgs=dict(), init_cfg=None): super(T2T_ViT, self).__init__(init_cfg) # Token-to-Token Module self.tokens_to_token = T2TModule( img_size=img_size, in_channels=in_channels, embed_dims=embed_dims, **t2t_cfg) self.patch_resolution = self.tokens_to_token.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] # Set cls token if output_cls_token: assert with_cls_token is True, f'with_cls_token must be True if' \ f'set output_cls_token to True, but got {with_cls_token}' self.with_cls_token = with_cls_token self.output_cls_token = output_cls_token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) # Set position embedding self.interpolate_mode = interpolate_mode sinusoid_table = get_sinusoid_encoding( num_patches + self.num_extra_tokens, embed_dims) self.register_buffer('pos_embed', sinusoid_table) self._register_load_state_dict_pre_hook(self._prepare_pos_embed) self.drop_after_pos = nn.Dropout(p=drop_rate) if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must be a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = num_layers + index assert 0 <= out_indices[i] <= num_layers, \ f'Invalid out_indices {index}' self.out_indices = out_indices # stochastic depth decay rule dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)] self.encoder = ModuleList() for i in range(num_layers): if isinstance(layer_cfgs, Sequence): layer_cfg = layer_cfgs[i] else: layer_cfg = deepcopy(layer_cfgs) layer_cfg = { 'embed_dims': embed_dims, 'num_heads': 6, 'feedforward_channels': 3 * embed_dims, 'drop_path_rate': dpr[i], 'qkv_bias': False, 'norm_cfg': norm_cfg, **layer_cfg } layer = T2TTransformerLayer(**layer_cfg) self.encoder.append(layer) self.final_norm = final_norm if final_norm: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = nn.Identity()
[docs] def init_weights(self): super().init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): # Suppress custom init if use pretrained model. return trunc_normal_(self.cls_token, std=.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): name = prefix + 'pos_embed' if name not in state_dict.keys(): return ckpt_pos_embed_shape = state_dict[name].shape if self.pos_embed.shape != ckpt_pos_embed_shape: from mmcls.utils import get_root_logger logger = get_root_logger() logger.info( f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' f'to {self.pos_embed.shape}.') ckpt_pos_embed_shape = to_2tuple( int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) pos_embed_shape = self.tokens_to_token.init_out_size state_dict[name] = resize_pos_embed(state_dict[name], ckpt_pos_embed_shape, pos_embed_shape, self.interpolate_mode, self.num_extra_tokens)
[docs] def forward(self, x): B = x.shape[0] x, patch_resolution = self.tokens_to_token(x) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + resize_pos_embed( self.pos_embed, self.patch_resolution, patch_resolution, mode=self.interpolate_mode, num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) if not self.with_cls_token: # Remove class token for transformer encoder input x = x[:, 1:] outs = [] for i, layer in enumerate(self.encoder): x = layer(x) if i == len(self.encoder) - 1 and self.final_norm: x = self.norm(x) if i in self.out_indices: B, _, C = x.shape if self.with_cls_token: patch_token = x[:, 1:].reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = x[:, 0] else: patch_token = x.reshape(B, *patch_resolution, C) patch_token = patch_token.permute(0, 3, 1, 2) cls_token = None if self.output_cls_token: out = [patch_token, cls_token] else: out = patch_token outs.append(out) return tuple(outs)
Read the Docs v: master
Versions
master
latest
1.x
dev-1.x
Downloads
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.