зеркало из https://github.com/microsoft/LightGBM.git
[R-package] added tests on LGBM_BoosterResetTrainingData_R (#3020)
This commit is contained in:
Родитель
ad7f285154
Коммит
05d89a1a1f
|
@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", {
|
||||||
logloss <- bst$eval_train()[[1L]][["value"]]
|
logloss <- bst$eval_train()[[1L]][["value"]]
|
||||||
expect_equal(logloss, 0.027915146)
|
expect_equal(logloss, 0.027915146)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Booster$update() passing a train_set works as expected", {
|
||||||
|
set.seed(708L)
|
||||||
|
data(agaricus.train, package = "lightgbm")
|
||||||
|
nrounds <- 2L
|
||||||
|
|
||||||
|
# train with 2 rounds and then update
|
||||||
|
bst <- lightgbm(
|
||||||
|
data = as.matrix(agaricus.train$data)
|
||||||
|
, label = agaricus.train$label
|
||||||
|
, num_leaves = 4L
|
||||||
|
, learning_rate = 1.0
|
||||||
|
, nrounds = nrounds
|
||||||
|
, objective = "binary"
|
||||||
|
)
|
||||||
|
expect_true(lgb.is.Booster(bst))
|
||||||
|
expect_equal(bst$current_iter(), nrounds)
|
||||||
|
bst$update(
|
||||||
|
train_set = Dataset$new(
|
||||||
|
data = agaricus.train$data
|
||||||
|
, label = agaricus.train$label
|
||||||
|
)
|
||||||
|
)
|
||||||
|
expect_true(lgb.is.Booster(bst))
|
||||||
|
expect_equal(bst$current_iter(), nrounds + 1L)
|
||||||
|
|
||||||
|
# train with 3 rounds directlry
|
||||||
|
bst2 <- lightgbm(
|
||||||
|
data = as.matrix(agaricus.train$data)
|
||||||
|
, label = agaricus.train$label
|
||||||
|
, num_leaves = 4L
|
||||||
|
, learning_rate = 1.0
|
||||||
|
, nrounds = nrounds + 1L
|
||||||
|
, objective = "binary"
|
||||||
|
)
|
||||||
|
expect_true(lgb.is.Booster(bst2))
|
||||||
|
expect_equal(bst2$current_iter(), nrounds + 1L)
|
||||||
|
|
||||||
|
# model with 2 rounds + 1 update should be identical to 3 rounds
|
||||||
|
expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585)
|
||||||
|
expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]])
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", {
|
||||||
|
set.seed(708L)
|
||||||
|
data(agaricus.train, package = "lightgbm")
|
||||||
|
nrounds <- 2L
|
||||||
|
|
||||||
|
# train with 2 rounds and then update
|
||||||
|
bst <- lightgbm(
|
||||||
|
data = as.matrix(agaricus.train$data)
|
||||||
|
, label = agaricus.train$label
|
||||||
|
, num_leaves = 4L
|
||||||
|
, learning_rate = 1.0
|
||||||
|
, nrounds = nrounds
|
||||||
|
, objective = "binary"
|
||||||
|
)
|
||||||
|
expect_error({
|
||||||
|
bst$update(
|
||||||
|
train_set = data.frame(x = rnorm(10L))
|
||||||
|
)
|
||||||
|
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
|
||||||
|
})
|
||||||
|
|
Загрузка…
Ссылка в новой задаче