diff --git a/R-package/R/aliases.R b/R-package/R/aliases.R new file mode 100644 index 000000000..76d00aff1 --- /dev/null +++ b/R-package/R/aliases.R @@ -0,0 +1,42 @@ +# Central location for parameter aliases. +# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters + +# [description] List of respected parameter aliases. Wrapped in a function to take advantage of +# lazy evaluation (so it doesn't matter what order R sources files during installation). +# [return] A named list, where each key is a main LightGBM parameter and each value is a character +# vector of corresponding aliases. +.PARAMETER_ALIASES <- function(){ + return(list( + "boosting" = c( + "boosting" + , "boost" + , "boosting_type" + ) + , "early_stopping_round" = c( + "early_stopping_round" + , "early_stopping_rounds" + , "early_stopping" + , "n_iter_no_change" + ) + , "metric" = c( + "metric" + , "metrics" + , "metric_types" + ) + , "num_class" = c( + "num_class" + , "num_classes" + ) + , "num_iterations" = c( + "num_iterations" + , "num_iteration" + , "n_iter" + , "num_tree" + , "num_trees" + , "num_round" + , "num_rounds" + , "num_boost_round" + , "n_estimators" + ) + )) +} diff --git a/R-package/R/callback.R b/R-package/R/callback.R index 92bd9c035..1b5f4f456 100644 --- a/R-package/R/callback.R +++ b/R-package/R/callback.R @@ -37,7 +37,11 @@ cb.reset.parameters <- function(new_params) { # Some parameters are not allowed to be changed, # since changing them would simply wreck some chaos - not_allowed <- c("num_class", "metric", "boosting_type") + not_allowed <- c( + .PARAMETER_ALIASES()[["num_class"]] + , .PARAMETER_ALIASES()[["metric"]] + , .PARAMETER_ALIASES()[["boosting"]] + ) if (any(pnames %in% not_allowed)) { stop( "Parameters " diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 594b323df..c810432b0 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -136,17 +136,7 @@ lgb.cv <- function(params = list(), begin_iteration <- predictor$current_iter() + 1 } # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one - n_trees <- c( - "num_iterations" - , "num_iteration" - , "n_iter" - , "num_tree" - , "num_trees" - , "num_round" - , "num_rounds" - , "num_boost_round" - , "n_estimators" - ) + n_trees <- .PARAMETER_ALIASES()[["num_iterations"]] if (any(names(params) %in% n_trees)) { end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1 } else { @@ -225,30 +215,52 @@ lgb.cv <- function(params = list(), callbacks <- add.cb(callbacks, cb.record.evaluation()) } - # Check for early stopping passed as parameter when adding early stopping callback - early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change") - if (any(names(params) %in% early_stop)) { - if (params[[which(names(params) %in% early_stop)[1]]] > 0) { - callbacks <- add.cb( - callbacks - , cb.early.stop( - params[[which(names(params) %in% early_stop)[1]]] - , verbose = verbose - ) - ) - } - } else { - if (!is.null(early_stopping_rounds)) { - if (early_stopping_rounds > 0) { - callbacks <- add.cb( - callbacks - , cb.early.stop( - early_stopping_rounds - , verbose = verbose - ) - ) + # If early stopping was passed as a parameter in params(), prefer that to keyword argument + # early_stopping_rounds by overwriting the value in 'early_stopping_rounds' + early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]] + early_stop_param_indx <- names(params) %in% early_stop + if (any(early_stop_param_indx)) { + first_early_stop_param <- which(early_stop_param_indx)[[1]] + first_early_stop_param_name <- names(params)[[first_early_stop_param]] + early_stopping_rounds <- params[[first_early_stop_param_name]] + } + + # Did user pass parameters that indicate they want to use early stopping? + using_early_stopping_via_args <- !is.null(early_stopping_rounds) + + boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]] + using_dart <- any( + sapply( + X = boosting_param_names + , FUN = function(param){ + identical(params[[param]], 'dart') } - } + ) + ) + + # Cannot use early stopping with 'dart' boosting + if (using_dart){ + warning("Early stopping is not available in 'dart' mode.") + using_early_stopping_via_args <- FALSE + + # Remove the cb.early.stop() function if it was passed in to callbacks + callbacks <- Filter( + f = function(cb_func){ + !identical(attr(cb_func, "name"), "cb.early.stop") + } + , x = callbacks + ) + } + + # If user supplied early_stopping_rounds, add the early stopping callback + if (using_early_stopping_via_args){ + callbacks <- add.cb( + callbacks + , cb.early.stop( + stopping_rounds = early_stopping_rounds + , verbose = verbose + ) + ) } # Categorize callbacks diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 79c6b4e48..eb4bef407 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -108,24 +108,13 @@ lgb.train <- function(params = list(), begin_iteration <- predictor$current_iter() + 1 } # Check for number of rounds passed as parameter - in case there are multiple ones, take only the first one - n_rounds <- c( - "num_iterations" - , "num_iteration" - , "n_iter" - , "num_tree" - , "num_trees" - , "num_round" - , "num_rounds" - , "num_boost_round" - , "n_estimators" - ) - if (any(names(params) %in% n_rounds)) { - end_iteration <- begin_iteration + params[[which(names(params) %in% n_rounds)[1]]] - 1 + n_trees <- .PARAMETER_ALIASES()[["num_iterations"]] + if (any(names(params) %in% n_trees)) { + end_iteration <- begin_iteration + params[[which(names(params) %in% n_trees)[1]]] - 1 } else { end_iteration <- begin_iteration + nrounds - 1 } - # Check for training dataset type correctness if (!lgb.is.Dataset(data)) { stop("lgb.train: data only accepts lgb.Dataset object") @@ -207,30 +196,52 @@ lgb.train <- function(params = list(), callbacks <- add.cb(callbacks, cb.record.evaluation()) } - # Check for early stopping passed as parameter when adding early stopping callback - early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping", "n_iter_no_change") - if (any(names(params) %in% early_stop)) { - if (params[[which(names(params) %in% early_stop)[1]]] > 0) { - callbacks <- add.cb( - callbacks - , cb.early.stop( - params[[which(names(params) %in% early_stop)[1]]] - , verbose = verbose - ) - ) - } - } else { - if (!is.null(early_stopping_rounds)) { - if (early_stopping_rounds > 0) { - callbacks <- add.cb( - callbacks - , cb.early.stop( - early_stopping_rounds - , verbose = verbose - ) - ) + # If early stopping was passed as a parameter in params(), prefer that to keyword argument + # early_stopping_rounds by overwriting the value in 'early_stopping_rounds' + early_stop <- .PARAMETER_ALIASES()[["early_stopping_round"]] + early_stop_param_indx <- names(params) %in% early_stop + if (any(early_stop_param_indx)) { + first_early_stop_param <- which(early_stop_param_indx)[[1]] + first_early_stop_param_name <- names(params)[[first_early_stop_param]] + early_stopping_rounds <- params[[first_early_stop_param_name]] + } + + # Did user pass parameters that indicate they want to use early stopping? + using_early_stopping_via_args <- !is.null(early_stopping_rounds) + + boosting_param_names <- .PARAMETER_ALIASES()[["boosting"]] + using_dart <- any( + sapply( + X = boosting_param_names + , FUN = function(param){ + identical(params[[param]], 'dart') } - } + ) + ) + + # Cannot use early stopping with 'dart' boosting + if (using_dart){ + warning("Early stopping is not available in 'dart' mode.") + using_early_stopping_via_args <- FALSE + + # Remove the cb.early.stop() function if it was passed in to callbacks + callbacks <- Filter( + f = function(cb_func){ + !identical(attr(cb_func, "name"), "cb.early.stop") + } + , x = callbacks + ) + } + + # If user supplied early_stopping_rounds, add the early stopping callback + if (using_early_stopping_via_args){ + callbacks <- add.cb( + callbacks + , cb.early.stop( + stopping_rounds = early_stopping_rounds + , verbose = verbose + ) + ) } # "Categorize" callbacks diff --git a/R-package/tests/testthat/test_parameters.R b/R-package/tests/testthat/test_parameters.R index 60a762de2..83aa4ce9f 100644 --- a/R-package/tests/testthat/test_parameters.R +++ b/R-package/tests/testthat/test_parameters.R @@ -43,3 +43,35 @@ test_that("Feature penalties work properly", { # Ensure that feature is not used when feature_penalty = 0 expect_length(var_gain[[length(var_gain)]], 0) }) + +expect_true(".PARAMETER_ALIASES() returns a named list", { + param_aliases <- .PARAMETER_ALIASES() + expect_true(is.list(param_aliases)) + expect_true(is.character(names(param_aliases))) + expect_true(is.character(param_aliases[["boosting"]])) + expect_true(is.character(param_aliases[["early_stopping_round"]])) + expect_true(is.character(param_aliases[["metric"]])) + expect_true(is.character(param_aliases[["num_class"]])) + expect_true(is.character(param_aliases[["num_iterations"]])) +}) + +expect_true("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", { + for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]){ + expect_warning({ + result <- lightgbm( + data = train$data + , label = train$label + , num_leaves = 5 + , learning_rate = 0.05 + , nrounds = 5 + , objective = "binary" + , metric = "binary_error" + , verbose = -1 + , params = stats::setNames( + object = "dart" + , nm = boosting_param + ) + ) + }, regexp = "Early stopping is not available in 'dart' mode") + } +})