Shortcuts

T2T_ViT

class mmcls.models.T2T_ViT(img_size=224, in_channels=3, embed_dims=384, num_layers=14, out_indices=- 1, drop_rate=0.0, drop_path_rate=0.0, norm_cfg={'type': 'LN'}, final_norm=True, with_cls_token=True, output_cls_token=True, interpolate_mode='bicubic', t2t_cfg={}, layer_cfgs={}, init_cfg=None)[source]

Tokens-to-Token Vision Transformer (T2T-ViT)

A PyTorch implementation of Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

Parameters
  • img_size (int | tuple) – The expected input image shape. Because we support dynamic input shape, just set the argument to the most common input image shape. Defaults to 224.

  • in_channels (int) – Number of input channels.

  • embed_dims (int) – Embedding dimension.

  • num_layers (int) – Num of transformer layers in encoder. Defaults to 14.

  • out_indices (Sequence | int) – Output from which stages. Defaults to -1, means the last stage.

  • drop_rate (float) – Dropout rate after position embedding. Defaults to 0.

  • drop_path_rate (float) – stochastic depth rate. Defaults to 0.

  • norm_cfg (dict) – Config dict for normalization layer. Defaults to dict(type='LN').

  • final_norm (bool) – Whether to add a additional layer to normalize final feature map. Defaults to True.

  • with_cls_token (bool) – Whether concatenating class token into image tokens as transformer input. Defaults to True.

  • output_cls_token (bool) – Whether output the cls_token. If set True, with_cls_token must be True. Defaults to True.

  • interpolate_mode (str) – Select the interpolate mode for position embeding vector resize. Defaults to “bicubic”.

  • t2t_cfg (dict) – Extra config of Tokens-to-Token module. Defaults to an empty dict.

  • layer_cfgs (Sequence | dict) – Configs of each transformer layer in encoder. Defaults to an empty dict.

  • init_cfg (dict, optional) – The Config for initialization. Defaults to None.

forward(x)[source]

Forward computation.

Parameters

x (tensor | tuple[tensor]) – x could be a Torch.tensor or a tuple of Torch.tensor, containing input data for forward computation.

init_weights()[source]

Initialize the weights.

Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.