[R-package] allow access to params in Booster (#3662)

* [R-package] allow access to params in Booster

* remove unnecessary whitespace

* fix test on resetting params

* remove pytest_cache

* Update R-package/tests/testthat/test_custom_objective.R
This commit is contained in:
James Lamb 2021-01-03 04:26:17 +00:00 коммит произвёл GitHub
Родитель d7a384fa04
Коммит 532fa914e6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 153 добавлений и 6 удалений

Просмотреть файл

@ -6,6 +6,7 @@ Booster <- R6::R6Class(
best_iter = -1L,
best_score = NA_real_,
params = list(),
record_evals = list(),
# Finalize will free up the handles
@ -134,6 +135,8 @@ Booster <- R6::R6Class(
}
self$params <- params
},
# Set training data name
@ -187,17 +190,20 @@ Booster <- R6::R6Class(
# Reset parameters of booster
reset_parameter = function(params, ...) {
# Append parameters
params <- append(params, list(...))
if (methods::is(self$params, "list")) {
params <- modifyList(self$params, params)
}
params <- modifyList(params, list(...))
params_str <- lgb.params2str(params = params)
# Reset parameters
lgb.call(
fun_name = "LGBM_BoosterResetParameter_R"
, ret = NULL
, private$handle
, params_str
)
self$params <- params
return(invisible(self))

Просмотреть файл

@ -44,6 +44,7 @@ readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
# Restore best iteration and recorded evaluations
object2$best_iter <- object$best_iter
object2$record_evals <- object$record_evals
object2$params <- object$params
# Return newly loaded object
return(object2)

Просмотреть файл

@ -386,6 +386,75 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
})
test_that("Booster should store parameters and Booster$reset_parameter() should update them", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
)
# testing that this works for some cases that could break it:
# - multiple metrics
# - using "metric", "boosting", "num_class" in params
params <- list(
objective = "multiclass"
, max_depth = 4L
, bagging_fraction = 0.8
, metric = c("multi_logloss", "multi_error")
, boosting = "gbdt"
, num_class = 5L
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
expect_identical(bst$params, params)
params[["bagging_fraction"]] <- 0.9
ret_bst <- bst$reset_parameter(params = params)
expect_identical(ret_bst$params, params)
expect_identical(bst$params, params)
})
test_that("Booster$params should include dataset params, before and after Booster$reset_parameter()", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
, params = list(
max_bin = 17L
)
)
params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
expect_identical(
bst$params
, list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
, max_bin = 17L
)
)
params[["bagging_fraction"]] <- 0.9
ret_bst <- bst$reset_parameter(params = params)
expected_params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.9
, max_bin = 17L
)
expect_identical(ret_bst$params, expected_params)
expect_identical(bst$params, expected_params)
})
context("save_model")
test_that("Saving a model with different feature importance types works", {
@ -626,3 +695,38 @@ test_that("lgb.cv() correctly handles passing through params to the model file",
}
})
context("saveRDS.lgb.Booster() and readRDS.lgb.Booster()")
test_that("params (including dataset params) should be stored in .rds file for Booster", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
, params = list(
max_bin = 17L
)
)
params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
bst_file <- tempfile(fileext = ".rds")
saveRDS.lgb.Booster(bst, file = bst_file)
bst_from_file <- readRDS.lgb.Booster(file = bst_file)
expect_identical(
bst_from_file$params
, list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
, max_bin = 17L
)
)
})

Просмотреть файл

@ -297,13 +297,15 @@ class Booster {
void ResetConfig(const char* parameters) {
UNIQUE_LOCK(mutex_)
auto param = Config::Str2Map(parameters);
if (param.count("num_class")) {
Config new_config;
new_config.Set(param);
if (param.count("num_class") && new_config.num_class != config_.num_class) {
Log::Fatal("Cannot change num_class during training");
}
if (param.count("boosting")) {
if (param.count("boosting") && new_config.boosting != config_.boosting) {
Log::Fatal("Cannot change boosting during training");
}
if (param.count("metric")) {
if (param.count("metric") && new_config.metric != config_.metric) {
Log::Fatal("Cannot change metric during training");
}
CheckDatasetResetConfig(config_, param);

Просмотреть файл

@ -2609,3 +2609,37 @@ class TestEngine(unittest.TestCase):
lgb_X = lgb.Dataset(X, label=y)
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
self.assertAlmostEqual(res['training']['average_precision'][-1], 1)
def test_reset_params_works_with_metric_num_class_and_boosting(self):
X, y = load_breast_cancer(return_X_y=True)
params = {
'objective': 'multiclass',
'max_depth': 4,
'bagging_fraction': 0.8,
'metric': ['multi_logloss', 'multi_error'],
'boosting': 'gbdt',
'num_class': 5
}
dtrain = lgb.Dataset(X, y, params={"max_bin": 150})
bst = lgb.Booster(
params=params,
train_set=dtrain
)
expected_params = {
'objective': 'multiclass',
'max_depth': 4,
'bagging_fraction': 0.8,
'metric': ['multi_logloss', 'multi_error'],
'boosting': 'gbdt',
'num_class': 5,
'max_bin': 150
}
assert bst.params == expected_params
params['bagging_fraction'] = 0.9
ret_bst = bst.reset_parameter(params)
expected_params['bagging_fraction'] = 0.9
assert bst.params == expected_params
assert ret_bst.params == expected_params