зеркало из https://github.com/microsoft/LightGBM.git
[python][R-package] warn users about untransformed values in case of custom obj (#2611)
This commit is contained in:
Родитель
6129208006
Коммит
69c1c33093
|
@ -47,9 +47,16 @@ logregobj <- function(preds, dtrain) {
|
|||
hess <- preds * (1.0 - preds)
|
||||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
|
||||
# User-defined evaluation function returns a pair (metric_name, result, higher_better)
|
||||
# NOTE: when you do customized loss function, the default prediction value is margin
|
||||
# This may make built-in evalution metric calculate wrong results
|
||||
# For example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
evalerror <- function(preds, dtrain) {
|
||||
labels <- getinfo(dtrain, "label")
|
||||
err <- as.numeric(sum(labels != (preds > 0.0))) / length(labels)
|
||||
preds <- 1.0 / (1.0 + exp(-preds))
|
||||
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
|
||||
return(list(name = "error", value = err, higher_better = FALSE))
|
||||
}
|
||||
|
||||
|
|
|
@ -28,12 +28,12 @@ logregobj <- function(preds, dtrain) {
|
|||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
|
||||
# User defined evaluation function, return a pair metric_name, result, higher_better
|
||||
# User-defined evaluation function returns a pair (metric_name, result, higher_better)
|
||||
# NOTE: when you do customized loss function, the default prediction value is margin
|
||||
# This may make buildin evalution metric not function properly
|
||||
# This may make built-in evalution metric calculate wrong results
|
||||
# For example, we are doing logistic loss, the prediction is score before logistic transformation
|
||||
# The buildin evaluation error assumes input is after logistic transformation
|
||||
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
# The built-in evaluation error assumes input is after logistic transformation
|
||||
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
evalerror <- function(preds, dtrain) {
|
||||
labels <- getinfo(dtrain, "label")
|
||||
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
|
||||
|
|
|
@ -14,9 +14,14 @@ logregobj <- function(preds, dtrain) {
|
|||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
|
||||
# User-defined evaluation function returns a pair (metric_name, result, higher_better)
|
||||
# NOTE: when you do customized loss function, the default prediction value is margin
|
||||
# This may make built-in evalution metric calculate wrong results
|
||||
# Keep this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
evalerror <- function(preds, dtrain) {
|
||||
labels <- getinfo(dtrain, "label")
|
||||
err <- as.numeric(sum(labels != (preds > 0.0))) / length(labels)
|
||||
preds <- 1.0 / (1.0 + exp(-preds))
|
||||
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
|
||||
return(list(
|
||||
name = "error"
|
||||
, value = err
|
||||
|
|
|
@ -147,8 +147,13 @@ def loglikelihood(preds, train_data):
|
|||
# self-defined eval metric
|
||||
# f(preds: array, train_data: Dataset) -> name: string, eval_result: float, is_higher_better: bool
|
||||
# binary error
|
||||
# NOTE: when you do customized loss function, the default prediction value is margin
|
||||
# This may make built-in evalution metric calculate wrong results
|
||||
# For example, we are doing log likelihood loss, the prediction is score before logistic transformation
|
||||
# Keep this in mind when you use the customization
|
||||
def binary_error(preds, train_data):
|
||||
labels = train_data.get_label()
|
||||
preds = 1. / (1. + np.exp(-preds))
|
||||
return 'error', np.mean(labels != (preds > 0.5)), False
|
||||
|
||||
|
||||
|
@ -166,8 +171,13 @@ print('Finished 40 - 50 rounds with self-defined objective function and eval met
|
|||
# another self-defined eval metric
|
||||
# f(preds: array, train_data: Dataset) -> name: string, eval_result: float, is_higher_better: bool
|
||||
# accuracy
|
||||
# NOTE: when you do customized loss function, the default prediction value is margin
|
||||
# This may make built-in evalution metric calculate wrong results
|
||||
# For example, we are doing log likelihood loss, the prediction is score before logistic transformation
|
||||
# Keep this in mind when you use the customization
|
||||
def accuracy(preds, train_data):
|
||||
labels = train_data.get_label()
|
||||
preds = 1. / (1. + np.exp(-preds))
|
||||
return 'accuracy', np.mean(labels == (preds > 0.5)), True
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
"""Scikit-learn wrapper interface for LightGBM."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basic import Dataset, LightGBMError, _ConfigAliases
|
||||
|
@ -812,7 +814,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
|||
"""Docstring is inherited from the LGBMModel."""
|
||||
result = self.predict_proba(X, raw_score, num_iteration,
|
||||
pred_leaf, pred_contrib, **kwargs)
|
||||
if raw_score or pred_leaf or pred_contrib:
|
||||
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
|
||||
return result
|
||||
else:
|
||||
class_index = np.argmax(result, axis=1)
|
||||
|
@ -861,7 +863,12 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
|||
"""
|
||||
result = super(LGBMClassifier, self).predict(X, raw_score, num_iteration,
|
||||
pred_leaf, pred_contrib, **kwargs)
|
||||
if self._n_classes > 2 or raw_score or pred_leaf or pred_contrib:
|
||||
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
|
||||
warnings.warn("Cannot compute class probabilities or labels "
|
||||
"due to the usage of customized objective function.\n"
|
||||
"Returning raw scores instead.")
|
||||
return result
|
||||
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib:
|
||||
return result
|
||||
else:
|
||||
return np.vstack((1. - result, result)).transpose()
|
||||
|
|
|
@ -131,7 +131,11 @@ class TestSklearn(unittest.TestCase):
|
|||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
|
||||
gbm = lgb.LGBMClassifier(n_estimators=50, silent=True, objective=logregobj)
|
||||
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False)
|
||||
ret = binary_error(y_test, gbm.predict(X_test))
|
||||
# prediction result is actually not transformed (is raw) due to custom objective
|
||||
y_pred_raw = gbm.predict_proba(X_test)
|
||||
self.assertFalse(np.all(y_pred_raw >= 0))
|
||||
y_pred = 1.0 / (1.0 + np.exp(-y_pred_raw))
|
||||
ret = binary_error(y_test, y_pred)
|
||||
self.assertLess(ret, 0.05)
|
||||
|
||||
def test_dart(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче