[R-package] added tests on LGBM_BoosterResetTrainingData_R (#3020)

This commit is contained in:
James Lamb 2020-05-11 04:26:52 +01:00 коммит произвёл GitHub
Родитель ad7f285154
Коммит 05d89a1a1f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 63 добавлений и 0 удалений

Просмотреть файл

@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", {
logloss <- bst$eval_train()[[1L]][["value"]]
expect_equal(logloss, 0.027915146)
})
test_that("Booster$update() passing a train_set works as expected", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L
# train with 2 rounds and then update
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds)
bst$update(
train_set = Dataset$new(
data = agaricus.train$data
, label = agaricus.train$label
)
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds + 1L)
# train with 3 rounds directlry
bst2 <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds + 1L
, objective = "binary"
)
expect_true(lgb.is.Booster(bst2))
expect_equal(bst2$current_iter(), nrounds + 1L)
# model with 2 rounds + 1 update should be identical to 3 rounds
expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585)
expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]])
})
test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L
# train with 2 rounds and then update
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
expect_error({
bst$update(
train_set = data.frame(x = rnorm(10L))
)
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
})