зеркало из https://github.com/microsoft/LightGBM.git
[R-package] added tests on LGBM_BoosterResetTrainingData_R (#3020)
This commit is contained in:
Родитель
ad7f285154
Коммит
05d89a1a1f
|
@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", {
|
|||
logloss <- bst$eval_train()[[1L]][["value"]]
|
||||
expect_equal(logloss, 0.027915146)
|
||||
})
|
||||
|
||||
test_that("Booster$update() passing a train_set works as expected", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
nrounds <- 2L
|
||||
|
||||
# train with 2 rounds and then update
|
||||
bst <- lightgbm(
|
||||
data = as.matrix(agaricus.train$data)
|
||||
, label = agaricus.train$label
|
||||
, num_leaves = 4L
|
||||
, learning_rate = 1.0
|
||||
, nrounds = nrounds
|
||||
, objective = "binary"
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
expect_equal(bst$current_iter(), nrounds)
|
||||
bst$update(
|
||||
train_set = Dataset$new(
|
||||
data = agaricus.train$data
|
||||
, label = agaricus.train$label
|
||||
)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
expect_equal(bst$current_iter(), nrounds + 1L)
|
||||
|
||||
# train with 3 rounds directlry
|
||||
bst2 <- lightgbm(
|
||||
data = as.matrix(agaricus.train$data)
|
||||
, label = agaricus.train$label
|
||||
, num_leaves = 4L
|
||||
, learning_rate = 1.0
|
||||
, nrounds = nrounds + 1L
|
||||
, objective = "binary"
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst2))
|
||||
expect_equal(bst2$current_iter(), nrounds + 1L)
|
||||
|
||||
# model with 2 rounds + 1 update should be identical to 3 rounds
|
||||
expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585)
|
||||
expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]])
|
||||
})
|
||||
|
||||
test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
nrounds <- 2L
|
||||
|
||||
# train with 2 rounds and then update
|
||||
bst <- lightgbm(
|
||||
data = as.matrix(agaricus.train$data)
|
||||
, label = agaricus.train$label
|
||||
, num_leaves = 4L
|
||||
, learning_rate = 1.0
|
||||
, nrounds = nrounds
|
||||
, objective = "binary"
|
||||
)
|
||||
expect_error({
|
||||
bst$update(
|
||||
train_set = data.frame(x = rnorm(10L))
|
||||
)
|
||||
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
|
||||
})
|
||||
|
|
Загрузка…
Ссылка в новой задаче