зеркало из https://github.com/microsoft/LightGBM.git
[R-package] allow access to params in Booster (#3662)
* [R-package] allow access to params in Booster * remove unnecessary whitespace * fix test on resetting params * remove pytest_cache * Update R-package/tests/testthat/test_custom_objective.R
This commit is contained in:
Родитель
d7a384fa04
Коммит
532fa914e6
|
@ -6,6 +6,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
best_iter = -1L,
|
||||
best_score = NA_real_,
|
||||
params = list(),
|
||||
record_evals = list(),
|
||||
|
||||
# Finalize will free up the handles
|
||||
|
@ -134,6 +135,8 @@ Booster <- R6::R6Class(
|
|||
|
||||
}
|
||||
|
||||
self$params <- params
|
||||
|
||||
},
|
||||
|
||||
# Set training data name
|
||||
|
@ -187,17 +190,20 @@ Booster <- R6::R6Class(
|
|||
# Reset parameters of booster
|
||||
reset_parameter = function(params, ...) {
|
||||
|
||||
# Append parameters
|
||||
params <- append(params, list(...))
|
||||
if (methods::is(self$params, "list")) {
|
||||
params <- modifyList(self$params, params)
|
||||
}
|
||||
|
||||
params <- modifyList(params, list(...))
|
||||
params_str <- lgb.params2str(params = params)
|
||||
|
||||
# Reset parameters
|
||||
lgb.call(
|
||||
fun_name = "LGBM_BoosterResetParameter_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, params_str
|
||||
)
|
||||
self$params <- params
|
||||
|
||||
return(invisible(self))
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
|
|||
# Restore best iteration and recorded evaluations
|
||||
object2$best_iter <- object$best_iter
|
||||
object2$record_evals <- object$record_evals
|
||||
object2$params <- object$params
|
||||
|
||||
# Return newly loaded object
|
||||
return(object2)
|
||||
|
|
|
@ -386,6 +386,75 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
|
|||
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
|
||||
})
|
||||
|
||||
test_that("Booster should store parameters and Booster$reset_parameter() should update them", {
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
dtrain <- lgb.Dataset(
|
||||
agaricus.train$data
|
||||
, label = agaricus.train$label
|
||||
)
|
||||
# testing that this works for some cases that could break it:
|
||||
# - multiple metrics
|
||||
# - using "metric", "boosting", "num_class" in params
|
||||
params <- list(
|
||||
objective = "multiclass"
|
||||
, max_depth = 4L
|
||||
, bagging_fraction = 0.8
|
||||
, metric = c("multi_logloss", "multi_error")
|
||||
, boosting = "gbdt"
|
||||
, num_class = 5L
|
||||
)
|
||||
bst <- Booster$new(
|
||||
params = params
|
||||
, train_set = dtrain
|
||||
)
|
||||
expect_identical(bst$params, params)
|
||||
|
||||
params[["bagging_fraction"]] <- 0.9
|
||||
ret_bst <- bst$reset_parameter(params = params)
|
||||
expect_identical(ret_bst$params, params)
|
||||
expect_identical(bst$params, params)
|
||||
})
|
||||
|
||||
test_that("Booster$params should include dataset params, before and after Booster$reset_parameter()", {
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
dtrain <- lgb.Dataset(
|
||||
agaricus.train$data
|
||||
, label = agaricus.train$label
|
||||
, params = list(
|
||||
max_bin = 17L
|
||||
)
|
||||
)
|
||||
params <- list(
|
||||
objective = "binary"
|
||||
, max_depth = 4L
|
||||
, bagging_fraction = 0.8
|
||||
)
|
||||
bst <- Booster$new(
|
||||
params = params
|
||||
, train_set = dtrain
|
||||
)
|
||||
expect_identical(
|
||||
bst$params
|
||||
, list(
|
||||
objective = "binary"
|
||||
, max_depth = 4L
|
||||
, bagging_fraction = 0.8
|
||||
, max_bin = 17L
|
||||
)
|
||||
)
|
||||
|
||||
params[["bagging_fraction"]] <- 0.9
|
||||
ret_bst <- bst$reset_parameter(params = params)
|
||||
expected_params <- list(
|
||||
objective = "binary"
|
||||
, max_depth = 4L
|
||||
, bagging_fraction = 0.9
|
||||
, max_bin = 17L
|
||||
)
|
||||
expect_identical(ret_bst$params, expected_params)
|
||||
expect_identical(bst$params, expected_params)
|
||||
})
|
||||
|
||||
context("save_model")
|
||||
|
||||
test_that("Saving a model with different feature importance types works", {
|
||||
|
@ -626,3 +695,38 @@ test_that("lgb.cv() correctly handles passing through params to the model file",
|
|||
}
|
||||
|
||||
})
|
||||
|
||||
context("saveRDS.lgb.Booster() and readRDS.lgb.Booster()")
|
||||
|
||||
test_that("params (including dataset params) should be stored in .rds file for Booster", {
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
dtrain <- lgb.Dataset(
|
||||
agaricus.train$data
|
||||
, label = agaricus.train$label
|
||||
, params = list(
|
||||
max_bin = 17L
|
||||
)
|
||||
)
|
||||
params <- list(
|
||||
objective = "binary"
|
||||
, max_depth = 4L
|
||||
, bagging_fraction = 0.8
|
||||
)
|
||||
bst <- Booster$new(
|
||||
params = params
|
||||
, train_set = dtrain
|
||||
)
|
||||
bst_file <- tempfile(fileext = ".rds")
|
||||
saveRDS.lgb.Booster(bst, file = bst_file)
|
||||
|
||||
bst_from_file <- readRDS.lgb.Booster(file = bst_file)
|
||||
expect_identical(
|
||||
bst_from_file$params
|
||||
, list(
|
||||
objective = "binary"
|
||||
, max_depth = 4L
|
||||
, bagging_fraction = 0.8
|
||||
, max_bin = 17L
|
||||
)
|
||||
)
|
||||
})
|
||||
|
|
|
@ -297,13 +297,15 @@ class Booster {
|
|||
void ResetConfig(const char* parameters) {
|
||||
UNIQUE_LOCK(mutex_)
|
||||
auto param = Config::Str2Map(parameters);
|
||||
if (param.count("num_class")) {
|
||||
Config new_config;
|
||||
new_config.Set(param);
|
||||
if (param.count("num_class") && new_config.num_class != config_.num_class) {
|
||||
Log::Fatal("Cannot change num_class during training");
|
||||
}
|
||||
if (param.count("boosting")) {
|
||||
if (param.count("boosting") && new_config.boosting != config_.boosting) {
|
||||
Log::Fatal("Cannot change boosting during training");
|
||||
}
|
||||
if (param.count("metric")) {
|
||||
if (param.count("metric") && new_config.metric != config_.metric) {
|
||||
Log::Fatal("Cannot change metric during training");
|
||||
}
|
||||
CheckDatasetResetConfig(config_, param);
|
||||
|
|
|
@ -2609,3 +2609,37 @@ class TestEngine(unittest.TestCase):
|
|||
lgb_X = lgb.Dataset(X, label=y)
|
||||
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
|
||||
self.assertAlmostEqual(res['training']['average_precision'][-1], 1)
|
||||
|
||||
def test_reset_params_works_with_metric_num_class_and_boosting(self):
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
params = {
|
||||
'objective': 'multiclass',
|
||||
'max_depth': 4,
|
||||
'bagging_fraction': 0.8,
|
||||
'metric': ['multi_logloss', 'multi_error'],
|
||||
'boosting': 'gbdt',
|
||||
'num_class': 5
|
||||
}
|
||||
dtrain = lgb.Dataset(X, y, params={"max_bin": 150})
|
||||
|
||||
bst = lgb.Booster(
|
||||
params=params,
|
||||
train_set=dtrain
|
||||
)
|
||||
expected_params = {
|
||||
'objective': 'multiclass',
|
||||
'max_depth': 4,
|
||||
'bagging_fraction': 0.8,
|
||||
'metric': ['multi_logloss', 'multi_error'],
|
||||
'boosting': 'gbdt',
|
||||
'num_class': 5,
|
||||
'max_bin': 150
|
||||
}
|
||||
assert bst.params == expected_params
|
||||
|
||||
params['bagging_fraction'] = 0.9
|
||||
ret_bst = bst.reset_parameter(params)
|
||||
|
||||
expected_params['bagging_fraction'] = 0.9
|
||||
assert bst.params == expected_params
|
||||
assert ret_bst.params == expected_params
|
||||
|
|
Загрузка…
Ссылка в новой задаче