[R-package] construct dataset earlier in lgb.train and lgb.cv (fixes #3583) (#3598)

* construct dataset earlier in lgb.train and lgb.cv

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.cv.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.train.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* fixing lint issues

* styling updates

* fix failing test

Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Tony Kenny 2020-12-01 02:01:02 +00:00 коммит произвёл GitHub
Родитель c02917e493
Коммит 9597326eec
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 68 добавлений и 5 удалений

Просмотреть файл

@ -164,6 +164,10 @@ lgb.cv <- function(params = list()
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L
# Construct datasets, if needed
data$update_params(params = params)
data$construct()
# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
@ -194,9 +198,6 @@ lgb.cv <- function(params = list()
data$set_categorical_feature(categorical_feature)
}
# Construct datasets, if needed
data$construct()
# Check for folds
if (!is.null(folds)) {

Просмотреть файл

@ -142,6 +142,10 @@ lgb.train <- function(params = list(),
}
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L
# Construct datasets, if needed
data$update_params(params = params)
data$construct()
# Check interaction constraints
cnames <- NULL
if (!is.null(colnames)) {
@ -167,8 +171,6 @@ lgb.train <- function(params = list(),
data$set_categorical_feature(categorical_feature)
}
# Construct datasets, if needed
data$construct()
valid_contain_train <- FALSE
train_data_name <- "train"
reduced_valid_sets <- list()

Просмотреть файл

@ -205,3 +205,63 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame
expect_identical(new_params[[param_name]], updated_params[[param_name]])
}
})
test_that("lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)
# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)
param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)
# should be able to train right away
bst <- lgb.train(
params = param
, data = dtest_read_in
)
expect_true(lgb.is.Booster(x = bst))
})
test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)
# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)
param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)
# should be able to train right away
bst <- lgb.cv(
params = param
, data = dtest_read_in
)
expect_is(bst, "lgb.CVBooster")
})