Note

You are reading the documentation for MMClassification 0.x, which will soon be deprecated at the end of 2022. We recommend you upgrade to MMClassification 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check the installation tutorial, migration tutorial and changelog for more details.

# Source code for mmcls.models.losses.cross_entropy_loss

```
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The gt label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
# element-wise losses
loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def soft_cross_entropy(pred,
label,
weight=None,
reduction='mean',
class_weight=None,
avg_factor=None):
"""Calculate the Soft CrossEntropy loss. The label can be float.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The gt label of the prediction with shape (N, C).
When using "mixup", the label can be float.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
# element-wise losses
loss = -label * F.log_softmax(pred, dim=-1)
if class_weight is not None:
loss *= class_weight
loss = loss.sum(dim=-1)
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
pos_weight=None):
r"""Calculate the binary CrossEntropy loss with logits.
Args:
pred (torch.Tensor): The prediction with shape (N, \*).
label (torch.Tensor): The gt label with shape (N, \*).
weight (torch.Tensor, optional): Element-wise weight of loss with shape
(N, ). Defaults to None.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". If reduction is 'none' , loss
is same shape as pred and label. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (torch.Tensor, optional): The weight for each class with
shape (C), C is the number of classes. Default None.
pos_weight (torch.Tensor, optional): The positive weight for each
class with shape (C), C is the number of classes. Default None.
Returns:
torch.Tensor: The calculated loss
"""
# Ensure that the size of class_weight is consistent with pred and label to
# avoid automatic boracast,
assert pred.dim() == label.dim()
if class_weight is not None:
N = pred.size()[0]
class_weight = class_weight.repeat(N, 1)
loss = F.binary_cross_entropy_with_logits(
pred,
label,
weight=class_weight,
pos_weight=pos_weight,
reduction='none')
# apply weights and do the reduction
if weight is not None:
assert weight.dim() == 1
weight = weight.float()
if pred.dim() > 1:
weight = weight.reshape(-1, 1)
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
[docs]@LOSSES.register_module()
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss.
Args:
use_sigmoid (bool): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_soft (bool): Whether to use the soft version of CrossEntropyLoss.
Defaults to False.
reduction (str): The method used to reduce the loss.
Options are "none", "mean" and "sum". Defaults to 'mean'.
loss_weight (float): Weight of the loss. Defaults to 1.0.
class_weight (List[float], optional): The weight for each class with
shape (C), C is the number of classes. Default None.
pos_weight (List[float], optional): The positive weight for each
class with shape (C), C is the number of classes. Only enabled in
BCE loss when ``use_sigmoid`` is True. Default None.
"""
def __init__(self,
use_sigmoid=False,
use_soft=False,
reduction='mean',
loss_weight=1.0,
class_weight=None,
pos_weight=None):
super(CrossEntropyLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.use_soft = use_soft
assert not (
self.use_soft and self.use_sigmoid
), 'use_sigmoid and use_soft could not be set simultaneously'
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.pos_weight = pos_weight
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_soft:
self.cls_criterion = soft_cross_entropy
else:
self.cls_criterion = cross_entropy
[docs] def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# only BCE loss has pos_weight
if self.pos_weight is not None and self.use_sigmoid:
pos_weight = cls_score.new_tensor(self.pos_weight)
kwargs.update({'pos_weight': pos_weight})
else:
pos_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_cls
```