зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Fixed R implementation of upper_bound() and lower_bound() for lgb.Booster (#2785)
* [R-package] Fixed R implementation of upper_bound() and lower_bound() for lgb.Booster * [R-package] switched return type to double * fixed R tests on Booster upper_bound() and lower_bound() * fixed linting * moved numeric tolerance into a global constant
This commit is contained in:
Родитель
4adb9ff71f
Коммит
790c1e33e6
|
@ -322,9 +322,9 @@ Booster <- R6::R6Class(
|
|||
},
|
||||
|
||||
# Get upper bound
|
||||
upper_bound_ = function() {
|
||||
upper_bound = function() {
|
||||
|
||||
upper_bound <- 0L
|
||||
upper_bound <- 0.0
|
||||
lgb.call(
|
||||
"LGBM_BoosterGetUpperBoundValue_R"
|
||||
, ret = upper_bound
|
||||
|
@ -334,12 +334,12 @@ Booster <- R6::R6Class(
|
|||
},
|
||||
|
||||
# Get lower bound
|
||||
lower_bound_ = function() {
|
||||
lower_bound = function() {
|
||||
|
||||
lower_bound <- 0L
|
||||
lower_bound <- 0.0
|
||||
lgb.call(
|
||||
"LGBM_BoosterGetLowerBoundValue_R"
|
||||
, ret = upper_bound
|
||||
, ret = lower_bound
|
||||
, private$handle
|
||||
)
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ test <- agaricus.test
|
|||
|
||||
windows_flag <- grepl("Windows", Sys.info()[["sysname"]])
|
||||
|
||||
TOLERANCE <- 1e-6
|
||||
|
||||
test_that("train and predict binary classification", {
|
||||
nrounds <- 10L
|
||||
bst <- lightgbm(
|
||||
|
@ -28,7 +30,7 @@ test_that("train and predict binary classification", {
|
|||
expect_equal(length(pred1), 6513L)
|
||||
err_pred1 <- sum((pred1 > 0.5) != train$label) / length(train$label)
|
||||
err_log <- record_results[1L]
|
||||
expect_lt(abs(err_pred1 - err_log), 10e-6)
|
||||
expect_lt(abs(err_pred1 - err_log), TOLERANCE)
|
||||
})
|
||||
|
||||
|
||||
|
@ -70,6 +72,36 @@ test_that("use of multiple eval metrics works", {
|
|||
expect_false(is.null(bst$record_evals))
|
||||
})
|
||||
|
||||
test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for binary classification", {
|
||||
set.seed(708L)
|
||||
nrounds <- 10L
|
||||
bst <- lightgbm(
|
||||
data = train$data
|
||||
, label = train$label
|
||||
, num_leaves = 5L
|
||||
, nrounds = nrounds
|
||||
, objective = "binary"
|
||||
, metric = "binary_error"
|
||||
)
|
||||
expect_true(abs(bst$lower_bound() - -1.590853) < TOLERANCE)
|
||||
expect_true(abs(bst$upper_bound() - 1.871015) < TOLERANCE)
|
||||
})
|
||||
|
||||
test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for regression", {
|
||||
set.seed(708L)
|
||||
nrounds <- 10L
|
||||
bst <- lightgbm(
|
||||
data = train$data
|
||||
, label = train$label
|
||||
, num_leaves = 5L
|
||||
, nrounds = nrounds
|
||||
, objective = "regression"
|
||||
, metric = "l2"
|
||||
)
|
||||
expect_true(abs(bst$lower_bound() - 0.1513859) < TOLERANCE)
|
||||
expect_true(abs(bst$upper_bound() - 0.9080349) < TOLERANCE)
|
||||
})
|
||||
|
||||
test_that("lightgbm() rejects negative or 0 value passed to nrounds", {
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label)
|
||||
params <- list(objective = "regression", metric = "l2,l1")
|
||||
|
|
|
@ -336,6 +336,30 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
|
|||
LGBM_SE out,
|
||||
LGBM_SE call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get model upper bound value.
|
||||
* \param handle Handle of booster
|
||||
* \param[out] out_results Result pointing to max value
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetUpperBoundValue_R(
|
||||
LGBM_SE handle,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state
|
||||
);
|
||||
|
||||
/*!
|
||||
* \brief Get model lower bound value.
|
||||
* \param handle Handle of booster
|
||||
* \param[out] out_results Result pointing to min value
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterGetLowerBoundValue_R(
|
||||
LGBM_SE handle,
|
||||
LGBM_SE out_result,
|
||||
LGBM_SE call_state
|
||||
);
|
||||
|
||||
/*!
|
||||
* \brief Get Name of eval
|
||||
* \param eval_names eval names
|
||||
|
|
Загрузка…
Ссылка в новой задаче