зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Interface for interaction constraints (#3136)
* Add interaction constraints functionality. * Minor fixes. * Minor fixes. * Change lambda to function. * Fix gpu bug, remove extra blank lines. * Fix gpu bug. * Fix style issues. * Try to fix segfault on MACOS. * Fix bug. * Fix bug. * Fix bugs. * Change parameter format for R. * Fix R style issues. * Change string formatting code. * Change docs to say R package not supported. * Refactor check_interaction_constraints into separate function, add validation. * Fix error messages. * Add tests. * Update docs. * Fix tests, minor refactoring. * Fix style issues. * Fix R style issue. * Remove old code. * Fix existing test and add new one. * Fix R lint error.
This commit is contained in:
Родитель
cfc5e4fe8b
Коммит
4f8c32d9a6
|
@ -148,6 +148,15 @@ lgb.cv <- function(params = list()
|
|||
end_iteration <- begin_iteration + nrounds - 1L
|
||||
}
|
||||
|
||||
# Check interaction constraints
|
||||
cnames <- NULL
|
||||
if (!is.null(colnames)) {
|
||||
cnames <- colnames
|
||||
} else if (!is.null(data$get_colnames())) {
|
||||
cnames <- data$get_colnames()
|
||||
}
|
||||
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames)
|
||||
|
||||
# Check for weights
|
||||
if (!is.null(weight)) {
|
||||
data$setinfo("weight", weight)
|
||||
|
|
|
@ -124,9 +124,14 @@ lgb.train <- function(params = list(),
|
|||
end_iteration <- begin_iteration + nrounds - 1L
|
||||
}
|
||||
|
||||
if (!is.null(params[["interaction_constraints"]])) {
|
||||
stop("lgb.train: interaction_constraints is not implemented")
|
||||
# Check interaction constraints
|
||||
cnames <- NULL
|
||||
if (!is.null(colnames)) {
|
||||
cnames <- colnames
|
||||
} else if (!is.null(data$get_colnames())) {
|
||||
cnames <- data$get_colnames()
|
||||
}
|
||||
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames)
|
||||
|
||||
# Update parameters with parsed parameters
|
||||
data$update_params(params)
|
||||
|
|
|
@ -167,6 +167,65 @@ lgb.params2str <- function(params, ...) {
|
|||
|
||||
}
|
||||
|
||||
lgb.check_interaction_constraints <- function(params, column_names) {
|
||||
|
||||
# Convert interaction constraints to feature numbers
|
||||
string_constraints <- list()
|
||||
|
||||
if (!is.null(params[["interaction_constraints"]])) {
|
||||
|
||||
# validation
|
||||
if (!methods::is(params[["interaction_constraints"]], "list")) {
|
||||
stop("interaction_constraints must be a list")
|
||||
}
|
||||
if (!all(sapply(params[["interaction_constraints"]], function(x) {is.character(x) || is.numeric(x)}))) {
|
||||
stop("every element in interaction_constraints must be a character vector or numeric vector")
|
||||
}
|
||||
|
||||
for (constraint in params[["interaction_constraints"]]) {
|
||||
|
||||
# Check for character name
|
||||
if (is.character(constraint)) {
|
||||
|
||||
constraint_indices <- as.integer(match(constraint, column_names) - 1L)
|
||||
|
||||
# Provided indices, but some indices are not existing?
|
||||
if (sum(is.na(constraint_indices)) > 0L) {
|
||||
stop(
|
||||
"supplied an unknown feature in interaction_constraints "
|
||||
, sQuote(constraint[is.na(constraint_indices)])
|
||||
)
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
# Check that constraint indices are at most number of features
|
||||
if (max(constraint) > length(column_names)) {
|
||||
stop(
|
||||
"supplied a too large value in interaction_constraints: "
|
||||
, max(constraint)
|
||||
, " but only "
|
||||
, length(column_names)
|
||||
, " features"
|
||||
)
|
||||
}
|
||||
|
||||
# Store indices as [0, n-1] indexed instead of [1, n] indexed
|
||||
constraint_indices <- as.integer(constraint - 1L)
|
||||
|
||||
}
|
||||
|
||||
# Convert constraint to string
|
||||
constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]")
|
||||
string_constraints <- append(string_constraints, constraint_string)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return(string_constraints)
|
||||
|
||||
}
|
||||
|
||||
lgb.c_str <- function(x) {
|
||||
|
||||
# Perform character to raw conversion
|
||||
|
|
|
@ -1030,3 +1030,103 @@ test_that("using lightgbm() without early stopping, best_iter and best_score com
|
|||
expect_identical(bst$best_iter, which.max(auc_scores))
|
||||
expect_identical(bst$best_score, auc_scores[which.max(auc_scores)])
|
||||
})
|
||||
|
||||
test_that("lgb.train() throws an informative error if interaction_constraints is not a list", {
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label)
|
||||
params <- list(objective = "regression", interaction_constraints = "[1,2],[3]")
|
||||
expect_error({
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
}, "interaction_constraints must be a list")
|
||||
})
|
||||
|
||||
test_that(paste0("lgb.train() throws an informative error if the members of interaction_constraints ",
|
||||
"are not character or numeric vectors"), {
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label)
|
||||
params <- list(objective = "regression", interaction_constraints = list(list(1L, 2L), list(3L)))
|
||||
expect_error({
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
}, "every element in interaction_constraints must be a character vector or numeric vector")
|
||||
})
|
||||
|
||||
test_that("lgb.train() throws an informative error if interaction_constraints contains a too large index", {
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label)
|
||||
params <- list(objective = "regression",
|
||||
interaction_constraints = list(c(1L, length(colnames(train$data)) + 1L), 3L))
|
||||
expect_error({
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
}, "supplied a too large value in interaction_constraints")
|
||||
})
|
||||
|
||||
test_that(paste0("lgb.train() gives same result when interaction_constraints is specified as a list of ",
|
||||
"character vectors, numeric vectors, or a combination"), {
|
||||
set.seed(1L)
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label)
|
||||
|
||||
params <- list(objective = "regression", interaction_constraints = list(c(1L, 2L), 3L))
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
pred1 <- bst$predict(test$data)
|
||||
|
||||
cnames <- colnames(train$data)
|
||||
params <- list(objective = "regression", interaction_constraints = list(c(cnames[[1L]], cnames[[2L]]), cnames[[3L]]))
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
pred2 <- bst$predict(test$data)
|
||||
|
||||
params <- list(objective = "regression", interaction_constraints = list(c(cnames[[1L]], cnames[[2L]]), 3L))
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
pred3 <- bst$predict(test$data)
|
||||
|
||||
expect_equal(pred1, pred2)
|
||||
expect_equal(pred2, pred3)
|
||||
|
||||
})
|
||||
|
||||
test_that(paste0("lgb.train() gives same results when using interaction_constraints and specifying colnames"), {
|
||||
set.seed(1L)
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label)
|
||||
|
||||
params <- list(objective = "regression", interaction_constraints = list(c(1L, 2L), 3L))
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
)
|
||||
pred1 <- bst$predict(test$data)
|
||||
|
||||
new_colnames <- paste0(colnames(train$data), "_x")
|
||||
params <- list(objective = "regression"
|
||||
, interaction_constraints = list(c(new_colnames[1L], new_colnames[2L]), new_colnames[3L]))
|
||||
bst <- lightgbm(
|
||||
data = dtrain
|
||||
, params = params
|
||||
, nrounds = 2L
|
||||
, colnames = new_colnames
|
||||
)
|
||||
pred2 <- bst$predict(test$data)
|
||||
|
||||
expect_equal(pred1, pred2)
|
||||
|
||||
})
|
||||
|
|
|
@ -548,7 +548,7 @@ Learning Control Parameters
|
|||
|
||||
- for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
|
||||
|
||||
- for R-package, **not yet supported**
|
||||
- for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
|
||||
|
||||
- any two features can only appear in the same branch only if there exists a constraint containing both features
|
||||
|
||||
|
|
|
@ -509,7 +509,7 @@ struct Config {
|
|||
// desc = by default interaction constraints are disabled, to enable them you can specify
|
||||
// descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
|
||||
// descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
|
||||
// descl2 = for R-package, **not yet supported**
|
||||
// descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
|
||||
// desc = any two features can only appear in the same branch only if there exists a constraint containing both features
|
||||
std::string interaction_constraints = "";
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче