Note
You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.
Source code for mmcls.models.losses.accuracy
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
import numpy as np
import torch
import torch.nn as nn
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be a number or tuple, but got {type(thrs)}.')
res = []
maxk = max(topk)
num = pred.shape[0]
static_inds = np.indices((num, maxk))[0]
pred_label = pred.argpartition(-maxk, axis=1)[:, -maxk:]
pred_score = pred[static_inds, pred_label]
sort_inds = np.argsort(pred_score, axis=1)[:, ::-1]
pred_label = pred_label[static_inds, sort_inds]
pred_score = pred_score[static_inds, sort_inds]
for k in topk:
correct_k = pred_label[:, :k] == target.reshape(-1, 1)
res_thr = []
for thr in thrs:
# Only prediction values larger than thr are counted as correct
_correct_k = correct_k & (pred_score[:, :k] > thr)
_correct_k = np.logical_or.reduce(_correct_k, axis=1)
res_thr.append((_correct_k.sum() * 100. / num))
if res_single:
res.append(res_thr[0])
else:
res.append(res_thr)
return res
def accuracy_torch(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be a number or tuple, but got {type(thrs)}.')
res = []
maxk = max(topk)
num = pred.size(0)
pred = pred.float()
pred_score, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t()
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
for k in topk:
res_thr = []
for thr in thrs:
# Only prediction values larger than thr are counted as correct
_correct = correct & (pred_score.t() > thr)
correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)
res_thr.append((correct_k.mul_(100. / num)))
if res_single:
res.append(res_thr[0])
else:
res.append(res_thr)
return res
def accuracy(pred, target, topk=1, thrs=0.):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction
topk (int | tuple[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]]: Accuracy
- torch.Tensor: If both ``topk`` and ``thrs`` is a single value.
- list[torch.Tensor]: If one of ``topk`` or ``thrs`` is a tuple.
- list[list[torch.Tensor]]: If both ``topk`` and ``thrs`` is a \
tuple. And the first dim is ``topk``, the second dim is ``thrs``.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
assert isinstance(pred, (torch.Tensor, np.ndarray)), \
f'The pred should be torch.Tensor or np.ndarray ' \
f'instead of {type(pred)}.'
assert isinstance(target, (torch.Tensor, np.ndarray)), \
f'The target should be torch.Tensor or np.ndarray ' \
f'instead of {type(target)}.'
# torch version is faster in most situations.
to_tensor = (lambda x: torch.from_numpy(x)
if isinstance(x, np.ndarray) else x)
pred = to_tensor(pred)
target = to_tensor(target)
res = accuracy_torch(pred, target, topk, thrs)
return res[0] if return_single else res
[docs]class Accuracy(nn.Module):
def __init__(self, topk=(1, )):
"""Module to calculate the accuracy.
Args:
topk (tuple): The criterion used to calculate the
accuracy. Defaults to (1,).
"""
super().__init__()
self.topk = topk
[docs] def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
list[torch.Tensor]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk)