Shortcuts

Note

You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

Source code for mmcls.models.losses.seesaw_loss

# Copyright (c) OpenMMLab. All rights reserved.
# migrate from mmdetection with modifications
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss


def seesaw_ce_loss(cls_score,
                   labels,
                   weight,
                   cum_samples,
                   num_classes,
                   p,
                   q,
                   eps,
                   reduction='mean',
                   avg_factor=None):
    """Calculate the Seesaw CrossEntropy loss.

    Args:
        cls_score (torch.Tensor): The prediction with shape (N, C),
             C is the number of classes.
        labels (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor): Sample-wise loss weight.
        cum_samples (torch.Tensor): Cumulative samples for each category.
        num_classes (int): The number of classes.
        p (float): The ``p`` in the mitigation factor.
        q (float): The ``q`` in the compenstation factor.
        eps (float): The minimal value of divisor to smooth
             the computation of compensation factor
        reduction (str, optional): The method used to reduce the loss.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.

    Returns:
        torch.Tensor: The calculated loss
    """
    assert cls_score.size(-1) == num_classes
    assert len(cum_samples) == num_classes

    onehot_labels = F.one_hot(labels, num_classes)
    seesaw_weights = cls_score.new_ones(onehot_labels.size())

    # mitigation factor
    if p > 0:
        sample_ratio_matrix = cum_samples[None, :].clamp(
            min=1) / cum_samples[:, None].clamp(min=1)
        index = (sample_ratio_matrix < 1.0).float()
        sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index
                                                               )  # M_{ij}
        mitigation_factor = sample_weights[labels.long(), :]
        seesaw_weights = seesaw_weights * mitigation_factor

    # compensation factor
    if q > 0:
        scores = F.softmax(cls_score.detach(), dim=1)
        self_scores = scores[
            torch.arange(0, len(scores)).to(scores.device).long(),
            labels.long()]
        score_matrix = scores / self_scores[:, None].clamp(min=eps)
        index = (score_matrix > 1.0).float()
        compensation_factor = score_matrix.pow(q) * index + (1 - index)
        seesaw_weights = seesaw_weights * compensation_factor

    cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))

    loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')

    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
    return loss


[docs]@LOSSES.register_module() class SeesawLoss(nn.Module): """Implementation of seesaw loss. Refers to `Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021) <https://arxiv.org/abs/2008.10032>`_ Args: use_sigmoid (bool): Whether the prediction uses sigmoid of softmax. Only False is supported. Defaults to False. p (float): The ``p`` in the mitigation factor. Defaults to 0.8. q (float): The ``q`` in the compenstation factor. Defaults to 2.0. num_classes (int): The number of classes. Default to 1000 for the ImageNet dataset. eps (float): The minimal value of divisor to smooth the computation of compensation factor, default to 1e-2. reduction (str): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". Default to "mean". loss_weight (float): The weight of the loss. Defaults to 1.0 """ def __init__(self, use_sigmoid=False, p=0.8, q=2.0, num_classes=1000, eps=1e-2, reduction='mean', loss_weight=1.0): super(SeesawLoss, self).__init__() assert not use_sigmoid, '`use_sigmoid` is not supported' self.use_sigmoid = False self.p = p self.q = q self.num_classes = num_classes self.eps = eps self.reduction = reduction self.loss_weight = loss_weight self.cls_criterion = seesaw_ce_loss # cumulative samples for each category self.register_buffer('cum_samples', torch.zeros(self.num_classes, dtype=torch.float))
[docs] def forward(self, cls_score, labels, weight=None, avg_factor=None, reduction_override=None): """Forward function. Args: cls_score (torch.Tensor): The prediction with shape (N, C). labels (torch.Tensor): The learning label of the prediction. weight (torch.Tensor, optional): Sample-wise loss weight. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum'), \ f'The `reduction_override` should be one of (None, "none", ' \ f'"mean", "sum"), but get "{reduction_override}".' assert cls_score.size(0) == labels.view(-1).size(0), \ f'Expected `labels` shape [{cls_score.size(0)}], ' \ f'but got {list(labels.size())}' reduction = ( reduction_override if reduction_override else self.reduction) assert cls_score.size(-1) == self.num_classes, \ f'The channel number of output ({cls_score.size(-1)}) does ' \ f'not match the `num_classes` of seesaw loss ({self.num_classes}).' # accumulate the samples for each category unique_labels = labels.unique() for u_l in unique_labels: inds_ = labels == u_l.item() self.cum_samples[u_l] += inds_.sum() if weight is not None: weight = weight.float() else: weight = labels.new_ones(labels.size(), dtype=torch.float) # calculate loss_cls_classes loss_cls = self.loss_weight * self.cls_criterion( cls_score, labels, weight, self.cum_samples, self.num_classes, self.p, self.q, self.eps, reduction, avg_factor) return loss_cls
Read the Docs v: latest
Versions
master
latest
1.x
dev-1.x
Downloads
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.