Shortcuts

mmcls.apis

These are some high-level APIs for classification tasks.

mmcls.apis

Model

mmcls.apis.list_models(pattern=None)[source]

List all models available in MMClassification.

Parameters

pattern (str | None) – A wildcard pattern to match model names.

Returns

a list of model names.

Return type

List[str]

Examples

List all models:

>>> from mmcls import list_models
>>> print(list_models())

List ResNet-50 models on ImageNet-1k dataset:

>>> from mmcls import list_models
>>> print(list_models('resnet*in1k'))
['resnet50_8xb32_in1k',
 'resnet50_8xb32-fp16_in1k',
 'resnet50_8xb256-rsb-a1-600e_in1k',
 'resnet50_8xb256-rsb-a2-300e_in1k',
 'resnet50_8xb256-rsb-a3-100e_in1k']
mmcls.apis.get_model(model_name, pretrained=False, device=None, **kwargs)[source]

Get a pre-defined model by the name of model.

Parameters
  • model_name (str) – The name of model.

  • pretrained (bool | str) – If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False.

  • device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.

  • **kwargs – Other keyword arguments of the model config.

Returns

The result model.

Return type

mmengine.model.BaseModel

Examples

Get a ResNet-50 model and extract images feature:

>>> import torch
>>> from mmcls import get_model
>>> inputs = torch.rand(16, 3, 224, 224)
>>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
>>> feats = model.extract_feat(inputs)
>>> for feat in feats:
...     print(feat.shape)
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])
torch.Size([16, 2048])

Get Swin-Transformer model with pre-trained weights and inference:

>>> from mmcls import get_model, inference_model
>>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
>>> result = inference_model(model, 'demo/demo.JPEG')
>>> print(result['pred_class'])
'sea snake'
mmcls.apis.init_model(config, checkpoint=None, device=None, **kwargs)[source]

Initialize a classifier from config file.

Parameters
  • config (str | mmengine.Config) – Config file path or the config object.

  • checkpoint (str, optional) – Checkpoint path. If left as None, the model will not load any weights.

  • device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.

  • **kwargs – Other keyword arguments of the model config.

Returns

The constructed model.

Return type

nn.Module

Inference

mmcls.apis.inference_model(model, img)[source]

Inference image(s) with the classifier.

Parameters
  • model (BaseClassifier) – The loaded classifier.

  • img (str/ndarray) – The image filename or loaded image.

Returns

The classification results that contains

class_name, pred_label and pred_score.

Return type

result (dict)

Read the Docs v: dev-1.x
Versions
master
latest
1.x
dev-1.x
Downloads
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.