зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Update remaining internal function calls to use keyword arguments (#3617)
* update lgb.convert_with_rules.R to use keyword arguments * update lgb.cv.R to use keyword arguments * update lgb.importance.R to use keyword arguments * update lgb.interprete.R to use keyword arguments * update lgb.plot.interpretation.R to use keyword arguments * update more internal function calls to use keyword arguments * update more internal function calls to use keyword arguments
This commit is contained in:
Родитель
9597326eec
Коммит
ab0d71d699
|
@ -318,7 +318,7 @@ cb.early.stop <- function(stopping_rounds, first_metric_only = FALSE, verbose =
|
|||
|
||||
# Check for empty evaluation
|
||||
if (is.null(eval_len)) {
|
||||
init(env)
|
||||
init(env = env)
|
||||
}
|
||||
|
||||
# Store iteration
|
||||
|
|
|
@ -12,10 +12,10 @@ Booster <- R6::R6Class(
|
|||
finalize = function() {
|
||||
|
||||
# Check the need for freeing handle
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
# Freeing up handle
|
||||
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
|
||||
lgb.call(fun_name = "LGBM_BoosterFree_R", ret = NULL, private$handle)
|
||||
private$handle <- NULL
|
||||
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ Booster <- R6::R6Class(
|
|||
params_str <- lgb.params2str(params = params)
|
||||
# Store booster handle
|
||||
handle <- lgb.call(
|
||||
"LGBM_BoosterCreate_R"
|
||||
fun_name = "LGBM_BoosterCreate_R"
|
||||
, ret = handle
|
||||
, train_set_handle
|
||||
, params_str
|
||||
|
@ -64,7 +64,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Merge booster
|
||||
lgb.call(
|
||||
"LGBM_BoosterMerge_R"
|
||||
fun_name = "LGBM_BoosterMerge_R"
|
||||
, ret = NULL
|
||||
, handle
|
||||
, private$init_predictor$.__enclos_env__$private$handle
|
||||
|
@ -86,7 +86,7 @@ Booster <- R6::R6Class(
|
|||
handle <- lgb.call(
|
||||
fun_name = "LGBM_BoosterCreateFromModelfile_R"
|
||||
, ret = handle
|
||||
, lgb.c_str(modelfile)
|
||||
, lgb.c_str(x = modelfile)
|
||||
)
|
||||
|
||||
} else if (!is.null(model_str)) {
|
||||
|
@ -100,7 +100,7 @@ Booster <- R6::R6Class(
|
|||
handle <- lgb.call(
|
||||
fun_name = "LGBM_BoosterLoadModelFromString_R"
|
||||
, ret = handle
|
||||
, lgb.c_str(model_str)
|
||||
, lgb.c_str(x = model_str)
|
||||
)
|
||||
|
||||
} else {
|
||||
|
@ -116,7 +116,7 @@ Booster <- R6::R6Class(
|
|||
})
|
||||
|
||||
# Check whether the handle was created properly if it was not stopped earlier by a stop call
|
||||
if (isTRUE(lgb.is.null.handle(handle))) {
|
||||
if (isTRUE(lgb.is.null.handle(x = handle))) {
|
||||
|
||||
stop("lgb.Booster: cannot create Booster handle")
|
||||
|
||||
|
@ -168,7 +168,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Add validation data to booster
|
||||
lgb.call(
|
||||
"LGBM_BoosterAddValidData_R"
|
||||
fun_name = "LGBM_BoosterAddValidData_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, data$.__enclos_env__$private$get_handle()
|
||||
|
@ -193,7 +193,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Reset parameters
|
||||
lgb.call(
|
||||
"LGBM_BoosterResetParameter_R"
|
||||
fun_name = "LGBM_BoosterResetParameter_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, params_str
|
||||
|
@ -227,7 +227,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Reset training data on booster
|
||||
lgb.call(
|
||||
"LGBM_BoosterResetTrainingData_R"
|
||||
fun_name = "LGBM_BoosterResetTrainingData_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, train_set$.__enclos_env__$private$get_handle()
|
||||
|
@ -245,7 +245,11 @@ Booster <- R6::R6Class(
|
|||
stop("lgb.Booster.update: cannot update due to null objective function")
|
||||
}
|
||||
# Boost iteration from known objective
|
||||
ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
|
||||
ret <- lgb.call(
|
||||
fun_name = "LGBM_BoosterUpdateOneIter_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
)
|
||||
|
||||
} else {
|
||||
|
||||
|
@ -292,7 +296,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Return one iteration behind
|
||||
lgb.call(
|
||||
"LGBM_BoosterRollbackOneIter_R"
|
||||
fun_name = "LGBM_BoosterRollbackOneIter_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
)
|
||||
|
@ -438,7 +442,7 @@ Booster <- R6::R6Class(
|
|||
, private$handle
|
||||
, as.integer(num_iteration)
|
||||
, as.integer(feature_importance_type)
|
||||
, lgb.c_str(filename)
|
||||
, lgb.c_str(x = filename)
|
||||
)
|
||||
|
||||
return(invisible(self))
|
||||
|
@ -581,7 +585,7 @@ Booster <- R6::R6Class(
|
|||
|
||||
# Use buffer
|
||||
private$predict_buffer[[data_name]] <- lgb.call(
|
||||
"LGBM_BoosterGetPredict_R"
|
||||
fun_name = "LGBM_BoosterGetPredict_R"
|
||||
, ret = private$predict_buffer[[data_name]]
|
||||
, private$handle
|
||||
, as.integer(idx - 1L)
|
||||
|
@ -765,7 +769,7 @@ predict.lgb.Booster <- function(object,
|
|||
reshape = FALSE,
|
||||
...) {
|
||||
|
||||
if (!lgb.is.Booster(object)) {
|
||||
if (!lgb.is.Booster(x = object)) {
|
||||
stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
|
||||
}
|
||||
|
||||
|
@ -877,7 +881,7 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
|
|||
#' @export
|
||||
lgb.save <- function(booster, filename, num_iteration = NULL) {
|
||||
|
||||
if (!lgb.is.Booster(booster)) {
|
||||
if (!lgb.is.Booster(x = booster)) {
|
||||
stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
|
||||
}
|
||||
|
||||
|
@ -926,7 +930,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
|
|||
#' @export
|
||||
lgb.dump <- function(booster, num_iteration = NULL) {
|
||||
|
||||
if (!lgb.is.Booster(booster)) {
|
||||
if (!lgb.is.Booster(x = booster)) {
|
||||
stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
|
||||
}
|
||||
|
||||
|
@ -981,7 +985,7 @@ lgb.dump <- function(booster, num_iteration = NULL) {
|
|||
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
|
||||
|
||||
# Check if booster is booster
|
||||
if (!lgb.is.Booster(booster)) {
|
||||
if (!lgb.is.Booster(x = booster)) {
|
||||
stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ Dataset <- R6::R6Class(
|
|||
finalize = function() {
|
||||
|
||||
# Check the need for freeing handle
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
# Freeing up handle
|
||||
lgb.call(fun_name = "LGBM_DatasetFree_R", ret = NULL, private$handle)
|
||||
|
@ -113,7 +113,7 @@ Dataset <- R6::R6Class(
|
|||
construct = function() {
|
||||
|
||||
# Check for handle null
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
return(invisible(self))
|
||||
}
|
||||
|
||||
|
@ -196,7 +196,7 @@ Dataset <- R6::R6Class(
|
|||
handle <- lgb.call(
|
||||
fun_name = "LGBM_DatasetCreateFromFile_R"
|
||||
, ret = handle
|
||||
, lgb.c_str(private$raw_data)
|
||||
, lgb.c_str(x = private$raw_data)
|
||||
, params_str
|
||||
, ref_handle
|
||||
)
|
||||
|
@ -260,7 +260,7 @@ Dataset <- R6::R6Class(
|
|||
)
|
||||
|
||||
}
|
||||
if (lgb.is.null.handle(handle)) {
|
||||
if (lgb.is.null.handle(x = handle)) {
|
||||
stop("lgb.Dataset.construct: cannot create Dataset handle")
|
||||
}
|
||||
# Setup class and private type
|
||||
|
@ -269,7 +269,7 @@ Dataset <- R6::R6Class(
|
|||
|
||||
# Set feature names
|
||||
if (!is.null(private$colnames)) {
|
||||
self$set_colnames(private$colnames)
|
||||
self$set_colnames(colnames = private$colnames)
|
||||
}
|
||||
|
||||
# Load init score if requested
|
||||
|
@ -307,7 +307,7 @@ Dataset <- R6::R6Class(
|
|||
}
|
||||
|
||||
# Get label information existence
|
||||
if (is.null(self$getinfo("label"))) {
|
||||
if (is.null(self$getinfo(name = "label"))) {
|
||||
stop("lgb.Dataset.construct: label should be set")
|
||||
}
|
||||
|
||||
|
@ -319,7 +319,7 @@ Dataset <- R6::R6Class(
|
|||
dim = function() {
|
||||
|
||||
# Check for handle
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
num_row <- 0L
|
||||
num_col <- 0L
|
||||
|
@ -360,7 +360,7 @@ Dataset <- R6::R6Class(
|
|||
get_colnames = function() {
|
||||
|
||||
# Check for handle
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
# Get feature names and write them
|
||||
cnames <- lgb.call.return.str(
|
||||
|
@ -403,7 +403,7 @@ Dataset <- R6::R6Class(
|
|||
|
||||
# Write column names
|
||||
private$colnames <- colnames
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
if (!lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
# Merge names with tab separation
|
||||
merged_name <- paste0(as.list(private$colnames), collapse = "\t")
|
||||
|
@ -411,7 +411,7 @@ Dataset <- R6::R6Class(
|
|||
fun_name = "LGBM_DatasetSetFeatureNames_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, lgb.c_str(merged_name)
|
||||
, lgb.c_str(x = merged_name)
|
||||
)
|
||||
|
||||
}
|
||||
|
@ -434,7 +434,7 @@ Dataset <- R6::R6Class(
|
|||
# Check for info name and handle
|
||||
if (is.null(private$info[[name]])) {
|
||||
|
||||
if (lgb.is.null.handle(private$handle)) {
|
||||
if (lgb.is.null.handle(x = private$handle)) {
|
||||
stop("Cannot perform getinfo before constructing Dataset.")
|
||||
}
|
||||
|
||||
|
@ -444,7 +444,7 @@ Dataset <- R6::R6Class(
|
|||
fun_name = "LGBM_DatasetGetFieldSize_R"
|
||||
, ret = info_len
|
||||
, private$handle
|
||||
, lgb.c_str(name)
|
||||
, lgb.c_str(x = name)
|
||||
)
|
||||
|
||||
# Check if info is not empty
|
||||
|
@ -462,7 +462,7 @@ Dataset <- R6::R6Class(
|
|||
fun_name = "LGBM_DatasetGetField_R"
|
||||
, ret = ret
|
||||
, private$handle
|
||||
, lgb.c_str(name)
|
||||
, lgb.c_str(x = name)
|
||||
)
|
||||
|
||||
private$info[[name]] <- ret
|
||||
|
@ -495,7 +495,7 @@ Dataset <- R6::R6Class(
|
|||
# Store information privately
|
||||
private$info[[name]] <- info
|
||||
|
||||
if (!lgb.is.null.handle(private$handle) && !is.null(info)) {
|
||||
if (!lgb.is.null.handle(x = private$handle) && !is.null(info)) {
|
||||
|
||||
if (length(info) > 0L) {
|
||||
|
||||
|
@ -503,7 +503,7 @@ Dataset <- R6::R6Class(
|
|||
fun_name = "LGBM_DatasetSetField_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, lgb.c_str(name)
|
||||
, lgb.c_str(x = name)
|
||||
, info
|
||||
, length(info)
|
||||
)
|
||||
|
@ -542,7 +542,7 @@ Dataset <- R6::R6Class(
|
|||
if (length(params) == 0L) {
|
||||
return(invisible(self))
|
||||
}
|
||||
if (lgb.is.null.handle(private$handle)) {
|
||||
if (lgb.is.null.handle(x = private$handle)) {
|
||||
private$params <- modifyList(private$params, params)
|
||||
} else {
|
||||
call_state <- 0L
|
||||
|
@ -608,9 +608,9 @@ Dataset <- R6::R6Class(
|
|||
set_reference = function(reference) {
|
||||
|
||||
# Set known references
|
||||
self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature)
|
||||
self$set_colnames(reference$get_colnames())
|
||||
private$set_predictor(reference$.__enclos_env__$private$predictor)
|
||||
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
|
||||
self$set_colnames(colnames = reference$get_colnames())
|
||||
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)
|
||||
|
||||
# Check for identical references
|
||||
if (identical(private$reference, reference)) {
|
||||
|
@ -653,7 +653,7 @@ Dataset <- R6::R6Class(
|
|||
fun_name = "LGBM_DatasetSaveBinary_R"
|
||||
, ret = NULL
|
||||
, private$handle
|
||||
, lgb.c_str(fname)
|
||||
, lgb.c_str(x = fname)
|
||||
)
|
||||
return(invisible(self))
|
||||
}
|
||||
|
@ -676,7 +676,7 @@ Dataset <- R6::R6Class(
|
|||
get_handle = function() {
|
||||
|
||||
# Get handle and construct if needed
|
||||
if (lgb.is.null.handle(private$handle)) {
|
||||
if (lgb.is.null.handle(x = private$handle)) {
|
||||
self$construct()
|
||||
}
|
||||
private$handle
|
||||
|
@ -791,7 +791,7 @@ lgb.Dataset <- function(data,
|
|||
lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("lgb.Dataset.create.valid: input data should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -817,7 +817,7 @@ lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) {
|
|||
lgb.Dataset.construct <- function(dataset) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("lgb.Dataset.construct: input data should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -852,7 +852,7 @@ lgb.Dataset.construct <- function(dataset) {
|
|||
dim.lgb.Dataset <- function(x, ...) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(x)) {
|
||||
if (!lgb.is.Dataset(x = x)) {
|
||||
stop("dim.lgb.Dataset: input data should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -888,7 +888,7 @@ dim.lgb.Dataset <- function(x, ...) {
|
|||
dimnames.lgb.Dataset <- function(x) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(x)) {
|
||||
if (!lgb.is.Dataset(x = x)) {
|
||||
stop("dimnames.lgb.Dataset: input data should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -914,7 +914,7 @@ dimnames.lgb.Dataset <- function(x) {
|
|||
|
||||
if (is.null(value[[2L]])) {
|
||||
|
||||
x$set_colnames(NULL)
|
||||
x$set_colnames(colnames = NULL)
|
||||
return(x)
|
||||
|
||||
}
|
||||
|
@ -931,7 +931,7 @@ dimnames.lgb.Dataset <- function(x) {
|
|||
}
|
||||
|
||||
# Set column names properly, and return
|
||||
x$set_colnames(value[[2L]])
|
||||
x$set_colnames(colnames = value[[2L]])
|
||||
x
|
||||
|
||||
}
|
||||
|
@ -965,7 +965,7 @@ slice <- function(dataset, ...) {
|
|||
slice.lgb.Dataset <- function(dataset, idxset, ...) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("slice.lgb.Dataset: input dataset should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -1016,11 +1016,11 @@ getinfo <- function(dataset, ...) {
|
|||
getinfo.lgb.Dataset <- function(dataset, name, ...) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
dataset$getinfo(name)
|
||||
dataset$getinfo(name = name)
|
||||
|
||||
}
|
||||
|
||||
|
@ -1069,7 +1069,7 @@ setinfo <- function(dataset, ...) {
|
|||
#' @export
|
||||
setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
|
||||
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -1101,12 +1101,12 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
|
|||
#' @export
|
||||
lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
|
||||
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
# Set categoricals
|
||||
invisible(dataset$set_categorical_feature(categorical_feature))
|
||||
invisible(dataset$set_categorical_feature(categorical_feature = categorical_feature))
|
||||
|
||||
}
|
||||
|
||||
|
@ -1133,12 +1133,12 @@ lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
|
|||
lgb.Dataset.set.reference <- function(dataset, reference) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
# Set reference
|
||||
invisible(dataset$set_reference(reference))
|
||||
invisible(dataset$set_reference(reference = reference))
|
||||
}
|
||||
|
||||
#' @name lgb.Dataset.save
|
||||
|
@ -1161,7 +1161,7 @@ lgb.Dataset.set.reference <- function(dataset, reference) {
|
|||
lgb.Dataset.save <- function(dataset, fname) {
|
||||
|
||||
# Check if dataset is not a dataset
|
||||
if (!lgb.is.Dataset(dataset)) {
|
||||
if (!lgb.is.Dataset(x = dataset)) {
|
||||
stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object")
|
||||
}
|
||||
|
||||
|
@ -1171,5 +1171,5 @@ lgb.Dataset.save <- function(dataset, fname) {
|
|||
}
|
||||
|
||||
# Store binary
|
||||
invisible(dataset$save_binary(fname))
|
||||
invisible(dataset$save_binary(fname = fname))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ Predictor <- R6::R6Class(
|
|||
finalize = function() {
|
||||
|
||||
# Check the need for freeing handle
|
||||
if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
|
||||
if (private$need_free_handle && !lgb.is.null.handle(x = private$handle)) {
|
||||
|
||||
# Freeing up handle
|
||||
lgb.call(
|
||||
|
@ -28,7 +28,7 @@ Predictor <- R6::R6Class(
|
|||
# Initialize will create a starter model
|
||||
initialize = function(modelfile, ...) {
|
||||
params <- list(...)
|
||||
private$params <- lgb.params2str(params)
|
||||
private$params <- lgb.params2str(params = params)
|
||||
# Create new lgb handle
|
||||
handle <- lgb.null.handle()
|
||||
|
||||
|
@ -39,7 +39,7 @@ Predictor <- R6::R6Class(
|
|||
handle <- lgb.call(
|
||||
fun_name = "LGBM_BoosterCreateFromModelfile_R"
|
||||
, ret = handle
|
||||
, lgb.c_str(modelfile)
|
||||
, lgb.c_str(x = modelfile)
|
||||
)
|
||||
private$need_free_handle <- TRUE
|
||||
|
||||
|
@ -114,7 +114,7 @@ Predictor <- R6::R6Class(
|
|||
, as.integer(start_iteration)
|
||||
, as.integer(num_iteration)
|
||||
, private$params
|
||||
, lgb.c_str(tmp_filename)
|
||||
, lgb.c_str(x = tmp_filename)
|
||||
)
|
||||
|
||||
# Get predictions from file
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# and lgb.convert_with_rules() too warn if more action is needed by users
|
||||
# before a dataset can be converted to a lgb.Dataset.
|
||||
.warn_for_unconverted_columns <- function(df, function_name) {
|
||||
column_classes <- .get_column_classes(df)
|
||||
column_classes <- .get_column_classes(df = df)
|
||||
unconverted_columns <- column_classes[!(column_classes %in% c("numeric", "integer"))]
|
||||
if (length(unconverted_columns) > 0L) {
|
||||
col_detail_string <- paste0(
|
||||
|
@ -109,13 +109,13 @@
|
|||
#' @export
|
||||
lgb.convert_with_rules <- function(data, rules = NULL) {
|
||||
|
||||
column_classes <- .get_column_classes(data)
|
||||
column_classes <- .get_column_classes(df = data)
|
||||
|
||||
is_char <- which(column_classes == "character")
|
||||
is_factor <- which(column_classes == "factor")
|
||||
is_logical <- which(column_classes == "logical")
|
||||
|
||||
is_data_table <- data.table::is.data.table(data)
|
||||
is_data_table <- data.table::is.data.table(x = data)
|
||||
is_data_frame <- is.data.frame(data)
|
||||
|
||||
if (!(is_data_table || is_data_frame)) {
|
||||
|
@ -166,7 +166,7 @@ lgb.convert_with_rules <- function(data, rules = NULL) {
|
|||
}
|
||||
if (is_data_table) {
|
||||
data.table::set(
|
||||
data
|
||||
x = data
|
||||
, j = col_name
|
||||
, value = unname(rules[[col_name]][data[[col_name]]])
|
||||
)
|
||||
|
|
|
@ -96,18 +96,18 @@ lgb.cv <- function(params = list()
|
|||
}
|
||||
|
||||
# If 'data' is not an lgb.Dataset, try to construct one using 'label'
|
||||
if (!lgb.is.Dataset(data)) {
|
||||
if (!lgb.is.Dataset(x = data)) {
|
||||
if (is.null(label)) {
|
||||
stop("'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'")
|
||||
}
|
||||
data <- lgb.Dataset(data, label = label)
|
||||
data <- lgb.Dataset(data = data, label = label)
|
||||
}
|
||||
|
||||
# Setup temporary variables
|
||||
params <- append(params, list(...))
|
||||
params$verbose <- verbose
|
||||
params <- lgb.check.obj(params, obj)
|
||||
params <- lgb.check.eval(params, eval)
|
||||
params <- lgb.check.obj(params = params, obj = obj)
|
||||
params <- lgb.check.eval(params = params, eval = eval)
|
||||
fobj <- NULL
|
||||
eval_functions <- list(NULL)
|
||||
|
||||
|
@ -152,8 +152,8 @@ lgb.cv <- function(params = list()
|
|||
|
||||
# Check for boosting from a trained model
|
||||
if (is.character(init_model)) {
|
||||
predictor <- Predictor$new(init_model)
|
||||
} else if (lgb.is.Booster(init_model)) {
|
||||
predictor <- Predictor$new(modelfile = init_model)
|
||||
} else if (lgb.is.Booster(x = init_model)) {
|
||||
predictor <- init_model$to_predictor()
|
||||
}
|
||||
|
||||
|
@ -175,27 +175,27 @@ lgb.cv <- function(params = list()
|
|||
} else if (!is.null(data$get_colnames())) {
|
||||
cnames <- data$get_colnames()
|
||||
}
|
||||
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames)
|
||||
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params = params, column_names = cnames)
|
||||
|
||||
# Check for weights
|
||||
if (!is.null(weight)) {
|
||||
data$setinfo("weight", weight)
|
||||
data$setinfo(name = "weight", info = weight)
|
||||
}
|
||||
|
||||
# Update parameters with parsed parameters
|
||||
data$update_params(params)
|
||||
data$update_params(params = params)
|
||||
|
||||
# Create the predictor set
|
||||
data$.__enclos_env__$private$set_predictor(predictor)
|
||||
data$.__enclos_env__$private$set_predictor(predictor = predictor)
|
||||
|
||||
# Write column names
|
||||
if (!is.null(colnames)) {
|
||||
data$set_colnames(colnames)
|
||||
data$set_colnames(colnames = colnames)
|
||||
}
|
||||
|
||||
# Write categorical features
|
||||
if (!is.null(categorical_feature)) {
|
||||
data$set_categorical_feature(categorical_feature)
|
||||
data$set_categorical_feature(categorical_feature = categorical_feature)
|
||||
}
|
||||
|
||||
# Check for folds
|
||||
|
@ -221,8 +221,8 @@ lgb.cv <- function(params = list()
|
|||
nfold = nfold
|
||||
, nrows = nrow(data)
|
||||
, stratified = stratified
|
||||
, label = getinfo(data, "label")
|
||||
, group = getinfo(data, "group")
|
||||
, label = getinfo(dataset = data, name = "label")
|
||||
, group = getinfo(dataset = data, name = "group")
|
||||
, params = params
|
||||
)
|
||||
|
||||
|
@ -230,12 +230,12 @@ lgb.cv <- function(params = list()
|
|||
|
||||
# Add printing log callback
|
||||
if (verbose > 0L && eval_freq > 0L) {
|
||||
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
|
||||
callbacks <- add.cb(cb_list = callbacks, cb = cb.print.evaluation(period = eval_freq))
|
||||
}
|
||||
|
||||
# Add evaluation log callback
|
||||
if (record) {
|
||||
callbacks <- add.cb(callbacks, cb.record.evaluation())
|
||||
callbacks <- add.cb(cb_list = callbacks, cb = cb.record.evaluation())
|
||||
}
|
||||
|
||||
# Did user pass parameters that indicate they want to use early stopping?
|
||||
|
@ -268,8 +268,8 @@ lgb.cv <- function(params = list()
|
|||
# If user supplied early_stopping_rounds, add the early stopping callback
|
||||
if (using_early_stopping) {
|
||||
callbacks <- add.cb(
|
||||
callbacks
|
||||
, cb.early.stop(
|
||||
cb_list = callbacks
|
||||
, cb = cb.early.stop(
|
||||
stopping_rounds = early_stopping_rounds
|
||||
, first_metric_only = isTRUE(params[["first_metric_only"]])
|
||||
, verbose = verbose
|
||||
|
@ -277,7 +277,7 @@ lgb.cv <- function(params = list()
|
|||
)
|
||||
}
|
||||
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
cb <- categorize.callbacks(cb_list = callbacks)
|
||||
|
||||
# Construct booster for each fold. The data.table() code below is used to
|
||||
# guarantee that indices are sorted while keeping init_score and weight together
|
||||
|
@ -296,8 +296,8 @@ lgb.cv <- function(params = list()
|
|||
if (folds_have_group) {
|
||||
test_indices <- folds[[k]]$fold
|
||||
test_group_indices <- folds[[k]]$group
|
||||
test_groups <- getinfo(data, "group")[test_group_indices]
|
||||
train_groups <- getinfo(data, "group")[-test_group_indices]
|
||||
test_groups <- getinfo(dataset = data, name = "group")[test_group_indices]
|
||||
train_groups <- getinfo(dataset = data, name = "group")[-test_group_indices]
|
||||
} else {
|
||||
test_indices <- folds[[k]]
|
||||
}
|
||||
|
@ -306,32 +306,32 @@ lgb.cv <- function(params = list()
|
|||
# set up test set
|
||||
indexDT <- data.table::data.table(
|
||||
indices = test_indices
|
||||
, weight = getinfo(data, "weight")[test_indices]
|
||||
, init_score = getinfo(data, "init_score")[test_indices]
|
||||
, weight = getinfo(dataset = data, name = "weight")[test_indices]
|
||||
, init_score = getinfo(dataset = data, name = "init_score")[test_indices]
|
||||
)
|
||||
data.table::setorderv(indexDT, "indices", order = 1L)
|
||||
data.table::setorderv(x = indexDT, cols = "indices", order = 1L)
|
||||
dtest <- slice(data, indexDT$indices)
|
||||
setinfo(dtest, "weight", indexDT$weight)
|
||||
setinfo(dtest, "init_score", indexDT$init_score)
|
||||
setinfo(dataset = dtest, name = "weight", info = indexDT$weight)
|
||||
setinfo(dataset = dtest, name = "init_score", info = indexDT$init_score)
|
||||
|
||||
# set up training set
|
||||
indexDT <- data.table::data.table(
|
||||
indices = train_indices
|
||||
, weight = getinfo(data, "weight")[train_indices]
|
||||
, init_score = getinfo(data, "init_score")[train_indices]
|
||||
, weight = getinfo(data = data, name = "weight")[train_indices]
|
||||
, init_score = getinfo(data = data, name = "init_score")[train_indices]
|
||||
)
|
||||
data.table::setorderv(indexDT, "indices", order = 1L)
|
||||
data.table::setorderv(x = indexDT, cols = "indices", order = 1L)
|
||||
dtrain <- slice(data, indexDT$indices)
|
||||
setinfo(dtrain, "weight", indexDT$weight)
|
||||
setinfo(dtrain, "init_score", indexDT$init_score)
|
||||
setinfo(dataset = dtrain, name = "weight", info = indexDT$weight)
|
||||
setinfo(dataset = dtrain, name = "init_score", info = indexDT$init_score)
|
||||
|
||||
if (folds_have_group) {
|
||||
setinfo(dtest, "group", test_groups)
|
||||
setinfo(dtrain, "group", train_groups)
|
||||
setinfo(dataset = dtest, name = "group", info = test_groups)
|
||||
setinfo(dataset = dtrain, name = "group", info = train_groups)
|
||||
}
|
||||
|
||||
booster <- Booster$new(params, dtrain)
|
||||
booster$add_valid(dtest, "valid")
|
||||
booster <- Booster$new(params = params, train_set = dtrain)
|
||||
booster$add_valid(data = dtest, name = "valid")
|
||||
return(
|
||||
list(booster = booster)
|
||||
)
|
||||
|
@ -339,7 +339,7 @@ lgb.cv <- function(params = list()
|
|||
)
|
||||
|
||||
# Create new booster
|
||||
cv_booster <- CVBooster$new(bst_folds)
|
||||
cv_booster <- CVBooster$new(x = bst_folds)
|
||||
|
||||
# Callback env
|
||||
env <- CB_ENV$new()
|
||||
|
@ -369,7 +369,7 @@ lgb.cv <- function(params = list()
|
|||
})
|
||||
|
||||
# Prepare collection of evaluation results
|
||||
merged_msg <- lgb.merge.cv.result(msg)
|
||||
merged_msg <- lgb.merge.cv.result(msg = msg)
|
||||
|
||||
# Write evaluation result in environment
|
||||
env$eval_list <- merged_msg$eval_list
|
||||
|
@ -447,7 +447,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
|
|||
|
||||
y <- label[rnd_idx]
|
||||
y <- as.factor(y)
|
||||
folds <- lgb.stratified.folds(y, nfold)
|
||||
folds <- lgb.stratified.folds(y = y, k = nfold)
|
||||
|
||||
} else {
|
||||
|
||||
|
|
|
@ -39,12 +39,12 @@
|
|||
lgb.importance <- function(model, percentage = TRUE) {
|
||||
|
||||
# Check if model is a lightgbm model
|
||||
if (!lgb.is.Booster(model)) {
|
||||
if (!lgb.is.Booster(x = model)) {
|
||||
stop("'model' has to be an object of class lgb.Booster")
|
||||
}
|
||||
|
||||
# Setup importance
|
||||
tree_dt <- lgb.model.dt.tree(model)
|
||||
tree_dt <- lgb.model.dt.tree(model = model)
|
||||
|
||||
# Extract elements
|
||||
tree_imp_dt <- tree_dt[
|
||||
|
@ -54,7 +54,7 @@ lgb.importance <- function(model, percentage = TRUE) {
|
|||
]
|
||||
|
||||
data.table::setnames(
|
||||
tree_imp_dt
|
||||
x = tree_imp_dt
|
||||
, old = "split_feature"
|
||||
, new = "Feature"
|
||||
)
|
||||
|
|
|
@ -48,7 +48,7 @@ lgb.interprete <- function(model,
|
|||
num_iteration = NULL) {
|
||||
|
||||
# Get tree model
|
||||
tree_dt <- lgb.model.dt.tree(model, num_iteration)
|
||||
tree_dt <- lgb.model.dt.tree(model = model, num_iteration = num_iteration)
|
||||
|
||||
# Check number of classes
|
||||
num_class <- model$.__enclos_env__$private$num_class
|
||||
|
@ -59,12 +59,12 @@ lgb.interprete <- function(model,
|
|||
# Get parsed predictions of data
|
||||
pred_mat <- t(
|
||||
model$predict(
|
||||
data[idxset, , drop = FALSE]
|
||||
data = data[idxset, , drop = FALSE]
|
||||
, num_iteration = num_iteration
|
||||
, predleaf = TRUE
|
||||
)
|
||||
)
|
||||
leaf_index_dt <- data.table::as.data.table(pred_mat)
|
||||
leaf_index_dt <- data.table::as.data.table(x = pred_mat)
|
||||
leaf_index_mat_list <- lapply(
|
||||
X = leaf_index_dt
|
||||
, FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
|
||||
|
@ -81,10 +81,10 @@ lgb.interprete <- function(model,
|
|||
# Sequence over idxset
|
||||
for (i in seq_along(idxset)) {
|
||||
tree_interpretation_dt_list[[i]] <- single.row.interprete(
|
||||
tree_dt
|
||||
, num_class
|
||||
, tree_index_mat_list[[i]]
|
||||
, leaf_index_mat_list[[i]]
|
||||
tree_dt = tree_dt
|
||||
, num_class = num_class
|
||||
, tree_index_mat = tree_index_mat_list[[i]]
|
||||
, leaf_index_mat = leaf_index_mat_list[[i]]
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -122,14 +122,20 @@ single.tree.interprete <- function(tree_dt,
|
|||
# Not null means existing node
|
||||
this_node <- node_dt[split_index == parent_id, ]
|
||||
feature_seq <<- c(this_node[["split_feature"]], feature_seq)
|
||||
leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
|
||||
leaf_to_root(
|
||||
parent_id = this_node[["node_parent"]]
|
||||
, current_value = this_node[["internal_value"]]
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
# Perform leaf to root conversion
|
||||
leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
|
||||
leaf_to_root(
|
||||
parent_id = leaf_dt[["leaf_parent"]]
|
||||
, current_value = leaf_dt[["leaf_value"]]
|
||||
)
|
||||
|
||||
data.table::data.table(
|
||||
Feature = feature_seq
|
||||
|
@ -191,7 +197,7 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index
|
|||
|
||||
if (num_class > 1L) {
|
||||
data.table::setnames(
|
||||
next_interp_dt
|
||||
x = next_interp_dt
|
||||
, old = "Contribution"
|
||||
, new = paste("Class", i - 1L)
|
||||
)
|
||||
|
@ -221,7 +227,7 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index
|
|||
for (j in 2L:ncol(tree_interpretation_dt)) {
|
||||
|
||||
data.table::set(
|
||||
tree_interpretation_dt
|
||||
x = tree_interpretation_dt
|
||||
, i = which(is.na(tree_interpretation_dt[[j]]))
|
||||
, j = j
|
||||
, value = 0.0
|
||||
|
|
|
@ -87,7 +87,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
|
|||
|
||||
# Only one class, plot straight away
|
||||
multiple.tree.plot.interpretation(
|
||||
tree_interpretation_dt
|
||||
tree_interpretation = tree_interpretation_dt
|
||||
, top_n = top_n
|
||||
, title = NULL
|
||||
, cex = cex
|
||||
|
@ -111,12 +111,12 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
|
|||
# Prepare interpretation, perform T, get the names, and plot straight away
|
||||
plot_dt <- tree_interpretation_dt[, c(1L, i + 1L), with = FALSE]
|
||||
data.table::setnames(
|
||||
plot_dt
|
||||
x = plot_dt
|
||||
, old = names(plot_dt)
|
||||
, new = c("Feature", "Contribution")
|
||||
)
|
||||
multiple.tree.plot.interpretation(
|
||||
plot_dt
|
||||
tree_interpretation = plot_dt
|
||||
, top_n = top_n
|
||||
, title = paste("Class", i - 1L)
|
||||
, cex = cex
|
||||
|
|
|
@ -67,7 +67,7 @@ lgb.train <- function(params = list(),
|
|||
if (nrounds <= 0L) {
|
||||
stop("nrounds should be greater than zero")
|
||||
}
|
||||
if (!lgb.is.Dataset(data)) {
|
||||
if (!lgb.is.Dataset(x = data)) {
|
||||
stop("lgb.train: data must be an lgb.Dataset instance")
|
||||
}
|
||||
if (length(valids) > 0L) {
|
||||
|
@ -131,7 +131,7 @@ lgb.train <- function(params = list(),
|
|||
# Check for boosting from a trained model
|
||||
if (is.character(init_model)) {
|
||||
predictor <- Predictor$new(init_model)
|
||||
} else if (lgb.is.Booster(init_model)) {
|
||||
} else if (lgb.is.Booster(x = init_model)) {
|
||||
predictor <- init_model$to_predictor()
|
||||
}
|
||||
|
||||
|
@ -153,7 +153,10 @@ lgb.train <- function(params = list(),
|
|||
} else if (!is.null(data$get_colnames())) {
|
||||
cnames <- data$get_colnames()
|
||||
}
|
||||
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(params, cnames)
|
||||
params[["interaction_constraints"]] <- lgb.check_interaction_constraints(
|
||||
params = params
|
||||
, column_names = cnames
|
||||
)
|
||||
|
||||
# Update parameters with parsed parameters
|
||||
data$update_params(params)
|
||||
|
@ -202,12 +205,12 @@ lgb.train <- function(params = list(),
|
|||
|
||||
# Add printing log callback
|
||||
if (verbose > 0L && eval_freq > 0L) {
|
||||
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
|
||||
callbacks <- add.cb(cb_list = callbacks, cb = cb.print.evaluation(period = eval_freq))
|
||||
}
|
||||
|
||||
# Add evaluation log callback
|
||||
if (record && length(valids) > 0L) {
|
||||
callbacks <- add.cb(callbacks, cb.record.evaluation())
|
||||
callbacks <- add.cb(cb_list = callbacks, cb = cb.record.evaluation())
|
||||
}
|
||||
|
||||
# Did user pass parameters that indicate they want to use early stopping?
|
||||
|
|
|
@ -121,7 +121,7 @@ lightgbm <- function(data,
|
|||
dtrain <- data
|
||||
|
||||
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
|
||||
if (!lgb.is.Dataset(dtrain)) {
|
||||
if (!lgb.is.Dataset(x = dtrain)) {
|
||||
dtrain <- lgb.Dataset(data = data, label = label, weight = weight)
|
||||
}
|
||||
|
||||
|
|
|
@ -95,13 +95,13 @@ lgb.call.return.str <- function(fun_name, ...) {
|
|||
buf <- raw(buf_len)
|
||||
|
||||
# Call buffer
|
||||
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
|
||||
buf <- lgb.call(fun_name = fun_name, ret = buf, ..., buf_len, act_len)
|
||||
|
||||
# Check for buffer content
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
buf <- raw(buf_len)
|
||||
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
|
||||
buf <- lgb.call(fun_name = fun_name, ret = buf, ..., buf_len, act_len)
|
||||
}
|
||||
|
||||
return(lgb.encode.char(arr = buf, len = act_len))
|
||||
|
|
Загрузка…
Ссылка в новой задаче