зеркало из https://github.com/microsoft/LightGBM.git
* [R-package] enable use of trees with linear models at leaves (fixes #3319) * remove problematic pragmas * fix tests * try to fix build scripts * try fixing pragma check * more pragma checks * ok fix pragma stuff for real * empty commit * regenerate documentation * try skipping test * uncomment CI * add note on missing value types for R * add tests on saving and re-loading booster
This commit is contained in:
Родитель
706f2af7ba
Коммит
ed651e8672
|
@ -1699,6 +1699,12 @@ CXX=`"${R_HOME}/bin/R" CMD config CXX11`
|
|||
# LightGBM-specific flags
|
||||
LGB_CPPFLAGS=""
|
||||
|
||||
#########
|
||||
# Eigen #
|
||||
#########
|
||||
|
||||
LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY"
|
||||
|
||||
###############
|
||||
# MM_PREFETCH #
|
||||
###############
|
||||
|
|
|
@ -26,6 +26,12 @@ CXX=`"${R_HOME}/bin/R" CMD config CXX11`
|
|||
# LightGBM-specific flags
|
||||
LGB_CPPFLAGS=""
|
||||
|
||||
#########
|
||||
# Eigen #
|
||||
#########
|
||||
|
||||
LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY"
|
||||
|
||||
###############
|
||||
# MM_PREFETCH #
|
||||
###############
|
||||
|
|
|
@ -12,6 +12,12 @@ CC=`"${R_EXE}" CMD config CC`
|
|||
# LightGBM-specific flags
|
||||
LGB_CPPFLAGS=""
|
||||
|
||||
#########
|
||||
# Eigen #
|
||||
#########
|
||||
|
||||
LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY"
|
||||
|
||||
###############
|
||||
# MM_PREFETCH #
|
||||
###############
|
||||
|
|
|
@ -345,6 +345,45 @@ test_that("lightgbm.cv() gives the correct best_score and best_iter for a metric
|
|||
expect_identical(cv_bst$best_score, auc_scores[which.max(auc_scores)])
|
||||
})
|
||||
|
||||
test_that("lgb.cv() fit on linearly-relatead data improves when using linear learners", {
|
||||
set.seed(708L)
|
||||
.new_dataset <- function() {
|
||||
X <- matrix(rnorm(1000L), ncol = 1L)
|
||||
return(lgb.Dataset(
|
||||
data = X
|
||||
, label = 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
))
|
||||
}
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
cv_bst <- lgb.cv(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, nfold = 5L
|
||||
)
|
||||
expect_is(cv_bst, "lgb.CVBooster")
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
cv_bst_linear <- lgb.cv(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, nfold = 5L
|
||||
)
|
||||
expect_is(cv_bst_linear, "lgb.CVBooster")
|
||||
|
||||
expect_true(cv_bst_linear$best_score < cv_bst$best_score)
|
||||
})
|
||||
|
||||
context("lgb.train()")
|
||||
|
||||
test_that("lgb.train() works as expected with multiple eval metrics", {
|
||||
|
@ -1631,6 +1670,247 @@ test_that("early stopping works with lgb.cv()", {
|
|||
)
|
||||
})
|
||||
|
||||
context("linear learner")
|
||||
|
||||
test_that("lgb.train() fit on linearly-relatead data improves when using linear learners", {
|
||||
set.seed(708L)
|
||||
.new_dataset <- function() {
|
||||
X <- matrix(rnorm(100L), ncol = 1L)
|
||||
return(lgb.Dataset(
|
||||
data = X
|
||||
, label = 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
))
|
||||
}
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst_linear <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst_linear))
|
||||
|
||||
bst_last_mse <- bst$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
bst_lin_last_mse <- bst_linear$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
expect_true(bst_lin_last_mse < bst_last_mse)
|
||||
})
|
||||
|
||||
|
||||
test_that("lgb.train() w/ linear learner fails already-constructed dataset with linear=false", {
|
||||
testthat::skip("Skipping this test because it causes issues for valgrind")
|
||||
set.seed(708L)
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
dtrain <- lgb.Dataset(
|
||||
data = matrix(rnorm(100L), ncol = 1L)
|
||||
, label = rnorm(100L)
|
||||
)
|
||||
dtrain$construct()
|
||||
expect_error({
|
||||
bst_linear <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
}, regexp = "Cannot change linear_tree after constructed Dataset handle")
|
||||
})
|
||||
|
||||
test_that("lgb.train() works with linear learners even if Dataset has missing values", {
|
||||
set.seed(708L)
|
||||
.new_dataset <- function() {
|
||||
values <- rnorm(100L)
|
||||
values[sample(seq_len(length(values)), size = 10L)] <- NA_real_
|
||||
X <- matrix(
|
||||
data = sample(values, size = 100L)
|
||||
, ncol = 1L
|
||||
)
|
||||
return(lgb.Dataset(
|
||||
data = X
|
||||
, label = 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
))
|
||||
}
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst_linear <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst_linear))
|
||||
|
||||
bst_last_mse <- bst$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
bst_lin_last_mse <- bst_linear$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
expect_true(bst_lin_last_mse < bst_last_mse)
|
||||
})
|
||||
|
||||
test_that("lgb.train() works with linear learners, bagging, and a Dataset that has missing values", {
|
||||
set.seed(708L)
|
||||
.new_dataset <- function() {
|
||||
values <- rnorm(100L)
|
||||
values[sample(seq_len(length(values)), size = 10L)] <- NA_real_
|
||||
X <- matrix(
|
||||
data = sample(values, size = 100L)
|
||||
, ncol = 1L
|
||||
)
|
||||
return(lgb.Dataset(
|
||||
data = X
|
||||
, label = 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
))
|
||||
}
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
, bagging_freq = 1L
|
||||
, subsample = 0.8
|
||||
)
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst_linear <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst_linear))
|
||||
|
||||
bst_last_mse <- bst$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
bst_lin_last_mse <- bst_linear$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
expect_true(bst_lin_last_mse < bst_last_mse)
|
||||
})
|
||||
|
||||
test_that("lgb.train() works with linear learners and data where a feature has only 1 non-NA value", {
|
||||
set.seed(708L)
|
||||
.new_dataset <- function() {
|
||||
values <- rep(NA_real_, 100L)
|
||||
values[18L] <- rnorm(1L)
|
||||
X <- matrix(
|
||||
data = values
|
||||
, ncol = 1L
|
||||
)
|
||||
return(lgb.Dataset(
|
||||
data = X
|
||||
, label = 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
))
|
||||
}
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst_linear <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst_linear))
|
||||
})
|
||||
|
||||
test_that("lgb.train() works with linear learners when Dataset has categorical features", {
|
||||
set.seed(708L)
|
||||
.new_dataset <- function() {
|
||||
X <- matrix(numeric(200L), nrow = 100L, ncol = 2L)
|
||||
X[, 1L] <- rnorm(100L)
|
||||
X[, 2L] <- sample(seq_len(4L), size = 100L, replace = TRUE)
|
||||
return(lgb.Dataset(
|
||||
data = X
|
||||
, label = 2L * X[, 1L] + runif(nrow(X), 0L, 0.1)
|
||||
))
|
||||
}
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
, categorical_featurs = 1L
|
||||
)
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
|
||||
dtrain <- .new_dataset()
|
||||
bst_linear <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = modifyList(params, list(linear_tree = TRUE))
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst_linear))
|
||||
|
||||
bst_last_mse <- bst$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
bst_lin_last_mse <- bst_linear$record_evals[["train"]][["l2"]][["eval"]][[10L]]
|
||||
expect_true(bst_lin_last_mse < bst_last_mse)
|
||||
})
|
||||
|
||||
context("interaction constraints")
|
||||
|
||||
test_that("lgb.train() throws an informative error if interaction_constraints is not a list", {
|
||||
|
|
|
@ -135,7 +135,7 @@ test_that("lgb.load() gives the expected error messages given different incorrec
|
|||
|
||||
})
|
||||
|
||||
test_that("Loading a Booster from a file works", {
|
||||
test_that("Loading a Booster from a text file works", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
data(agaricus.test, package = "lightgbm")
|
||||
|
@ -168,6 +168,47 @@ test_that("Loading a Booster from a file works", {
|
|||
expect_identical(pred, pred2)
|
||||
})
|
||||
|
||||
test_that("boosters with linear models at leaves can be written to text file and re-loaded successfully", {
|
||||
X <- matrix(rnorm(100L), ncol = 1L)
|
||||
labels <- 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
dtrain <- lgb.Dataset(
|
||||
data = X
|
||||
, label = labels
|
||||
)
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
bst <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
|
||||
# save predictions, then write the model to a file and destroy it in R
|
||||
preds <- predict(bst, X)
|
||||
model_file <- tempfile(fileext = ".model")
|
||||
lgb.save(bst, model_file)
|
||||
bst$finalize()
|
||||
expect_null(bst$.__enclos_env__$private$handle)
|
||||
rm(bst)
|
||||
|
||||
# load the booster and make predictions...should be the same
|
||||
bst2 <- lgb.load(
|
||||
filename = model_file
|
||||
)
|
||||
pred2 <- predict(bst2, X)
|
||||
expect_identical(preds, pred2)
|
||||
})
|
||||
|
||||
|
||||
test_that("Loading a Booster from a string works", {
|
||||
set.seed(708L)
|
||||
data(agaricus.train, package = "lightgbm")
|
||||
|
@ -730,3 +771,41 @@ test_that("params (including dataset params) should be stored in .rds file for B
|
|||
)
|
||||
)
|
||||
})
|
||||
|
||||
test_that("boosters with linear models at leaves can be written to RDS and re-loaded successfully", {
|
||||
X <- matrix(rnorm(100L), ncol = 1L)
|
||||
labels <- 2L * X + runif(nrow(X), 0L, 0.1)
|
||||
dtrain <- lgb.Dataset(
|
||||
data = X
|
||||
, label = labels
|
||||
)
|
||||
|
||||
params <- list(
|
||||
objective = "regression"
|
||||
, verbose = -1L
|
||||
, metric = "mse"
|
||||
, seed = 0L
|
||||
, num_leaves = 2L
|
||||
)
|
||||
|
||||
bst <- lgb.train(
|
||||
data = dtrain
|
||||
, nrounds = 10L
|
||||
, params = params
|
||||
, valids = list("train" = dtrain)
|
||||
)
|
||||
expect_true(lgb.is.Booster(bst))
|
||||
|
||||
# save predictions, then write the model to a file and destroy it in R
|
||||
preds <- predict(bst, X)
|
||||
model_file <- tempfile(fileext = ".rds")
|
||||
saveRDS.lgb.Booster(bst, file = model_file)
|
||||
bst$finalize()
|
||||
expect_null(bst$.__enclos_env__$private$handle)
|
||||
rm(bst)
|
||||
|
||||
# load the booster and make predictions...should be the same
|
||||
bst2 <- readRDS.lgb.Booster(file = model_file)
|
||||
pred2 <- predict(bst2, X)
|
||||
expect_identical(preds, pred2)
|
||||
})
|
||||
|
|
|
@ -37,6 +37,27 @@ cp \
|
|||
external_libs/fmt/include/fmt/*.h \
|
||||
${TEMP_R_DIR}/src/include/LightGBM/fmt/
|
||||
|
||||
# including only specific files from Eigen, to keep the R package
|
||||
# small and avoid redistributing code with licenses incompatible with
|
||||
# LightGBM's license
|
||||
EIGEN_R_DIR=${TEMP_R_DIR}/src/include/Eigen
|
||||
mkdir -p ${EIGEN_R_DIR}
|
||||
|
||||
modules="Cholesky Core Dense Eigenvalues Geometry Householder Jacobi LU QR SVD"
|
||||
for eigen_module in ${modules}; do
|
||||
cp eigen/Eigen/${eigen_module} ${EIGEN_R_DIR}/${eigen_module}
|
||||
if [ ${eigen_module} != "Dense" ]; then
|
||||
mkdir -p ${EIGEN_R_DIR}/src/${eigen_module}/
|
||||
cp -R eigen/Eigen/src/${eigen_module}/* ${EIGEN_R_DIR}/src/${eigen_module}/
|
||||
fi
|
||||
done
|
||||
|
||||
mkdir -p ${EIGEN_R_DIR}/src/misc
|
||||
cp -R eigen/Eigen/src/misc/* ${EIGEN_R_DIR}/src/misc/
|
||||
|
||||
mkdir -p ${EIGEN_R_DIR}/src/plugins
|
||||
cp -R eigen/Eigen/src/plugins/* ${EIGEN_R_DIR}/src/plugins/
|
||||
|
||||
cd ${TEMP_R_DIR}
|
||||
|
||||
# Remove files not needed for CRAN
|
||||
|
@ -69,6 +90,9 @@ cd ${TEMP_R_DIR}
|
|||
for file in $(find . -name '*.h' -o -name '*.hpp' -o -name '*.cpp'); do
|
||||
sed \
|
||||
-i.bak \
|
||||
-e 's/^.*#pragma clang diagnostic.*$//' \
|
||||
-e 's/^.*#pragma diag_suppress.*$//' \
|
||||
-e 's/^.*#pragma GCC diagnostic.*$//' \
|
||||
-e 's/^.*#pragma region.*$//' \
|
||||
-e 's/^.*#pragma endregion.*$//' \
|
||||
-e 's/^.*#pragma warning.*$//' \
|
||||
|
|
66
build_r.R
66
build_r.R
|
@ -156,6 +156,72 @@ if (USING_GPU) {
|
|||
.handle_result(result)
|
||||
}
|
||||
|
||||
EIGEN_R_DIR <- file.path(TEMP_SOURCE_DIR, "include", "Eigen")
|
||||
dir.create(EIGEN_R_DIR)
|
||||
|
||||
eigen_modules <- c(
|
||||
"Cholesky"
|
||||
, "Core"
|
||||
, "Dense"
|
||||
, "Eigenvalues"
|
||||
, "Geometry"
|
||||
, "Householder"
|
||||
, "Jacobi"
|
||||
, "LU"
|
||||
, "QR"
|
||||
, "SVD"
|
||||
)
|
||||
for (eigen_module in eigen_modules) {
|
||||
result <- file.copy(
|
||||
from = file.path("eigen", "Eigen", eigen_module)
|
||||
, to = EIGEN_R_DIR
|
||||
, recursive = FALSE
|
||||
, overwrite = TRUE
|
||||
)
|
||||
.handle_result(result)
|
||||
}
|
||||
|
||||
dir.create(file.path(EIGEN_R_DIR, "src"))
|
||||
|
||||
for (eigen_module in c(eigen_modules, "misc", "plugins")) {
|
||||
if (eigen_module == "Dense") {
|
||||
next
|
||||
}
|
||||
module_dir <- file.path(EIGEN_R_DIR, "src", eigen_module)
|
||||
dir.create(module_dir, recursive = TRUE)
|
||||
result <- file.copy(
|
||||
from = sprintf("%s/", file.path("eigen", "Eigen", "src", eigen_module))
|
||||
, to = sprintf("%s/", file.path(EIGEN_R_DIR, "src"))
|
||||
, recursive = TRUE
|
||||
, overwrite = TRUE
|
||||
)
|
||||
.handle_result(result)
|
||||
}
|
||||
|
||||
.replace_pragmas <- function(filepath) {
|
||||
pragma_patterns <- c(
|
||||
"^.*#pragma clang diagnostic.*$"
|
||||
, "^.*#pragma diag_suppress.*$"
|
||||
, "^.*#pragma GCC diagnostic.*$"
|
||||
, "^.*#pragma region.*$"
|
||||
, "^.*#pragma endregion.*$"
|
||||
, "^.*#pragma warning.*$"
|
||||
)
|
||||
content <- readLines(filepath)
|
||||
for (pragma_pattern in pragma_patterns) {
|
||||
content <- content[!grepl(pragma_pattern, content)]
|
||||
}
|
||||
writeLines(content, filepath)
|
||||
}
|
||||
|
||||
# remove pragmas that suppress warnings, to appease R CMD check
|
||||
.replace_pragmas(
|
||||
file.path(EIGEN_R_DIR, "src", "Core", "arch", "SSE", "Complex.h")
|
||||
)
|
||||
.replace_pragmas(
|
||||
file.path(EIGEN_R_DIR, "src", "Core", "util", "DisableStupidWarnings.h")
|
||||
)
|
||||
|
||||
result <- file.copy(
|
||||
from = "CMakeLists.txt"
|
||||
, to = file.path(TEMP_R_DIR, "inst", "bin/")
|
||||
|
|
|
@ -127,14 +127,12 @@ Core Parameters
|
|||
|
||||
- categorical features are used for splits as normal but are not used in the linear models
|
||||
|
||||
- missing values must be encoded as ``np.nan`` (Python) or ``NA`` (CLI), not ``0``
|
||||
- missing values should not be encoded as ``0``. Use ``np.nan`` for Python, ``NA`` for the CLI, and ``NA``, ``NA_real_``, or ``NA_integer_`` for R
|
||||
|
||||
- it is recommended to rescale data before training so that features have similar mean and standard deviation
|
||||
|
||||
- **Note**: only works with CPU and ``serial`` tree learner
|
||||
|
||||
- **Note**: not yet supported in R-package
|
||||
|
||||
- **Note**: ``regression_l1`` objective is not supported with linear tree boosting
|
||||
|
||||
- **Note**: setting ``linear_tree=true`` significantly increases the memory use of LightGBM
|
||||
|
|
|
@ -152,10 +152,9 @@ struct Config {
|
|||
// descl2 = tree splits are chosen in the usual way, but the model at each leaf is linear instead of constant
|
||||
// descl2 = the linear model at each leaf includes all the numerical features in that leaf's branch
|
||||
// descl2 = categorical features are used for splits as normal but are not used in the linear models
|
||||
// descl2 = missing values must be encoded as ``np.nan`` (Python) or ``NA`` (CLI), not ``0``
|
||||
// descl2 = missing values should not be encoded as ``0``. Use ``np.nan`` for Python, ``NA`` for the CLI, and ``NA``, ``NA_real_``, or ``NA_integer_`` for R
|
||||
// descl2 = it is recommended to rescale data before training so that features have similar mean and standard deviation
|
||||
// descl2 = **Note**: only works with CPU and ``serial`` tree learner
|
||||
// descl2 = **Note**: not yet supported in R-package
|
||||
// descl2 = **Note**: ``regression_l1`` objective is not supported with linear tree boosting
|
||||
// descl2 = **Note**: setting ``linear_tree=true`` significantly increases the memory use of LightGBM
|
||||
bool linear_tree = false;
|
||||
|
|
|
@ -4,11 +4,9 @@
|
|||
*/
|
||||
#include "linear_tree_learner.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef LGB_R_BUILD
|
||||
#include <Eigen/Dense>
|
||||
#endif // !LGB_R_BUILD
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace LightGBM {
|
||||
|
||||
|
@ -170,12 +168,7 @@ void LinearTreeLearner::GetLeafMap(Tree* tree) const {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef LGB_R_BUILD
|
||||
template<bool HAS_NAN>
|
||||
void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
|
||||
Log::Fatal("Linear tree learner does not work with R package.");
|
||||
}
|
||||
#else
|
||||
|
||||
template<bool HAS_NAN>
|
||||
void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
|
||||
tree->SetIsLinear(true);
|
||||
|
@ -385,5 +378,4 @@ void LinearTreeLearner::CalculateLinear(Tree* tree, bool is_refit, const score_t
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif // LGB_R_BUILD
|
||||
} // namespace LightGBM
|
||||
|
|
|
@ -17,9 +17,6 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con
|
|||
if (device_type == std::string("cpu")) {
|
||||
if (learner_type == std::string("serial")) {
|
||||
if (config->linear_tree) {
|
||||
#ifdef LGB_R_BUILD
|
||||
Log::Fatal("Linear tree learner does not work with R package.");
|
||||
#endif // LGB_R_BUILD
|
||||
return new LinearTreeLearner(config);
|
||||
} else {
|
||||
return new SerialTreeLearner(config);
|
||||
|
|
Загрузка…
Ссылка в новой задаче