ENH more stable gradient of CrossEntropy (#6327)

* ENH more stable gradient of CrossEntropy

* FIX missing }

* FIX index score

* FIX missing parenthesis in hessian

---------

Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
This commit is contained in:
Christian Lorentzen 2024-02-22 04:26:09 +01:00 коммит произвёл GitHub
Родитель 1b792e7166
Коммит 894066dbcd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 39 добавлений и 6 удалений

Просмотреть файл

@ -75,21 +75,54 @@ class CrossEntropy: public ObjectiveFunction {
}
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
// z = expit(score) = 1 / (1 + exp(-score))
// gradient = z - label = expit(score) - label
// Numerically more stable, see http://fa.bianp.net/blog/2019/evaluate_logistic/
// if score < 0:
// exp_tmp = exp(score)
// return ((1 - label) * exp_tmp - label) / (1 + exp_tmp)
// else:
// exp_tmp = exp(-score)
// return ((1 - label) - label * exp_tmp) / (1 + exp_tmp)
// Note that optimal speed would be achieved, at the cost of precision, by
// return expit(score) - y_true
// i.e. no "if else" and an own inline implementation of expit.
// The case distinction score < 0 in the stable implementation does not
// provide significant better precision apart from protecting overflow of exp(..).
// The branch (if else), however, can incur runtime costs of up to 30%.
// Instead, we help branch prediction by almost always ending in the first if clause
// and making the second branch (else) a bit simpler. This has the exact same
// precision but is faster than the stable implementation.
// As branching criteria, we use the same cutoff as in log1pexp, see link above.
// Note that the maximal value to get gradient = -1 with label = 1 is -37.439198610162731
// (based on mpmath), and scipy.special.logit(np.finfo(float).eps) ~ -36.04365.
if (weights_ == nullptr) {
// compute pointwise gradients and Hessians with implied unit weights
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
const double z = 1.0f / (1.0f + std::exp(-score[i]));
gradients[i] = static_cast<score_t>(z - label_[i]);
hessians[i] = static_cast<score_t>(z * (1.0f - z));
if (score[i] > -37.0) {
const double exp_tmp = std::exp(-score[i]);
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp));
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)));
} else {
const double exp_tmp = std::exp(score[i]);
gradients[i] = static_cast<score_t>(exp_tmp - label_[i]);
hessians[i] = static_cast<score_t>(exp_tmp);
}
}
} else {
// compute pointwise gradients and Hessians with given weights
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
const double z = 1.0f / (1.0f + std::exp(-score[i]));
gradients[i] = static_cast<score_t>((z - label_[i]) * weights_[i]);
hessians[i] = static_cast<score_t>(z * (1.0f - z) * weights_[i]);
if (score[i] > -37.0) {
const double exp_tmp = std::exp(-score[i]);
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp) * weights_[i]);
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)) * weights_[i]);
} else {
const double exp_tmp = std::exp(score[i]);
gradients[i] = static_cast<score_t>((exp_tmp - label_[i]) * weights_[i]);
hessians[i] = static_cast<score_t>(exp_tmp * weights_[i]);
}
}
}
}