зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Disabled early stopping when using 'dart' boosting strategy (#2443)
This commit is contained in:
Родитель
fc991c9d7e
Коммит
85be04a62a
|
@ -0,0 +1,42 @@
|
|||
# Central location for parameter aliases.
|
||||
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters
|
||||
|
||||
# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
|
||||
# lazy evaluation (so it doesn't matter what order R sources files during installation).
|
||||
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
|
||||
# vector of corresponding aliases.
|
||||
.PARAMETER_ALIASES <- function(){
|
||||
return(list(
|
||||
"boosting" = c(
|
||||
"boosting"
|
||||
, "boost"
|
||||
, "boosting_type"
|
||||
)
|
||||
, "early_stopping_round" = c(
|
||||
"early_stopping_round"
|
||||
, "early_stopping_rounds"
|
||||
, "early_stopping"
|
||||
, "n_iter_no_change"
|
||||
)
|
||||
, "metric" = c(
|
||||
"metric"
|
||||
, "metrics"
|
||||
, "metric_types"
|
||||
)
|
||||
, "num_class" = c(
|
||||
"num_class"
|
||||
, "num_classes"
|
||||
)
|
||||
, "num_iterations" = c(
|
||||
"num_iterations"
|
||||
, "num_iteration"
|
||||
, "n_iter"
|
||||
, "num_tree"
|
||||
, "num_trees"
|
||||
, "num_round"
|
||||
, "num_rounds"
|
||||
, "num_boost_round"
|
||||
, "n_estimators"
|
||||
)
|
||||
))
|
||||
}
|
|
@ -37,7 +37,11 @@ cb.reset.parameters <- function(new_params) {
|
|||
|
||||
# Some parameters are not allowed to be changed,
|
||||
# since changing them would simply wreck some chaos
|
||||
not_allowed <- c("num_class", "metric", "boosting_type")
|
||||
not_allowed <- c(
|
||||
.PARAMETER_ALIASES()[["num_class"]]
|
||||
, .PARAMETER_ALIASES()[["metric"]]
|
||||
, .PARAMETER_ALIASES()[["boosting"]]
|
||||
)
|
||||
if (any(pnames %in% not_allowed)) {
|
||||
stop(
|
||||
"Parameters "
|
||||
|
|
|
@ -136,17 +136,7 @@ lgb.cv <- function(params = list(),
|
|||
begin_iteration <- predictor$current_iter() + 1
|
||||
}
|
||||
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
|
||||
n_trees <- c(
|
||||
"num_iterations"
|
||||
, "num_iteration"
|
||||
, "n_iter"
|
||||
, "num_tree"
|
||||
, "num_trees"
|
||||
, "num_round"
|
||||
, "num_rounds"
|
||||
, "num_boost_round"
|
||||
, "n_estimators"
|
||||
)
|
||||
n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
|
||||
if (any(names(params) %in% n_trees)) {
|
||||
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
|
||||
} else {
|
||||
|
@ -225,30 +215,52 @@ lgb.cv <- function(params = list(),
|
|||
callbacks <- add.cb(callbacks, cb.record.evaluation())
|
||||
}
|
||||
|
||||
# Check for early stopping passed as parameter when adding early stopping callback
|
||||
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
|
||||
if (any(names(params) %in% early_stop)) {
|
||||
if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
params[[which(names(params) %in% early_stop)[1]]]
|
||||
, verbose = verbose
|
||||
)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if (!is.null(early_stopping_rounds)) {
|
||||
if (early_stopping_rounds > 0) {
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
early_stopping_rounds
|
||||
, verbose = verbose
|
||||
)
|
||||
)
|
||||
# If early stopping was passed as a parameter in params(), prefer that to keyword argument
|
||||
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
|
||||
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
|
||||
early_stop_param_indx <- names(params) %in% early_stop
|
||||
if (any(early_stop_param_indx)) {
|
||||
first_early_stop_param <- which(early_stop_param_indx)[[1]]
|
||||
first_early_stop_param_name <- names(params)[[first_early_stop_param]]
|
||||
early_stopping_rounds <- params[[first_early_stop_param_name]]
|
||||
}
|
||||
|
||||
# Did user pass parameters that indicate they want to use early stopping?
|
||||
using_early_stopping_via_args <- !is.null(early_stopping_rounds)
|
||||
|
||||
boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
|
||||
using_dart <- any(
|
||||
sapply(
|
||||
X = boosting_param_names
|
||||
, FUN = function(param){
|
||||
identical(params[[param]], 'dart')
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Cannot use early stopping with 'dart' boosting
|
||||
if (using_dart){
|
||||
warning("Early stopping is not available in 'dart' mode.")
|
||||
using_early_stopping_via_args <- FALSE
|
||||
|
||||
# Remove the cb.early.stop() function if it was passed in to callbacks
|
||||
callbacks <- Filter(
|
||||
f = function(cb_func){
|
||||
!identical(attr(cb_func, "name"), "cb.early.stop")
|
||||
}
|
||||
, x = callbacks
|
||||
)
|
||||
}
|
||||
|
||||
# If user supplied early_stopping_rounds, add the early stopping callback
|
||||
if (using_early_stopping_via_args){
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
stopping_rounds = early_stopping_rounds
|
||||
, verbose = verbose
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
# Categorize callbacks
|
||||
|
|
|
@ -108,24 +108,13 @@ lgb.train <- function(params = list(),
|
|||
begin_iteration <- predictor$current_iter() + 1
|
||||
}
|
||||
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
|
||||
n_rounds <- c(
|
||||
"num_iterations"
|
||||
, "num_iteration"
|
||||
, "n_iter"
|
||||
, "num_tree"
|
||||
, "num_trees"
|
||||
, "num_round"
|
||||
, "num_rounds"
|
||||
, "num_boost_round"
|
||||
, "n_estimators"
|
||||
)
|
||||
if (any(names(params) %in% n_rounds)) {
|
||||
end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1
|
||||
n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
|
||||
if (any(names(params) %in% n_trees)) {
|
||||
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
|
||||
} else {
|
||||
end_iteration <- begin_iteration + nrounds - 1
|
||||
}
|
||||
|
||||
|
||||
# Check for training dataset type correctness
|
||||
if (!lgb.is.Dataset(data)) {
|
||||
stop("lgb.train: data only accepts lgb.Dataset object")
|
||||
|
@ -207,30 +196,52 @@ lgb.train <- function(params = list(),
|
|||
callbacks <- add.cb(callbacks, cb.record.evaluation())
|
||||
}
|
||||
|
||||
# Check for early stopping passed as parameter when adding early stopping callback
|
||||
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
|
||||
if (any(names(params) %in% early_stop)) {
|
||||
if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
params[[which(names(params) %in% early_stop)[1]]]
|
||||
, verbose = verbose
|
||||
)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
if (!is.null(early_stopping_rounds)) {
|
||||
if (early_stopping_rounds > 0) {
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
early_stopping_rounds
|
||||
, verbose = verbose
|
||||
)
|
||||
)
|
||||
# If early stopping was passed as a parameter in params(), prefer that to keyword argument
|
||||
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
|
||||
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
|
||||
early_stop_param_indx <- names(params) %in% early_stop
|
||||
if (any(early_stop_param_indx)) {
|
||||
first_early_stop_param <- which(early_stop_param_indx)[[1]]
|
||||
first_early_stop_param_name <- names(params)[[first_early_stop_param]]
|
||||
early_stopping_rounds <- params[[first_early_stop_param_name]]
|
||||
}
|
||||
|
||||
# Did user pass parameters that indicate they want to use early stopping?
|
||||
using_early_stopping_via_args <- !is.null(early_stopping_rounds)
|
||||
|
||||
boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
|
||||
using_dart <- any(
|
||||
sapply(
|
||||
X = boosting_param_names
|
||||
, FUN = function(param){
|
||||
identical(params[[param]], 'dart')
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Cannot use early stopping with 'dart' boosting
|
||||
if (using_dart){
|
||||
warning("Early stopping is not available in 'dart' mode.")
|
||||
using_early_stopping_via_args <- FALSE
|
||||
|
||||
# Remove the cb.early.stop() function if it was passed in to callbacks
|
||||
callbacks <- Filter(
|
||||
f = function(cb_func){
|
||||
!identical(attr(cb_func, "name"), "cb.early.stop")
|
||||
}
|
||||
, x = callbacks
|
||||
)
|
||||
}
|
||||
|
||||
# If user supplied early_stopping_rounds, add the early stopping callback
|
||||
if (using_early_stopping_via_args){
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
stopping_rounds = early_stopping_rounds
|
||||
, verbose = verbose
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
# "Categorize" callbacks
|
||||
|
|
|
@ -43,3 +43,35 @@ test_that("Feature penalties work properly", {
|
|||
# Ensure that feature is not used when feature_penalty = 0
|
||||
expect_length(var_gain[[length(var_gain)]], 0)
|
||||
})
|
||||
|
||||
expect_true(".PARAMETER_ALIASES() returns a named list", {
|
||||
param_aliases <- .PARAMETER_ALIASES()
|
||||
expect_true(is.list(param_aliases))
|
||||
expect_true(is.character(names(param_aliases)))
|
||||
expect_true(is.character(param_aliases[["boosting"]]))
|
||||
expect_true(is.character(param_aliases[["early_stopping_round"]]))
|
||||
expect_true(is.character(param_aliases[["metric"]]))
|
||||
expect_true(is.character(param_aliases[["num_class"]]))
|
||||
expect_true(is.character(param_aliases[["num_iterations"]]))
|
||||
})
|
||||
|
||||
expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
|
||||
for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){
|
||||
expect_warning({
|
||||
result <- lightgbm(
|
||||
data = train$data
|
||||
, label = train$label
|
||||
, num_leaves = 5
|
||||
, learning_rate = 0.05
|
||||
, nrounds = 5
|
||||
, objective = "binary"
|
||||
, metric = "binary_error"
|
||||
, verbose = -1
|
||||
, params = stats::setNames(
|
||||
object = "dart"
|
||||
, nm = boosting_param
|
||||
)
|
||||
)
|
||||
}, regexp = "Early stopping is not available in 'dart' mode")
|
||||
}
|
||||
})
|
||||
|
|
Загрузка…
Ссылка в новой задаче