NeuronBlocks/losses/FocalLoss.py

66 строки
2.3 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
""" Focal loss
reference: Lin T Y, Goyal P, Girshick R, et al. Focal loss for dense object detection[J]. arXiv preprint arXiv:1708.02002, 2017.
Args:
gamma (float): gamma >= 0.
alpha (float): 0 <= alpha <= 1
size_average (bool, optional): By default, the losses are averaged over observations for each minibatch. However, if the field size_average is set to False, the losses are instead summed for each minibatch. Default is True
"""
def __init__(self, **kwargs):
super(FocalLoss, self).__init__()
# default parameters
self.gamma = 0
self.alpha = 0.5
self.size_average = True
for key in kwargs:
setattr(self, key, kwargs[key])
# varification
assert self.alpha <= 1 and self.alpha >= 0, "The parameter alpha in Focal Loss must be in range [0, 1]."
if self.alpha is not None:
self.alpha = torch.Tensor([self.alpha, 1 - self.alpha])
def forward(self, input, target):
""" Get focal loss
Args:
input (Variable): the prediction with shape [batch_size, number of classes]
target (Variable): the answer with shape [batch_size, number of classes]
Returns:
Variable (float): loss
"""
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1, 1)
logpt = F.log_softmax(input)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = logpt.data.exp()
if self.alpha is not None:
if self.alpha.type() != input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0, target.data.view(-1))
logpt = logpt * at
loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average:
return loss.mean()
else:
return loss.sum()