[R-package] Disabled early stopping when using 'dart' boosting strategy (#2443)

This commit is contained in:
James Lamb 2019-10-25 09:12:47 -07:00 коммит произвёл GitHub
Родитель fc991c9d7e
Коммит 85be04a62a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 173 добавлений и 72 удалений

42
R-package/R/aliases.R Normal file
Просмотреть файл

@ -0,0 +1,42 @@
# Central location for parameter aliases.
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters
# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding aliases.
.PARAMETER_ALIASES <- function(){
return(list(
"boosting" = c(
"boosting"
, "boost"
, "boosting_type"
)
, "early_stopping_round" = c(
"early_stopping_round"
, "early_stopping_rounds"
, "early_stopping"
, "n_iter_no_change"
)
, "metric" = c(
"metric"
, "metrics"
, "metric_types"
)
, "num_class" = c(
"num_class"
, "num_classes"
)
, "num_iterations" = c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
))
}

Просмотреть файл

@ -37,7 +37,11 @@ cb.reset.parameters <- function(new_params) {
# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
not_allowed <- c("num_class", "metric", "boosting_type")
not_allowed <- c(
.PARAMETER_ALIASES()[["num_class"]]
, .PARAMETER_ALIASES()[["metric"]]
, .PARAMETER_ALIASES()[["boosting"]]
)
if (any(pnames %in% not_allowed)) {
stop(
"Parameters "

Просмотреть файл

@ -136,17 +136,7 @@ lgb.cv <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1
}
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_trees <- c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else {
@ -225,30 +215,52 @@ lgb.cv <- function(params = list(),
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
# Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
if (any(names(params) %in% early_stop)) {
if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
params[[which(names(params) %in% early_stop)[1]]]
, verbose = verbose
)
)
}
} else {
if (!is.null(early_stopping_rounds)) {
if (early_stopping_rounds > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
early_stopping_rounds
, verbose = verbose
)
)
# If early stopping was passed as a parameter in params(), prefer that to keyword argument
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
early_stop_param_indx <- names(params) %in% early_stop
if (any(early_stop_param_indx)) {
first_early_stop_param <- which(early_stop_param_indx)[[1]]
first_early_stop_param_name <- names(params)[[first_early_stop_param]]
early_stopping_rounds <- params[[first_early_stop_param_name]]
}
# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)
boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
using_dart <- any(
sapply(
X = boosting_param_names
, FUN = function(param){
identical(params[[param]], 'dart')
}
}
)
)
# Cannot use early stopping with 'dart' boosting
if (using_dart){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE
# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, verbose = verbose
)
)
}
# Categorize callbacks

Просмотреть файл

@ -108,24 +108,13 @@ lgb.train <- function(params = list(),
begin_iteration <- predictor$current_iter() + 1
}
# Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one
n_rounds <- c(
"num_iterations"
, "num_iteration"
, "n_iter"
, "num_tree"
, "num_trees"
, "num_round"
, "num_rounds"
, "num_boost_round"
, "n_estimators"
)
if (any(names(params) %in% n_rounds)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1
n_trees <- .PARAMETER_ALIASES()[["num_iterations"]]
if (any(names(params) %in% n_trees)) {
end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1
} else {
end_iteration <- begin_iteration + nrounds - 1
}
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object")
@ -207,30 +196,52 @@ lgb.train <- function(params = list(),
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
# Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change")
if (any(names(params) %in% early_stop)) {
if (params[[which(names(params) %in% early_stop)[1]]] > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
params[[which(names(params) %in% early_stop)[1]]]
, verbose = verbose
)
)
}
} else {
if (!is.null(early_stopping_rounds)) {
if (early_stopping_rounds > 0) {
callbacks <- add.cb(
callbacks
, cb.early.stop(
early_stopping_rounds
, verbose = verbose
)
)
# If early stopping was passed as a parameter in params(), prefer that to keyword argument
# early_stopping_rounds by overwriting the value in 'early_stopping_rounds'
early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]]
early_stop_param_indx <- names(params) %in% early_stop
if (any(early_stop_param_indx)) {
first_early_stop_param <- which(early_stop_param_indx)[[1]]
first_early_stop_param_name <- names(params)[[first_early_stop_param]]
early_stopping_rounds <- params[[first_early_stop_param_name]]
}
# Did user pass parameters that indicate they want to use early stopping?
using_early_stopping_via_args <- !is.null(early_stopping_rounds)
boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]]
using_dart <- any(
sapply(
X = boosting_param_names
, FUN = function(param){
identical(params[[param]], 'dart')
}
}
)
)
# Cannot use early stopping with 'dart' boosting
if (using_dart){
warning("Early stopping is not available in 'dart' mode.")
using_early_stopping_via_args <- FALSE
# Remove the cb.early.stop() function if it was passed in to callbacks
callbacks <- Filter(
f = function(cb_func){
!identical(attr(cb_func, "name"), "cb.early.stop")
}
, x = callbacks
)
}
# If user supplied early_stopping_rounds, add the early stopping callback
if (using_early_stopping_via_args){
callbacks <- add.cb(
callbacks
, cb.early.stop(
stopping_rounds = early_stopping_rounds
, verbose = verbose
)
)
}
# "Categorize" callbacks

Просмотреть файл

@ -43,3 +43,35 @@ test_that("Feature penalties work properly", {
# Ensure that feature is not used when feature_penalty = 0
expect_length(var_gain[[length(var_gain)]], 0)
})
expect_true(".PARAMETER_ALIASES() returns a named list", {
param_aliases <- .PARAMETER_ALIASES()
expect_true(is.list(param_aliases))
expect_true(is.character(names(param_aliases)))
expect_true(is.character(param_aliases[["boosting"]]))
expect_true(is.character(param_aliases[["early_stopping_round"]]))
expect_true(is.character(param_aliases[["metric"]]))
expect_true(is.character(param_aliases[["num_class"]]))
expect_true(is.character(param_aliases[["num_iterations"]]))
})
expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){
expect_warning({
result <- lightgbm(
data = train$data
, label = train$label
, num_leaves = 5
, learning_rate = 0.05
, nrounds = 5
, objective = "binary"
, metric = "binary_error"
, verbose = -1
, params = stats::setNames(
object = "dart"
, nm = boosting_param
)
)
}, regexp = "Early stopping is not available in 'dart' mode")
}
})