Shortcuts

Source code for mmpretrain.models.necks.itpn_neck

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule

from mmpretrain.models.backbones.hivit import BlockWithRPE
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import build_2d_sincos_position_embedding


class PatchSplit(nn.Module):
    """The up-sample module used in neck (transformer pyramid network)

    Args:
        dim (int): the input dimension (channel number).
        fpn_dim (int): the fpn dimension (channel number).
        norm_cfg (dict): Config dict for normalization layer.
                Defaults to ``dict(type='LN')``.
    """

    def __init__(self, dim, fpn_dim, norm_cfg):
        super().__init__()
        _, self.norm = build_norm_layer(norm_cfg, dim)
        self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False)
        self.fpn_dim = fpn_dim

    def forward(self, x):
        B, N, H, W, C = x.shape
        x = self.norm(x)
        x = self.reduction(x)
        x = x.reshape(B, N, H, W, 2, 2,
                      self.fpn_dim).permute(0, 1, 2, 4, 3, 5,
                                            6).reshape(B, N, 2 * H, 2 * W,
                                                       self.fpn_dim)
        return x


[docs]@MODELS.register_module() class iTPNPretrainDecoder(BaseModule): """The neck module of iTPN (transformer pyramid network). Args: num_patches (int): The number of total patches. Defaults to 196. patch_size (int): Image patch size. Defaults to 16. in_chans (int): The channel of input image. Defaults to 3. embed_dim (int): Encoder's embedding dimension. Defaults to 512. fpn_dim (int): The fpn dimension (channel number). fpn_depth (int): The layer number of feature pyramid. decoder_embed_dim (int): Decoder's embedding dimension. Defaults to 512. decoder_depth (int): The depth of decoder. Defaults to 8. decoder_num_heads (int): Number of attention heads of decoder. Defaults to 16. mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim. Defaults to 4. norm_cfg (dict): Normalization layer. Defaults to LayerNorm. reconstruction_type (str): The itpn supports 2 kinds of supervisions. Defaults to 'pixel'. num_outs (int): The output number of neck (transformer pyramid network). Defaults to 3. predict_feature_dim (int): The output dimension to supervision. Defaults to None. init_cfg (Union[List[dict], dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, num_patches: int = 196, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 512, fpn_dim: int = 256, fpn_depth: int = 2, decoder_embed_dim: int = 512, decoder_depth: int = 6, decoder_num_heads: int = 16, mlp_ratio: int = 4, norm_cfg: dict = dict(type='LN', eps=1e-6), reconstruction_type: str = 'pixel', num_outs: int = 3, qkv_bias: bool = True, qk_scale: Optional[bool] = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, predict_feature_dim: Optional[float] = None, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__(init_cfg=init_cfg) self.num_patches = num_patches assert reconstruction_type in ['pixel', 'clip'], \ 'iTPN method only support `pixel` and `clip`, ' \ f'but got `{reconstruction_type}`.' self.reconstruction_type = reconstruction_type self.num_outs = num_outs self.build_transformer_pyramid( num_outs=num_outs, embed_dim=embed_dim, fpn_dim=fpn_dim, fpn_depth=fpn_depth, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, rpe=False, norm_cfg=norm_cfg, ) # merge the output self.decoder_embed = nn.ModuleList() self.decoder_embed.append( nn.Sequential( nn.LayerNorm(fpn_dim), nn.Linear(fpn_dim, decoder_embed_dim, bias=True), )) if self.num_outs >= 2: self.decoder_embed.append( nn.Sequential( nn.LayerNorm(fpn_dim), nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True), )) if self.num_outs >= 3: self.decoder_embed.append( nn.Sequential( nn.LayerNorm(fpn_dim), nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True), )) if reconstruction_type == 'pixel': self.mask_token = nn.Parameter( torch.zeros(1, 1, decoder_embed_dim)) # create new position embedding, different from that in encoder # and is not learnable self.decoder_pos_embed = nn.Parameter( torch.zeros(1, self.num_patches, decoder_embed_dim), requires_grad=False) self.decoder_blocks = nn.ModuleList([ TransformerEncoderLayer( decoder_embed_dim, decoder_num_heads, int(mlp_ratio * decoder_embed_dim), qkv_bias=True, norm_cfg=norm_cfg) for _ in range(decoder_depth) ]) self.decoder_norm_name, decoder_norm = build_norm_layer( norm_cfg, decoder_embed_dim, postfix=1) self.add_module(self.decoder_norm_name, decoder_norm) # Used to map features to pixels if predict_feature_dim is None: predict_feature_dim = patch_size**2 * in_chans self.decoder_pred = nn.Linear( decoder_embed_dim, predict_feature_dim, bias=True) else: _, norm = build_norm_layer(norm_cfg, embed_dim) self.add_module('norm', norm) def build_transformer_pyramid(self, num_outs=3, embed_dim=512, fpn_dim=256, fpn_depth=2, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, rpe=False, norm_cfg=None): Hp = None mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim} if num_outs > 1: if embed_dim != fpn_dim: self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim) else: self.align_dim_16tofpn = None self.fpn_modules = nn.ModuleList() self.fpn_modules.append( BlockWithRPE( Hp, fpn_dim, 0, mlp_ratio, qkv_bias, qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., rpe=rpe, norm_cfg=norm_cfg)) self.fpn_modules.append( BlockWithRPE( Hp, fpn_dim, 0, mlp_ratio, qkv_bias, qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., rpe=False, norm_cfg=norm_cfg, )) self.align_dim_16to8 = nn.Linear( mlvl_dims['8'], fpn_dim, bias=False) self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg) self.block_16to8 = nn.Sequential(*[ BlockWithRPE( Hp, fpn_dim, 0, mlp_ratio, qkv_bias, qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., rpe=rpe, norm_cfg=norm_cfg, ) for _ in range(fpn_depth) ]) if num_outs > 2: self.align_dim_8to4 = nn.Linear( mlvl_dims['4'], fpn_dim, bias=False) self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg) self.block_8to4 = nn.Sequential(*[ BlockWithRPE( Hp, fpn_dim, 0, mlp_ratio, qkv_bias, qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., rpe=rpe, norm_cfg=norm_cfg, ) for _ in range(fpn_depth) ]) self.fpn_modules.append( BlockWithRPE( Hp, fpn_dim, 0, mlp_ratio, qkv_bias, qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., rpe=rpe, norm_cfg=norm_cfg))
[docs] def init_weights(self) -> None: """Initialize position embedding and mask token of MAE decoder.""" super().init_weights() if self.reconstruction_type == 'pixel': decoder_pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.decoder_pos_embed.shape[-1], cls_token=False) self.decoder_pos_embed.data.copy_(decoder_pos_embed.float()) torch.nn.init.normal_(self.mask_token, std=.02) else: self.rescale_init_weight()
[docs] def rescale_init_weight(self) -> None: """Rescale the initialized weights.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.fpn_modules): if isinstance(layer, BlockWithRPE): if layer.attn is not None: rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@property def decoder_norm(self): """The normalization layer of decoder.""" return getattr(self, self.decoder_norm_name)
[docs] def forward(self, x: torch.Tensor, ids_restore: torch.Tensor = None) -> torch.Tensor: """The forward function. The process computes the visible patches' features vectors and the mask tokens to output feature vectors, which will be used for reconstruction. Args: x (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C. ids_restore (torch.Tensor): ids to restore original image. Returns: torch.Tensor: The reconstructed feature vectors, which is of shape B x (num_patches) x C. """ features = x[:2] x = x[-1] B, L, _ = x.shape x = x[..., None, None, :] Hp = Wp = math.sqrt(L) outs = [x] if self.align_dim_16tofpn is None else [ self.align_dim_16tofpn(x) ] if self.num_outs >= 2: x = self.block_16to8( self.split_16to8(x) + self.align_dim_16to8(features[1])) outs.append(x) if self.num_outs >= 3: x = self.block_8to4( self.split_8to4(x) + self.align_dim_8to4(features[0])) outs.append(x) if self.num_outs > 3: outs = [ out.reshape(B, Hp, Wp, *out.shape[-3:]).permute( 0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3], Wp * out.shape[-2]).contiguous() for out in outs ] if self.num_outs >= 4: outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) if self.num_outs >= 5: outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2)) for i, out in enumerate(outs): out = self.fpn_modules[i](out) outs[i] = out if self.reconstruction_type == 'pixel': feats = [] for feat, layer in zip(outs, self.decoder_embed): x = layer(feat).reshape(B, L, -1) # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x = torch.cat([x, mask_tokens], dim=1) x = torch.gather( x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) feats.append(x) x = feats.pop(0) # add pos embed x = x + self.decoder_pos_embed for i, feat in enumerate(feats): x = x + feats[i] # apply Transformer blocks for i, blk in enumerate(self.decoder_blocks): x = blk(x) x = self.decoder_norm(x) x = self.decoder_pred(x) return x else: feats = [] for feat, layer in zip(outs, self.decoder_embed): x = layer(feat).reshape(B, L, -1) feats.append(x) x = feats.pop(0) for i, feat in enumerate(feats): x = x + feats[i] x = self.norm(x) return x
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.