From 894066dbcd7a912da13dd4243ad74b7d992f3d93 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 22 Feb 2024 04:26:09 +0100 Subject: [PATCH] ENH more stable gradient of CrossEntropy (#6327) * ENH more stable gradient of CrossEntropy * FIX missing } * FIX index score * FIX missing parenthesis in hessian --------- Co-authored-by: shiyu1994 --- src/objective/xentropy_objective.hpp | 45 ++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/objective/xentropy_objective.hpp b/src/objective/xentropy_objective.hpp index 8edb09506..234e273c4 100644 --- a/src/objective/xentropy_objective.hpp +++ b/src/objective/xentropy_objective.hpp @@ -75,21 +75,54 @@ class CrossEntropy: public ObjectiveFunction { } void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { + // z = expit(score) = 1 / (1 + exp(-score)) + // gradient = z - label = expit(score) - label + // Numerically more stable, see http://fa.bianp.net/blog/2019/evaluate_logistic/ + // if score < 0: + // exp_tmp = exp(score) + // return ((1 - label) * exp_tmp - label) / (1 + exp_tmp) + // else: + // exp_tmp = exp(-score) + // return ((1 - label) - label * exp_tmp) / (1 + exp_tmp) + // Note that optimal speed would be achieved, at the cost of precision, by + // return expit(score) - y_true + // i.e. no "if else" and an own inline implementation of expit. + // The case distinction score < 0 in the stable implementation does not + // provide significant better precision apart from protecting overflow of exp(..). + // The branch (if else), however, can incur runtime costs of up to 30%. + // Instead, we help branch prediction by almost always ending in the first if clause + // and making the second branch (else) a bit simpler. This has the exact same + // precision but is faster than the stable implementation. + // As branching criteria, we use the same cutoff as in log1pexp, see link above. + // Note that the maximal value to get gradient = -1 with label = 1 is -37.439198610162731 + // (based on mpmath), and scipy.special.logit(np.finfo(float).eps) ~ -36.04365. if (weights_ == nullptr) { // compute pointwise gradients and Hessians with implied unit weights #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) for (data_size_t i = 0; i < num_data_; ++i) { - const double z = 1.0f / (1.0f + std::exp(-score[i])); - gradients[i] = static_cast(z - label_[i]); - hessians[i] = static_cast(z * (1.0f - z)); + if (score[i] > -37.0) { + const double exp_tmp = std::exp(-score[i]); + gradients[i] = static_cast(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp)); + hessians[i] = static_cast(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp))); + } else { + const double exp_tmp = std::exp(score[i]); + gradients[i] = static_cast(exp_tmp - label_[i]); + hessians[i] = static_cast(exp_tmp); + } } } else { // compute pointwise gradients and Hessians with given weights #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) for (data_size_t i = 0; i < num_data_; ++i) { - const double z = 1.0f / (1.0f + std::exp(-score[i])); - gradients[i] = static_cast((z - label_[i]) * weights_[i]); - hessians[i] = static_cast(z * (1.0f - z) * weights_[i]); + if (score[i] > -37.0) { + const double exp_tmp = std::exp(-score[i]); + gradients[i] = static_cast(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp) * weights_[i]); + hessians[i] = static_cast(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)) * weights_[i]); + } else { + const double exp_tmp = std::exp(score[i]); + gradients[i] = static_cast((exp_tmp - label_[i]) * weights_[i]); + hessians[i] = static_cast(exp_tmp * weights_[i]); + } } } }