[python][R-package] warn users about untransformed values in case of custom obj (#2611)

This commit is contained in:
Nikita Titov 2019-12-05 16:53:13 +03:00 коммит произвёл GitHub
Родитель 6129208006
Коммит 69c1c33093
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 42 добавлений и 9 удалений

Просмотреть файл

@ -47,9 +47,16 @@ logregobj <- function(preds, dtrain) {
hess <- preds * (1.0 - preds)
return(list(grad = grad, hess = hess))
}
# User-defined evaluation function returns a pair (metric_name, result, higher_better)
# NOTE: when you do customized loss function, the default prediction value is margin
# This may make built-in evalution metric calculate wrong results
# For example, we are doing logistic loss, the prediction is score before logistic transformation
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0.0))) / length(labels)
preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(name = "error", value = err, higher_better = FALSE))
}

Просмотреть файл

@ -28,12 +28,12 @@ logregobj <- function(preds, dtrain) {
return(list(grad = grad, hess = hess))
}
# User defined evaluation function, return a pair metric_name, result, higher_better
# User-defined evaluation function returns a pair (metric_name, result, higher_better)
# NOTE: when you do customized loss function, the default prediction value is margin
# This may make buildin evalution metric not function properly
# This may make built-in evalution metric calculate wrong results
# For example, we are doing logistic loss, the prediction is score before logistic transformation
# The buildin evaluation error assumes input is after logistic transformation
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
# The built-in evaluation error assumes input is after logistic transformation
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)

Просмотреть файл

@ -14,9 +14,14 @@ logregobj <- function(preds, dtrain) {
return(list(grad = grad, hess = hess))
}
# User-defined evaluation function returns a pair (metric_name, result, higher_better)
# NOTE: when you do customized loss function, the default prediction value is margin
# This may make built-in evalution metric calculate wrong results
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
err <- as.numeric(sum(labels != (preds > 0.0))) / length(labels)
preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(
name = "error"
, value = err

Просмотреть файл

@ -147,8 +147,13 @@ def loglikelihood(preds, train_data):
# self-defined eval metric
# f(preds: array, train_data: Dataset) -> name: string, eval_result: float, is_higher_better: bool
# binary error
# NOTE: when you do customized loss function, the default prediction value is margin
# This may make built-in evalution metric calculate wrong results
# For example, we are doing log likelihood loss, the prediction is score before logistic transformation
# Keep this in mind when you use the customization
def binary_error(preds, train_data):
labels = train_data.get_label()
preds = 1. / (1. + np.exp(-preds))
return 'error', np.mean(labels != (preds > 0.5)), False
@ -166,8 +171,13 @@ print('Finished 40 - 50 rounds with self-defined objective function and eval met
# another self-defined eval metric
# f(preds: array, train_data: Dataset) -> name: string, eval_result: float, is_higher_better: bool
# accuracy
# NOTE: when you do customized loss function, the default prediction value is margin
# This may make built-in evalution metric calculate wrong results
# For example, we are doing log likelihood loss, the prediction is score before logistic transformation
# Keep this in mind when you use the customization
def accuracy(preds, train_data):
labels = train_data.get_label()
preds = 1. / (1. + np.exp(-preds))
return 'accuracy', np.mean(labels == (preds > 0.5)), True

Просмотреть файл

@ -2,6 +2,8 @@
"""Scikit-learn wrapper interface for LightGBM."""
from __future__ import absolute_import
import warnings
import numpy as np
from .basic import Dataset, LightGBMError, _ConfigAliases
@ -812,7 +814,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, num_iteration,
pred_leaf, pred_contrib, **kwargs)
if raw_score or pred_leaf or pred_contrib:
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result
else:
class_index = np.argmax(result, axis=1)
@ -861,7 +863,12 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
"""
result = super(LGBMClassifier, self).predict(X, raw_score, num_iteration,
pred_leaf, pred_contrib, **kwargs)
if self._n_classes > 2 or raw_score or pred_leaf or pred_contrib:
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
warnings.warn("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
"Returning raw scores instead.")
return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib:
return result
else:
return np.vstack((1. - result, result)).transpose()

Просмотреть файл

@ -131,7 +131,11 @@ class TestSklearn(unittest.TestCase):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMClassifier(n_estimators=50, silent=True, objective=logregobj)
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False)
ret = binary_error(y_test, gbm.predict(X_test))
# prediction result is actually not transformed (is raw) due to custom objective
y_pred_raw = gbm.predict_proba(X_test)
self.assertFalse(np.all(y_pred_raw >= 0))
y_pred = 1.0 / (1.0 + np.exp(-y_pred_raw))
ret = binary_error(y_test, y_pred)
self.assertLess(ret, 0.05)
def test_dart(self):