mmcls.apis¶
These are some high-level APIs for classification tasks.
Model¶
- mmcls.apis.list_models(pattern=None)[source]¶
List all models available in MMClassification.
- Parameters
pattern (str | None) – A wildcard pattern to match model names.
- Returns
a list of model names.
- Return type
List[str]
Examples
List all models:
>>> from mmcls import list_models >>> print(list_models())
List ResNet-50 models on ImageNet-1k dataset:
>>> from mmcls import list_models >>> print(list_models('resnet*in1k')) ['resnet50_8xb32_in1k', 'resnet50_8xb32-fp16_in1k', 'resnet50_8xb256-rsb-a1-600e_in1k', 'resnet50_8xb256-rsb-a2-300e_in1k', 'resnet50_8xb256-rsb-a3-100e_in1k']
- mmcls.apis.get_model(model_name, pretrained=False, device=None, **kwargs)[source]¶
Get a pre-defined model by the name of model.
- Parameters
model_name (str) – The name of model.
pretrained (bool | str) – If True, load the pre-defined pretrained weights. If a string, load the weights from it. Defaults to False.
device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.
**kwargs – Other keyword arguments of the model config.
- Returns
The result model.
- Return type
Examples
Get a ResNet-50 model and extract images feature:
>>> import torch >>> from mmcls import get_model >>> inputs = torch.rand(16, 3, 224, 224) >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3))) >>> feats = model.extract_feat(inputs) >>> for feat in feats: ... print(feat.shape) torch.Size([16, 256]) torch.Size([16, 512]) torch.Size([16, 1024]) torch.Size([16, 2048])
Get Swin-Transformer model with pre-trained weights and inference:
>>> from mmcls import get_model, inference_model >>> model = get_model('swin-base_16xb64_in1k', pretrained=True) >>> result = inference_model(model, 'demo/demo.JPEG') >>> print(result['pred_class']) 'sea snake'
- mmcls.apis.init_model(config, checkpoint=None, device=None, **kwargs)[source]¶
Initialize a classifier from config file.
- Parameters
config (str |
mmengine.Config
) – Config file path or the config object.checkpoint (str, optional) – Checkpoint path. If left as None, the model will not load any weights.
device (str | torch.device | None) – Transfer the model to the target device. Defaults to None.
**kwargs – Other keyword arguments of the model config.
- Returns
The constructed model.
- Return type
nn.Module
Inference¶
- mmcls.apis.inference_model(model, img)[source]¶
Inference image(s) with the classifier.
- Parameters
model (BaseClassifier) – The loaded classifier.
img (str/ndarray) – The image filename or loaded image.
- Returns
- The classification results that contains
class_name, pred_label and pred_score.
- Return type
result (dict)