Note
You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.
Source code for mmcls.models.losses.seesaw_loss
# Copyright (c) OpenMMLab. All rights reserved.
# migrate from mmdetection with modifications
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def seesaw_ce_loss(cls_score,
labels,
weight,
cum_samples,
num_classes,
p,
q,
eps,
reduction='mean',
avg_factor=None):
"""Calculate the Seesaw CrossEntropy loss.
Args:
cls_score (torch.Tensor): The prediction with shape (N, C),
C is the number of classes.
labels (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor): Sample-wise loss weight.
cum_samples (torch.Tensor): Cumulative samples for each category.
num_classes (int): The number of classes.
p (float): The ``p`` in the mitigation factor.
q (float): The ``q`` in the compenstation factor.
eps (float): The minimal value of divisor to smooth
the computation of compensation factor
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
Returns:
torch.Tensor: The calculated loss
"""
assert cls_score.size(-1) == num_classes
assert len(cum_samples) == num_classes
onehot_labels = F.one_hot(labels, num_classes)
seesaw_weights = cls_score.new_ones(onehot_labels.size())
# mitigation factor
if p > 0:
sample_ratio_matrix = cum_samples[None, :].clamp(
min=1) / cum_samples[:, None].clamp(min=1)
index = (sample_ratio_matrix < 1.0).float()
sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index
) # M_{ij}
mitigation_factor = sample_weights[labels.long(), :]
seesaw_weights = seesaw_weights * mitigation_factor
# compensation factor
if q > 0:
scores = F.softmax(cls_score.detach(), dim=1)
self_scores = scores[
torch.arange(0, len(scores)).to(scores.device).long(),
labels.long()]
score_matrix = scores / self_scores[:, None].clamp(min=eps)
index = (score_matrix > 1.0).float()
compensation_factor = score_matrix.pow(q) * index + (1 - index)
seesaw_weights = seesaw_weights * compensation_factor
cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
[docs]@LOSSES.register_module()
class SeesawLoss(nn.Module):
"""Implementation of seesaw loss.
Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
<https://arxiv.org/abs/2008.10032>`_
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid of softmax.
Only False is supported. Defaults to False.
p (float): The ``p`` in the mitigation factor.
Defaults to 0.8.
q (float): The ``q`` in the compenstation factor.
Defaults to 2.0.
num_classes (int): The number of classes.
Default to 1000 for the ImageNet dataset.
eps (float): The minimal value of divisor to smooth
the computation of compensation factor, default to 1e-2.
reduction (str): The method that reduces the loss to a scalar.
Options are "none", "mean" and "sum". Default to "mean".
loss_weight (float): The weight of the loss. Defaults to 1.0
"""
def __init__(self,
use_sigmoid=False,
p=0.8,
q=2.0,
num_classes=1000,
eps=1e-2,
reduction='mean',
loss_weight=1.0):
super(SeesawLoss, self).__init__()
assert not use_sigmoid, '`use_sigmoid` is not supported'
self.use_sigmoid = False
self.p = p
self.q = q
self.num_classes = num_classes
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
self.cls_criterion = seesaw_ce_loss
# cumulative samples for each category
self.register_buffer('cum_samples',
torch.zeros(self.num_classes, dtype=torch.float))
[docs] def forward(self,
cls_score,
labels,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction with shape (N, C).
labels (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum'), \
f'The `reduction_override` should be one of (None, "none", ' \
f'"mean", "sum"), but get "{reduction_override}".'
assert cls_score.size(0) == labels.view(-1).size(0), \
f'Expected `labels` shape [{cls_score.size(0)}], ' \
f'but got {list(labels.size())}'
reduction = (
reduction_override if reduction_override else self.reduction)
assert cls_score.size(-1) == self.num_classes, \
f'The channel number of output ({cls_score.size(-1)}) does ' \
f'not match the `num_classes` of seesaw loss ({self.num_classes}).'
# accumulate the samples for each category
unique_labels = labels.unique()
for u_l in unique_labels:
inds_ = labels == u_l.item()
self.cum_samples[u_l] += inds_.sum()
if weight is not None:
weight = weight.float()
else:
weight = labels.new_ones(labels.size(), dtype=torch.float)
# calculate loss_cls_classes
loss_cls = self.loss_weight * self.cls_criterion(
cls_score, labels, weight, self.cum_samples, self.num_classes,
self.p, self.q, self.eps, reduction, avg_factor)
return loss_cls