зеркало из https://github.com/microsoft/LightGBM.git
R package (#168)
* finish R's c_api * clean code * fix sizeof pointer in 32bit system. * add predictor class * add Dataset class * format code * add booster * add type check for expose function * add a simple callback * add all callbacks * finish the basic training logic * update docs * add an simple training interface * add basic test * adapt the changes in c_api * add test for Dataset * add test for custom obj/eval functions * fix python test * fix bug in metadata init * fix R CMD check
This commit is contained in:
Родитель
acbd4f34d4
Коммит
551d59ca71
|
@ -0,0 +1,37 @@
|
|||
Package: lightgbm
|
||||
Type: Package
|
||||
Title: Light Gradient Boosting Machine
|
||||
Version: 0.1
|
||||
Date: 2016-12-29
|
||||
Author: Guolin Ke <guolin.ke@microsoft.com>
|
||||
Maintainer: Guolin Ke <guolin.ke@microsoft.com>
|
||||
Description: LightGBM is a gradient boosting framework that uses tree based learning algorithms.
|
||||
It is designed to be distributed and efficient with the following advantages:
|
||||
1.Faster training speed and higher efficiency.
|
||||
2.Lower memory usage.
|
||||
3.Better accuracy.
|
||||
4.Parallel learning supported
|
||||
5. Capable of handling large-scale data
|
||||
License: The MIT License (MIT) | file LICENSE
|
||||
URL: https://github.com/Microsoft/LightGBM
|
||||
BugReports: https://github.com/Microsoft/LightGBM/issues
|
||||
VignetteBuilder: knitr
|
||||
Suggests:
|
||||
knitr,
|
||||
rmarkdown,
|
||||
ggplot2 (>= 1.0.1),
|
||||
DiagrammeR (>= 0.8.1),
|
||||
Ckmeans.1d.dp (>= 3.3.1),
|
||||
vcd (>= 1.3),
|
||||
testthat,
|
||||
igraph (>= 1.0.1),
|
||||
methods,
|
||||
data.table (>= 1.9.6),
|
||||
magrittr (>= 1.5),
|
||||
stringi (>= 0.5.2)
|
||||
Depends:
|
||||
R (>= 3.0),
|
||||
R6
|
||||
Imports:
|
||||
Matrix (>= 1.1-0)
|
||||
RoxygenNote: 5.0.1
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) Microsoft Corporation
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# Generated by roxygen2: do not edit by hand
|
||||
|
||||
S3method("dimnames<-",lgb.Dataset)
|
||||
S3method(dim,lgb.Dataset)
|
||||
S3method(dimnames,lgb.Dataset)
|
||||
S3method(getinfo,lgb.Dataset)
|
||||
S3method(predict,lgb.Booster)
|
||||
S3method(setinfo,lgb.Dataset)
|
||||
S3method(slice,lgb.Dataset)
|
||||
export(getinfo)
|
||||
export(lgb.Dataset)
|
||||
export(lgb.Dataset.construct)
|
||||
export(lgb.Dataset.create.valid)
|
||||
export(lgb.Dataset.save)
|
||||
export(lgb.Dataset.set.categorical)
|
||||
export(lgb.Dataset.set.reference)
|
||||
export(lgb.dump)
|
||||
export(lgb.get.eval.result)
|
||||
export(lgb.load)
|
||||
export(lgb.save)
|
||||
export(lgb.train)
|
||||
export(lightgbm)
|
||||
export(setinfo)
|
||||
export(slice)
|
||||
importFrom(R6,R6Class)
|
||||
useDynLib(lightgbm)
|
|
@ -0,0 +1,249 @@
|
|||
CB_ENV <- R6Class(
|
||||
"lgb.cb_env",
|
||||
cloneable=FALSE,
|
||||
public = list(
|
||||
model=NULL,
|
||||
iteration=NULL,
|
||||
begin_iteration=NULL,
|
||||
end_iteration=NULL,
|
||||
eval_list=list(),
|
||||
eval_err_list=list(),
|
||||
best_iter=-1,
|
||||
met_early_stop=FALSE
|
||||
)
|
||||
)
|
||||
|
||||
cb.reset.parameters <- function(new_params) {
|
||||
if (typeof(new_params) != "list")
|
||||
stop("'new_params' must be a list")
|
||||
pnames <- gsub("\\.", "_", names(new_params))
|
||||
nrounds <- NULL
|
||||
|
||||
# run some checks in the begining
|
||||
init <- function(env) {
|
||||
nrounds <<- env$end_iteration - env$begin_iteration + 1
|
||||
|
||||
if (is.null(env$model))
|
||||
stop("Env should has 'model'")
|
||||
|
||||
# Some parameters are not allowed to be changed,
|
||||
# since changing them would simply wreck some chaos
|
||||
not_allowed <- pnames %in%
|
||||
c('num_class', 'metric', 'boosting_type')
|
||||
if (any(not_allowed))
|
||||
stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.")
|
||||
|
||||
for (n in pnames) {
|
||||
p <- new_params[[n]]
|
||||
if (is.function(p)) {
|
||||
if (length(formals(p)) != 2)
|
||||
stop("Parameter '", n, "' is a function but not of two arguments")
|
||||
} else if (is.numeric(p) || is.character(p)) {
|
||||
if (length(p) != nrounds)
|
||||
stop("Length of '", n, "' has to be equal to 'nrounds'")
|
||||
} else {
|
||||
stop("Parameter '", n, "' is not a function or a vector")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
callback <- function(env) {
|
||||
if (is.null(nrounds))
|
||||
init(env)
|
||||
|
||||
i <- env$iteration - env$begin_iteration
|
||||
pars <- lapply(new_params, function(p) {
|
||||
if (is.function(p))
|
||||
return(p(i, nrounds))
|
||||
p[i]
|
||||
})
|
||||
# to-do check pars
|
||||
if (!is.null(env$model)) {
|
||||
env$model$reset_parameter(pars)
|
||||
}
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'is_pre_iteration') <- TRUE
|
||||
attr(callback, 'name') <- 'cb.reset.parameters'
|
||||
return(callback)
|
||||
}
|
||||
|
||||
# Format the evaluation metric string
|
||||
format.eval.string <- function(eval_res, eval_err=NULL) {
|
||||
if (is.null(eval_res))
|
||||
stop('no evaluation results')
|
||||
if (length(eval_res) == 0)
|
||||
stop('no evaluation results')
|
||||
if (!is.null(eval_err)) {
|
||||
res <- sprintf('%s\'s %s:%g+%g', eval_res$data_name, eval_res$name, eval_res$value, eval_err)
|
||||
} else {
|
||||
res <- sprintf('%s\'s %s:%g', eval_res$data_name, eval_res$name, eval_res$value)
|
||||
}
|
||||
return(res)
|
||||
}
|
||||
|
||||
merge.eval.string <- function(env){
|
||||
if(length(env$eval_list) <= 0){
|
||||
return("")
|
||||
}
|
||||
msg <- list(sprintf('[%d]:',env$iteration))
|
||||
is_eval_err <- FALSE
|
||||
if(length(env$eval_err_list) > 0){
|
||||
is_eval_err <- TRUE
|
||||
}
|
||||
for(j in 1:length(env$eval_list)) {
|
||||
eval_err <- NULL
|
||||
if(is_eval_err){
|
||||
eval_err <- env$eval_err_list[[j]]
|
||||
}
|
||||
msg <- c(msg, format.eval.string(env$eval_list[[j]],eval_err))
|
||||
}
|
||||
return(paste0(msg, collapse='\t'))
|
||||
}
|
||||
|
||||
cb.print.evaluation <- function(period=1){
|
||||
callback <- function(env){
|
||||
if(period > 0){
|
||||
i <- env$iteration
|
||||
if( (i - 1) %% period == 0
|
||||
| i == env$begin_iteration
|
||||
| i == env$end_iteration ){
|
||||
cat(merge.eval.string(env), "\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.print.evaluation'
|
||||
return(callback)
|
||||
}
|
||||
|
||||
cb.record.evaluation <- function() {
|
||||
callback <- function(env){
|
||||
if(length(env$eval_list) <= 0) return()
|
||||
is_eval_err <- FALSE
|
||||
if(length(env$eval_err_list) > 0){
|
||||
is_eval_err <- TRUE
|
||||
}
|
||||
if(length(env$model$record_evals) == 0){
|
||||
for(j in 1:length(env$eval_list)) {
|
||||
data_name <- env$eval_list[[j]]$data_name
|
||||
name <- env$eval_list[[j]]$name
|
||||
env$model$record_evals$start_iter <- env$begin_iteration
|
||||
if(is.null(env$model$record_evals[[data_name]])){
|
||||
env$model$record_evals[[data_name]] <- list()
|
||||
}
|
||||
env$model$record_evals[[data_name]][[name]] <- list()
|
||||
env$model$record_evals[[data_name]][[name]]$eval <- list()
|
||||
env$model$record_evals[[data_name]][[name]]$eval_err <- list()
|
||||
}
|
||||
}
|
||||
for(j in 1:length(env$eval_list)) {
|
||||
eval_res <- env$eval_list[[j]]
|
||||
eval_err <- NULL
|
||||
if(is_eval_err){
|
||||
eval_err <- env$eval_err_list[[j]]
|
||||
}
|
||||
data_name <- eval_res$data_name
|
||||
name <- eval_res$name
|
||||
env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
|
||||
env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
|
||||
}
|
||||
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.record.evaluation'
|
||||
return(callback)
|
||||
}
|
||||
|
||||
cb.early.stop <- function(stopping_rounds, verbose=TRUE) {
|
||||
# state variables
|
||||
factor_to_bigger_better <- NULL
|
||||
best_iter <- NULL
|
||||
best_score <- NULL
|
||||
best_msg <- NULL
|
||||
eval_len <- NULL
|
||||
init <- function(env) {
|
||||
eval_len <<- length(env$eval_list)
|
||||
if (eval_len == 0)
|
||||
stop("For early stopping, valids must have at least one element")
|
||||
|
||||
if (verbose)
|
||||
cat("Will train until hasn't improved in ",
|
||||
stopping_rounds, " rounds.\n\n", sep = '')
|
||||
|
||||
factor_to_bigger_better <<- rep(1.0, eval_len)
|
||||
best_iter <<- rep(-1, eval_len)
|
||||
best_score <<- rep(-Inf, eval_len)
|
||||
best_msg <<- list()
|
||||
for(i in 1:eval_len){
|
||||
best_msg <<- c(best_msg, "")
|
||||
if(!env$eval_list[[i]]$higher_better){
|
||||
factor_to_bigger_better[i] <<- -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
callback <- function(env, finalize = FALSE) {
|
||||
if (is.null(eval_len))
|
||||
init(env)
|
||||
cur_iter <- env$iteration
|
||||
for(i in 1:eval_len){
|
||||
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
|
||||
if(score > best_score[i]){
|
||||
best_score[i] <<- score
|
||||
best_iter[i] <<- cur_iter
|
||||
if(verbose){
|
||||
best_msg[[i]] <<- as.character(merge.eval.string(env))
|
||||
}
|
||||
} else {
|
||||
if(cur_iter - best_iter[i] >= stopping_rounds){
|
||||
if(!is.null(env$model)){
|
||||
env$model$best_iter <- best_iter[i]
|
||||
}
|
||||
if(verbose){
|
||||
cat('Early stopping, best iteration is:',"\n")
|
||||
cat(best_msg[[i]],"\n")
|
||||
}
|
||||
env$best_iter <- best_iter[i]
|
||||
env$met_early_stop <- TRUE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr(callback, 'call') <- match.call()
|
||||
attr(callback, 'name') <- 'cb.early.stop'
|
||||
return(callback)
|
||||
}
|
||||
|
||||
# Extract callback names from the list of callbacks
|
||||
callback.names <- function(cb_list) {
|
||||
unlist(lapply(cb_list, function(x) attr(x, 'name')))
|
||||
}
|
||||
|
||||
add.cb <- function(cb_list, cb) {
|
||||
cb_list <- c(cb_list, cb)
|
||||
names(cb_list) <- callback.names(cb_list)
|
||||
if ('cb.early.stop' %in% names(cb_list)) {
|
||||
cb_list <- c(cb_list, cb_list['cb.early.stop'])
|
||||
# this removes only the first one
|
||||
cb_list['cb.early.stop'] <- NULL
|
||||
}
|
||||
if ('cb.cv.predict' %in% names(cb_list)) {
|
||||
cb_list <- c(cb_list, cb_list['cb.cv.predict'])
|
||||
cb_list['cb.cv.predict'] <- NULL
|
||||
}
|
||||
cb_list
|
||||
}
|
||||
|
||||
categorize.callbacks <- function(cb_list) {
|
||||
list(
|
||||
pre_iter = Filter(function(x) {
|
||||
pre <- attr(x, 'is_pre_iteration')
|
||||
!is.null(pre) && pre
|
||||
}, cb_list),
|
||||
post_iter = Filter(function(x) {
|
||||
pre <- attr(x, 'is_pre_iteration')
|
||||
is.null(pre) || !pre
|
||||
}, cb_list)
|
||||
)
|
||||
}
|
|
@ -0,0 +1,500 @@
|
|||
Booster <- R6Class(
|
||||
"lgb.Booster",
|
||||
cloneable=FALSE,
|
||||
public = list(
|
||||
best_iter = -1,
|
||||
record_evals = list(),
|
||||
finalize = function() {
|
||||
if(!lgb.is.null.handle(private$handle)){
|
||||
print("free booster handle")
|
||||
lgb.call("LGBM_BoosterFree_R", ret=NULL, private$handle)
|
||||
private$handle <- NULL
|
||||
}
|
||||
},
|
||||
initialize = function(params = list(),
|
||||
train_set = NULL,
|
||||
modelfile = NULL,
|
||||
...) {
|
||||
params <- append(params, list(...))
|
||||
params_str <- lgb.params2str(params)
|
||||
handle <- lgb.new.handle()
|
||||
if (!is.null(train_set)) {
|
||||
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
|
||||
stop("lgb.Booster: Only can use lgb.Dataset as training data")
|
||||
}
|
||||
handle <-
|
||||
lgb.call("LGBM_BoosterCreate_R", ret=handle, train_set$.__enclos_env__$private$get_handle(), params_str)
|
||||
private$train_set <- train_set
|
||||
private$num_dataset <- 1
|
||||
private$init_predictor <- train_set$.__enclos_env__$private$predictor
|
||||
if (!is.null(private$init_predictor)) {
|
||||
lgb.call("LGBM_BoosterMerge_R", ret=NULL,
|
||||
handle,
|
||||
private$init_predictor$.__enclos_env__$private$handle)
|
||||
}
|
||||
private$is_predicted_cur_iter <-
|
||||
c(private$is_predicted_cur_iter, FALSE)
|
||||
} else if (!is.null(modelfile)) {
|
||||
if (!is.character(modelfile)) {
|
||||
stop("lgb.Booster: Only can use string as model file path")
|
||||
}
|
||||
handle <-
|
||||
lgb.call("LGBM_BoosterCreateFromModelfile_R",
|
||||
ret=handle,
|
||||
lgb.c_str(modelfile))
|
||||
} else {
|
||||
stop(
|
||||
"lgb.Booster: Need at least one training dataset or model file to create booster instance"
|
||||
)
|
||||
}
|
||||
class(handle) <- "lgb.Booster.handle"
|
||||
private$handle <- handle
|
||||
private$num_class <- as.integer(1)
|
||||
private$num_class <-
|
||||
lgb.call("LGBM_BoosterGetNumClasses_R", ret=private$num_class, private$handle)
|
||||
},
|
||||
set_train_data_name = function(name) {
|
||||
private$name_train_set <- name
|
||||
return(self)
|
||||
},
|
||||
add_valid = function(data, name) {
|
||||
if (!lgb.check.r6.class(data, "lgb.Dataset")) {
|
||||
stop("lgb.Booster.add_valid: Only can use lgb.Dataset as validation data")
|
||||
}
|
||||
if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
|
||||
stop(
|
||||
"lgb.Booster.add_valid: Add validation data failed, you should use same predictor for these data"
|
||||
)
|
||||
}
|
||||
if(!is.character(name)){
|
||||
stop("only can use character as data name")
|
||||
}
|
||||
lgb.call("LGBM_BoosterAddValidData_R", ret=NULL, private$handle, data$.__enclos_env__$private$get_handle())
|
||||
private$valid_sets <- c(private$valid_sets, data)
|
||||
private$name_valid_sets <- c(private$name_valid_sets, name)
|
||||
private$num_dataset <- private$num_dataset + 1
|
||||
private$is_predicted_cur_iter <-
|
||||
c(private$is_predicted_cur_iter, FALSE)
|
||||
return(self)
|
||||
},
|
||||
reset_parameter = function(params, ...) {
|
||||
params <- append(params, list(...))
|
||||
params_str <- algb.params2str(params)
|
||||
lgb.call("LGBM_BoosterResetParameter_R", ret=NULL,
|
||||
private$handle,
|
||||
params_str)
|
||||
return(self)
|
||||
},
|
||||
update = function(train_set = NULL, fobj = NULL) {
|
||||
if (!is.null(train_set)) {
|
||||
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
|
||||
stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
|
||||
}
|
||||
if (!identical(train_set$predictor, private$init_predictor)) {
|
||||
stop(
|
||||
"lgb.Booster.update: Change train_set failed, you should use same predictor for these data"
|
||||
)
|
||||
}
|
||||
lgb.call("LGBM_BoosterResetTrainingData_R", ret=NULL,
|
||||
private$handle,
|
||||
train_set$.__enclos_env__$private$get_handle())
|
||||
private$train_set = train_set
|
||||
}
|
||||
if (is.null(fobj)) {
|
||||
ret <-
|
||||
lgb.call("LGBM_BoosterUpdateOneIter_R", ret=NULL, private$handle)
|
||||
} else {
|
||||
if (typeof(fobj) != 'closure') {
|
||||
stop("lgb.Booster.update: fobj should be a function")
|
||||
}
|
||||
gpair <- fobj(private$inner_predict(1), private$train_set)
|
||||
ret <-
|
||||
lgb.call(
|
||||
"LGBM_BoosterUpdateOneIterCustom_R", ret=NULL,
|
||||
private$handle,
|
||||
gpair$grad,
|
||||
gpair$hess,
|
||||
length(gpair$grad)
|
||||
)
|
||||
}
|
||||
for (i in 1:length(private$is_predicted_cur_iter)) {
|
||||
private$is_predicted_cur_iter[[i]] <- FALSE
|
||||
}
|
||||
return(ret)
|
||||
},
|
||||
rollback_one_iter = function() {
|
||||
lgb.call("LGBM_BoosterRollbackOneIter_R", ret=NULL, private$handle)
|
||||
for (i in 1:length(private$is_predicted_cur_iter)) {
|
||||
private$is_predicted_cur_iter[[i]] <- FALSE
|
||||
}
|
||||
return(self)
|
||||
},
|
||||
current_iter = function() {
|
||||
cur_iter <- as.integer(0)
|
||||
return(lgb.call("LGBM_BoosterGetCurrentIteration_R", ret=cur_iter, private$handle))
|
||||
},
|
||||
eval = function(data, name, feval = NULL) {
|
||||
if (!lgb.check.r6.class(data, "lgb.Dataset")) {
|
||||
stop("lgb.Booster.eval: only can use lgb.Dataset to eval")
|
||||
}
|
||||
data_idx <- 0
|
||||
if (identical(data, private$train_set)) {
|
||||
data_idx <- 1
|
||||
} else {
|
||||
if(length(private$valid_sets) > 0){
|
||||
for (i in 1:length(private$valid_sets)) {
|
||||
if (identical(data, private$valid_sets[[i]])) {
|
||||
data_idx <- i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_idx == 0) {
|
||||
self$add_valid(data, name)
|
||||
data_idx <- private$num_dataset
|
||||
}
|
||||
return(private$inner_eval(name, data_idx, feval))
|
||||
},
|
||||
eval_train = function(feval = NULL) {
|
||||
return(private$inner_eval(private$name_train_set, 1, feval))
|
||||
},
|
||||
eval_valid = function(feval = NULL) {
|
||||
ret = list()
|
||||
if(length(private$valid_sets) <= 0) return(ret)
|
||||
for (i in 1:length(private$valid_sets)) {
|
||||
ret <-
|
||||
append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
|
||||
}
|
||||
return(ret)
|
||||
},
|
||||
save_model = function(filename, num_iteration = NULL) {
|
||||
if (is.null(num_iteration)) {
|
||||
num_iteration <- self$best_iter
|
||||
}
|
||||
lgb.call(
|
||||
"LGBM_BoosterSaveModel_R",
|
||||
ret = NULL,
|
||||
private$handle,
|
||||
as.integer(num_iteration),
|
||||
lgb.c_str(filename)
|
||||
)
|
||||
return(self)
|
||||
},
|
||||
dump_model = function(num_iteration = NULL) {
|
||||
if (is.null(num_iteration)) {
|
||||
num_iteration <- self$best_iter
|
||||
}
|
||||
return(
|
||||
lgb.call.return.str(
|
||||
"LGBM_BoosterDumpModel_R",
|
||||
private$handle,
|
||||
as.integer(num_iteration)
|
||||
)
|
||||
)
|
||||
},
|
||||
predict = function(data,
|
||||
num_iteration = NULL,
|
||||
rawscore = FALSE,
|
||||
predleaf = FALSE,
|
||||
header = FALSE,
|
||||
reshape = FALSE) {
|
||||
if (is.null(num_iteration)) {
|
||||
num_iteration <- self$best_iter
|
||||
}
|
||||
predictor <- Predictor$new(private$handle)
|
||||
return(predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape))
|
||||
},
|
||||
to_predictor = function() {
|
||||
Predictor$new(private$handle)
|
||||
}
|
||||
),
|
||||
private = list(
|
||||
handle = NULL,
|
||||
train_set = NULL,
|
||||
name_train_set = "training",
|
||||
valid_sets = list(),
|
||||
name_valid_sets = list(),
|
||||
predict_buffer = list(),
|
||||
is_predicted_cur_iter = list(),
|
||||
num_class = 1,
|
||||
num_dataset = 0,
|
||||
init_predictor = NULL,
|
||||
eval_names = NULL,
|
||||
higher_better_inner_eval = NULL,
|
||||
inner_predict = function(idx) {
|
||||
data_name <- private$name_train_set
|
||||
if(idx > 1){
|
||||
data_name <- private$name_valid_sets[[idx - 1]]
|
||||
}
|
||||
if (idx > private$num_dataset) {
|
||||
stop("data_idx should not be greater than num_dataset")
|
||||
}
|
||||
if (is.null(private$predict_buffer[[data_name]])) {
|
||||
npred <- as.integer(0)
|
||||
npred <-
|
||||
lgb.call("LGBM_BoosterGetNumPredict_R",
|
||||
ret = npred,
|
||||
private$handle,
|
||||
as.integer(idx - 1))
|
||||
private$predict_buffer[[data_name]] <- rep(0.0, npred)
|
||||
}
|
||||
if (!private$is_predicted_cur_iter[[idx]]) {
|
||||
private$predict_buffer[[data_name]] <-
|
||||
lgb.call(
|
||||
"LGBM_BoosterGetPredict_R",
|
||||
ret=private$predict_buffer[[data_name]],
|
||||
private$handle,
|
||||
as.integer(idx - 1)
|
||||
)
|
||||
private$is_predicted_cur_iter[[idx]] <- TRUE
|
||||
}
|
||||
return(private$predict_buffer[[data_name]])
|
||||
},
|
||||
get_eval_info = function() {
|
||||
if (is.null(private$eval_names)) {
|
||||
names <-
|
||||
lgb.call.return.str("LGBM_BoosterGetEvalNames_R", private$handle)
|
||||
if(nchar(names) > 0){
|
||||
names <- strsplit(names, "\t")[[1]]
|
||||
private$eval_names <- names
|
||||
private$higher_better_inner_eval <-
|
||||
rep(FALSE, length(names))
|
||||
for (i in 1:length(names)) {
|
||||
if (startsWith(names[i], "auc") |
|
||||
startsWith(names[i], "ndcg")) {
|
||||
private$higher_better_inner_eval[i] <- TRUE
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return(private$eval_names)
|
||||
},
|
||||
inner_eval = function(data_name, data_idx, feval = NULL) {
|
||||
if (data_idx > private$num_dataset) {
|
||||
stop("data_idx should not be greater than num_dataset")
|
||||
}
|
||||
private$get_eval_info()
|
||||
ret <- list()
|
||||
if (length(private$eval_names) > 0) {
|
||||
tmp_vals <- rep(0.0, length(private$eval_names))
|
||||
tmp_vals <-
|
||||
lgb.call("LGBM_BoosterGetEval_R", ret=tmp_vals,
|
||||
private$handle,
|
||||
as.integer(data_idx - 1))
|
||||
for (i in 1:length(private$eval_names)) {
|
||||
res <- list()
|
||||
res$data_name <- data_name
|
||||
res$name <- private$eval_names[i]
|
||||
res$value <- tmp_vals[i]
|
||||
res$higher_better <- private$higher_better_inner_eval[i]
|
||||
ret <- append(ret, list(res))
|
||||
}
|
||||
}
|
||||
if (!is.null(feval)) {
|
||||
if (typeof(feval) != 'closure') {
|
||||
stop("lgb.Booster.eval: feval should be a function")
|
||||
}
|
||||
data <- private$train_set
|
||||
if (data_idx > 1) {
|
||||
data <- private$valid_sets[[data_idx - 1]]
|
||||
}
|
||||
res <- feval(private$inner_predict(data_idx), data)
|
||||
res$data_name <- data_name
|
||||
ret <- append(ret, list(res))
|
||||
}
|
||||
return(ret)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# internal helper method
|
||||
lgb.is.Booster <- function(x){
|
||||
if(lgb.check.r6.class(x, "lgb.Booster")){
|
||||
return(TRUE)
|
||||
} else{
|
||||
return(FALSE)
|
||||
}
|
||||
}
|
||||
|
||||
#' Predict method for LightGBM model
|
||||
#'
|
||||
#' Predicted values based on class \code{lgb.Booster}
|
||||
#'
|
||||
#' @param object Object of class \code{lgb.Booster}
|
||||
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
|
||||
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
|
||||
#' @param rawscore whether the prediction should be returned in the for of original untransformed
|
||||
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
|
||||
#' logistic regression would result in predictions for log-odds instead of probabilities.
|
||||
#' @param predleaf whether predict leaf index instead.
|
||||
#' @param header only used for prediction for text file. True if text file has header
|
||||
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
|
||||
#' prediction outputs per case.
|
||||
|
||||
#' @return
|
||||
#' For regression or binary classification, it returns a vector of length \code{nrows(data)}.
|
||||
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
|
||||
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
|
||||
#' the \code{reshape} value.
|
||||
#'
|
||||
#' When \code{predleaf = TRUE}, the output is a matrix object with the
|
||||
#' number of columns corresponding to the number of trees.
|
||||
#' @examples
|
||||
#' library(lightgbm)
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
#' params <- list(objective="regression", metric="l2")
|
||||
#' valids <- list(test=dtest)
|
||||
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
#' preds <- predict(model, test$data)
|
||||
#'
|
||||
#' @rdname predict.lgb.Booster
|
||||
#' @export
|
||||
predict.lgb.Booster <- function(object,
|
||||
data,
|
||||
num_iteration = NULL,
|
||||
rawscore = FALSE,
|
||||
predleaf = FALSE,
|
||||
header = FALSE,
|
||||
reshape = FALSE) {
|
||||
if(!lgb.is.Booster(object)){
|
||||
stop("predict.lgb.Booster: should input lgb.Booster object")
|
||||
}
|
||||
object$predict(data, num_iteration, rawscore, predleaf, header, reshape)
|
||||
}
|
||||
|
||||
#' Load LightGBM model
|
||||
#'
|
||||
#' Load LightGBM model from saved model file
|
||||
#'
|
||||
#' @param filename path of model file
|
||||
#'
|
||||
#' @return booster
|
||||
#' @examples
|
||||
#' library(lightgbm)
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
#' params <- list(objective="regression", metric="l2")
|
||||
#' valids <- list(test=dtest)
|
||||
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
#' lgb.save(model, "model.txt")
|
||||
#' load_booster <- lgb.load("model.txt")
|
||||
#' @rdname lgb.load
|
||||
#' @export
|
||||
lgb.load <- function(filename){
|
||||
if(!is.character(filename)){
|
||||
stop("lgb.load: filename should be character")
|
||||
}
|
||||
Booster$new(modelfile=filename)
|
||||
}
|
||||
|
||||
#' Save LightGBM model
|
||||
#'
|
||||
#' Save LightGBM model
|
||||
#'
|
||||
#' @param booster Object of class \code{lgb.Booster}
|
||||
#' @param filename saved filename
|
||||
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
|
||||
#'
|
||||
#' @return booster
|
||||
#' @examples
|
||||
#' library(lightgbm)
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
#' params <- list(objective="regression", metric="l2")
|
||||
#' valids <- list(test=dtest)
|
||||
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
#' lgb.save(model, "model.txt")
|
||||
#' @rdname lgb.save
|
||||
#' @export
|
||||
lgb.save <- function(booster, filename, num_iteration=NULL){
|
||||
if(!lgb.is.Booster(booster)){
|
||||
stop("lgb.save: should input lgb.Booster object")
|
||||
}
|
||||
if(!is.character(filename)){
|
||||
stop("lgb.save: filename should be character")
|
||||
}
|
||||
booster$save_model(filename, num_iteration)
|
||||
}
|
||||
|
||||
#' Dump LightGBM model to json
|
||||
#'
|
||||
#' Dump LightGBM model to json
|
||||
#'
|
||||
#' @param booster Object of class \code{lgb.Booster}
|
||||
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
|
||||
#'
|
||||
#' @return json format of model
|
||||
#' @examples
|
||||
#' library(lightgbm)
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
#' params <- list(objective="regression", metric="l2")
|
||||
#' valids <- list(test=dtest)
|
||||
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
#' json_model <- lgb.dump(model)
|
||||
#' @rdname lgb.dump
|
||||
#' @export
|
||||
lgb.dump <- function(booster, num_iteration=NULL){
|
||||
if(!lgb.is.Booster(booster)){
|
||||
stop("lgb.dump: should input lgb.Booster object")
|
||||
}
|
||||
booster$dump_model(num_iteration)
|
||||
}
|
||||
|
||||
#' Get record evaluation result from booster
|
||||
#'
|
||||
#' Get record evaluation result from booster
|
||||
#' @param booster Object of class \code{lgb.Booster}
|
||||
#' @param data_name name of dataset
|
||||
#' @param eval_name name of evaluation
|
||||
#' @param iters iterations, NULL will return all
|
||||
#' @param is_err TRUE will return evaluation error instead
|
||||
#' @return vector of evaluation result
|
||||
#'
|
||||
#' @rdname lgb.get.eval.result
|
||||
#' @export
|
||||
lgb.get.eval.result <- function(booster, data_name, eval_name, iters=NULL, is_err=FALSE){
|
||||
if(!lgb.is.Booster(booster)){
|
||||
stop("lgb.get.eval.result: only can use booster to get eval result")
|
||||
}
|
||||
if(!is.character(data_name) | !is.character(eval_name)){
|
||||
stop("lgb.get.eval.result: data_name and eval_name should be character")
|
||||
}
|
||||
if(is.null(booster$record_evals[[data_name]])){
|
||||
stop("lgb.get.eval.result: wrong data name")
|
||||
}
|
||||
if(is.null(booster$record_evals[[data_name]][[eval_name]])){
|
||||
stop("lgb.get.eval.result: wrong eval name")
|
||||
}
|
||||
result <- booster$record_evals[[data_name]][[eval_name]]$eval
|
||||
if(is_err){
|
||||
result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
|
||||
}
|
||||
if(is.null(iters)){
|
||||
return(as.numeric(result))
|
||||
}
|
||||
iters <- as.integer(iters)
|
||||
delta <- booster$record_evals$start_iter - 1
|
||||
iters <- iters - delta
|
||||
return(as.numeric(result[iters]))
|
||||
}
|
||||
|
|
@ -0,0 +1,795 @@
|
|||
Dataset <- R6Class(
|
||||
"lgb.Dataset",
|
||||
cloneable=FALSE,
|
||||
public = list(
|
||||
finalize = function() {
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
print("free dataset handle")
|
||||
lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
|
||||
private$handle <- NULL
|
||||
}
|
||||
},
|
||||
initialize = function(data,
|
||||
params = list(),
|
||||
reference = NULL,
|
||||
colnames = NULL,
|
||||
categorical_feature = NULL,
|
||||
predictor = NULL,
|
||||
free_raw_data = TRUE,
|
||||
used_indices = NULL,
|
||||
info = list(),
|
||||
...) {
|
||||
addiction_params <- list(...)
|
||||
for (key in names(addiction_params)) {
|
||||
if (key %in% c('label', 'weight', 'init_score', 'group')) {
|
||||
info[[key]] <- addiction_params[[key]]
|
||||
} else {
|
||||
params[[key]] <- addiction_params[[key]]
|
||||
}
|
||||
}
|
||||
if (!is.null(reference)) {
|
||||
if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
|
||||
stop("lgb.Dataset: Only can use lgb.Dataset as reference")
|
||||
}
|
||||
}
|
||||
if (!is.null(predictor)) {
|
||||
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
|
||||
stop("lgb.Dataset: Only can use lgb.Predictor as predictor")
|
||||
}
|
||||
}
|
||||
private$raw_data <- data
|
||||
private$params <- params
|
||||
private$reference <- reference
|
||||
private$colnames <- colnames
|
||||
|
||||
private$categorical_feature <- categorical_feature
|
||||
private$predictor <- predictor
|
||||
private$free_raw_data <- free_raw_data
|
||||
private$used_indices <- used_indices
|
||||
private$info <- info
|
||||
},
|
||||
create_valid = function(data, info = list(), ...) {
|
||||
ret <- Dataset$new(
|
||||
data,
|
||||
private$params,
|
||||
self,
|
||||
private$colnames,
|
||||
private$categorical_feature,
|
||||
private$predictor,
|
||||
private$free_raw_data,
|
||||
NULL,
|
||||
info,
|
||||
...
|
||||
)
|
||||
return(ret)
|
||||
},
|
||||
construct = function() {
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
return(self)
|
||||
}
|
||||
# Get feature names
|
||||
cnames <- NULL
|
||||
if (is.matrix(private$raw_data) |
|
||||
class(private$raw_data) == "dgCMatrix") {
|
||||
cnames <- colnames(private$raw_data)
|
||||
}
|
||||
# set feature names if not exist
|
||||
if (is.null(private$colnames) & !is.null(cnames)) {
|
||||
private$colnames <- as.character(cnames)
|
||||
}
|
||||
# Get categorical feature index
|
||||
if (!is.null(private$categorical_feature)) {
|
||||
fname_dict <- list()
|
||||
if (!is.null(private$colnames)) {
|
||||
fname_dict <-
|
||||
as.list(setNames(0:(length(
|
||||
private$colnames
|
||||
) - 1), private$colnames))
|
||||
}
|
||||
cate_indices <- list()
|
||||
for (key in private$categorical_feature) {
|
||||
if (is.character(key)) {
|
||||
idx <- fname_dict[[key]]
|
||||
if (is.null(idx)) {
|
||||
stop(paste("lgb.self.get.handle: cannot find feature name ", key))
|
||||
}
|
||||
cate_indices <- c(cate_indices, idx)
|
||||
} else {
|
||||
# one-based indices to zero-based
|
||||
idx <- as.integer(key - 1)
|
||||
cate_indices <- c(cate_indices, idx)
|
||||
}
|
||||
}
|
||||
private$params$categorical_feature <- cate_indices
|
||||
}
|
||||
# Check has header or not
|
||||
has_header <- FALSE
|
||||
if (!is.null(private$params$has_header) |
|
||||
!is.null(private$params$header)) {
|
||||
if (tolower(as.character(private$params$has_header)) == "true"
|
||||
|
|
||||
tolower(as.character(private$params$header)) == "true") {
|
||||
has_header <- TRUE
|
||||
}
|
||||
}
|
||||
# Generate parameter str
|
||||
params_str <- lgb.params2str(private$params)
|
||||
# get handle of reference dataset
|
||||
ref_handle <- NULL
|
||||
if (!is.null(private$reference)) {
|
||||
ref_handle <- private$reference$.__enclos_env__$private$get_handle()
|
||||
}
|
||||
handle <- lgb.new.handle()
|
||||
# not subset
|
||||
if (is.null(private$used_indices)) {
|
||||
if (typeof(private$raw_data) == "character") {
|
||||
handle <-
|
||||
lgb.call(
|
||||
"LGBM_DatasetCreateFromFile_R",
|
||||
ret = handle,
|
||||
lgb.c_str(private$raw_data),
|
||||
params_str,
|
||||
ref_handle
|
||||
)
|
||||
} else if (is.matrix(private$raw_data)) {
|
||||
handle <-
|
||||
lgb.call(
|
||||
"LGBM_DatasetCreateFromMat_R",
|
||||
ret = handle,
|
||||
private$raw_data,
|
||||
nrow(private$raw_data),
|
||||
ncol(private$raw_data),
|
||||
params_str,
|
||||
ref_handle
|
||||
)
|
||||
} else if (class(private$raw_data) == "dgCMatrix") {
|
||||
handle <- lgb.call(
|
||||
"LGBM_DatasetCreateFromCSC_R",
|
||||
ret = handle,
|
||||
private$raw_data@p,
|
||||
private$raw_data@i,
|
||||
private$raw_data@x,
|
||||
length(private$raw_data@p),
|
||||
length(private$raw_data@x),
|
||||
nrow(private$raw_data),
|
||||
params_str,
|
||||
ref_handle
|
||||
)
|
||||
} else {
|
||||
stop(paste(
|
||||
"lgb.Dataset.construct: does not support to construct from ",
|
||||
typeof(private$raw_data)
|
||||
))
|
||||
}
|
||||
} else {
|
||||
# construct subset
|
||||
if (is.null(private$reference)) {
|
||||
stop("lgb.Dataset.construct: reference cannot be NULL if construct subset")
|
||||
}
|
||||
handle <-
|
||||
lgb.call(
|
||||
"LGBM_DatasetGetSubset_R",
|
||||
ret = handle,
|
||||
ref_handle,
|
||||
private$used_indices,
|
||||
length(private$used_indices),
|
||||
params_str
|
||||
)
|
||||
}
|
||||
class(handle) <- "lgb.Dataset.handle"
|
||||
private$handle <- handle
|
||||
# set feature names
|
||||
if (!is.null(private$colnames)) {
|
||||
self$set_colnames(private$colnames)
|
||||
}
|
||||
|
||||
# load init score
|
||||
if (!is.null(private$predictor) &
|
||||
is.null(private$used_indices)) {
|
||||
init_score <-
|
||||
private$predictor$predict(private$raw_data,
|
||||
rawscore = TRUE,
|
||||
reshape = TRUE)
|
||||
# not need to transpose, for is col_marjor
|
||||
init_score <- as.vector(init_score)
|
||||
private$info$init_score <- init_score
|
||||
}
|
||||
if (private$free_raw_data & !is.character(private$raw_data)) {
|
||||
private$raw_data <- NULL
|
||||
}
|
||||
if (length(private$info) > 0) {
|
||||
# set infos
|
||||
for (i in 1:length(private$info)) {
|
||||
p <- private$info[i]
|
||||
self$setinfo(names(p), p[[1]])
|
||||
}
|
||||
}
|
||||
if (is.null(self$getinfo("label"))) {
|
||||
stop("lgb.Dataset.construct: label should be set")
|
||||
}
|
||||
return(self)
|
||||
},
|
||||
dim = function() {
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
num_row <- as.integer(0)
|
||||
num_col <- as.integer(0)
|
||||
|
||||
return(c(
|
||||
lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle),
|
||||
lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle)
|
||||
))
|
||||
} else if (is.matrix(private$raw_data) |
|
||||
class(private$raw_data) == "dgCMatrix") {
|
||||
return(dim(private$raw_data))
|
||||
} else {
|
||||
stop(
|
||||
"dim: cannot get Dimensions before dataset constructed, please call lgb.Dataset.construct explicit"
|
||||
)
|
||||
}
|
||||
},
|
||||
get_colnames = function() {
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R",
|
||||
private$handle)
|
||||
private$colnames <- as.character(strsplit(cnames, "\t")[[1]])
|
||||
return(private$colnames)
|
||||
} else if (is.matrix(private$raw_data) |
|
||||
class(private$raw_data) == "dgCMatrix") {
|
||||
return(colnames(private$raw_data))
|
||||
} else {
|
||||
stop(
|
||||
"colnames: cannot get colnames before dataset constructed, please call lgb.Dataset.construct explicit"
|
||||
)
|
||||
}
|
||||
},
|
||||
set_colnames = function(colnames) {
|
||||
if(is.null(colnames)) return(self)
|
||||
colnames <- as.character(colnames)
|
||||
if(length(colnames) == 0) return(self)
|
||||
private$colnames <- colnames
|
||||
if (!lgb.is.null.handle(private$handle)) {
|
||||
merged_name <- paste0(as.list(private$colnames), collapse = "\t")
|
||||
lgb.call("LGBM_DatasetSetFeatureNames_R",
|
||||
ret = NULL,
|
||||
private$handle,
|
||||
lgb.c_str(merged_name))
|
||||
}
|
||||
return(self)
|
||||
},
|
||||
getinfo = function(name) {
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'init_score', 'group')) {
|
||||
stop(
|
||||
"getinfo: name must one of the following\n",
|
||||
" 'label', 'weight', 'init_score', 'group'"
|
||||
)
|
||||
}
|
||||
if (is.null(private$info[[name]]) &
|
||||
!lgb.is.null.handle(private$handle)) {
|
||||
info_len <- as.integer(0)
|
||||
info_len <-
|
||||
lgb.call("LGBM_DatasetGetFieldSize_R",
|
||||
ret = info_len,
|
||||
private$handle,
|
||||
lgb.c_str(name))
|
||||
if (info_len > 0) {
|
||||
ret <- NULL
|
||||
if (name == "group") {
|
||||
ret <- integer(info_len)
|
||||
} else {
|
||||
ret <- rep(0.0, info_len)
|
||||
}
|
||||
ret <-
|
||||
lgb.call("LGBM_DatasetGetField_R",
|
||||
ret = ret,
|
||||
private$handle,
|
||||
lgb.c_str(name))
|
||||
private$info[[name]] <- ret
|
||||
}
|
||||
}
|
||||
return(private$info[[name]])
|
||||
},
|
||||
setinfo = function(name, info) {
|
||||
if (typeof(name) != "character" ||
|
||||
length(name) != 1 ||
|
||||
!name %in% c('label', 'weight', 'init_score', 'group')) {
|
||||
stop(
|
||||
"setinfo: name must one of the following\n",
|
||||
" 'label', 'weight', 'init_score', 'group'"
|
||||
)
|
||||
}
|
||||
if (name == "group") {
|
||||
info <- as.integer(info)
|
||||
} else {
|
||||
info <- as.numeric(info)
|
||||
}
|
||||
private$info[[name]] <- info
|
||||
if (!lgb.is.null.handle(private$handle) & !is.null(info)) {
|
||||
if (length(info) > 0) {
|
||||
lgb.call(
|
||||
"LGBM_DatasetSetField_R",
|
||||
ret = NULL,
|
||||
private$handle,
|
||||
lgb.c_str(name),
|
||||
info,
|
||||
length(info)
|
||||
)
|
||||
}
|
||||
}
|
||||
return(self)
|
||||
},
|
||||
slice = function(idxset, ...) {
|
||||
ret <- Dataset$new(
|
||||
NULL,
|
||||
private$params,
|
||||
self,
|
||||
private$colnames,
|
||||
private$categorical_feature,
|
||||
private$predictor,
|
||||
private$free_raw_data,
|
||||
idxset,
|
||||
NULL,
|
||||
...
|
||||
)
|
||||
return(ret)
|
||||
},
|
||||
update_params = function(params){
|
||||
private$params <- modifyList(private$params, params)
|
||||
},
|
||||
set_categorical_feature = function(categorical_feature) {
|
||||
if (identical(private$categorical_feature, categorical_feature)) {
|
||||
return(self)
|
||||
}
|
||||
if (is.null(private$raw_data)) {
|
||||
stop(
|
||||
"set_categorical_feature: cannot set categorical feature after free raw data,
|
||||
please set free_raw_data=FALSE when construct lgb.Dataset"
|
||||
)
|
||||
}
|
||||
private$categorical_feature <- categorical_feature
|
||||
self$finalize()
|
||||
return(self)
|
||||
},
|
||||
set_reference = function(reference) {
|
||||
self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature)
|
||||
self$set_colnames(reference$get_colnames())
|
||||
private$set_predictor(reference$.__enclos_env__$private$predictor)
|
||||
if (identical(private$reference, reference)) {
|
||||
return(self)
|
||||
}
|
||||
if (is.null(private$raw_data)) {
|
||||
stop(
|
||||
"set_reference: cannot set reference after free raw data,
|
||||
please set free_raw_data=FALSE when construct lgb.Dataset"
|
||||
)
|
||||
}
|
||||
if (!is.null(reference)) {
|
||||
if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
|
||||
stop("set_reference: Only can use lgb.Dataset as reference")
|
||||
}
|
||||
}
|
||||
private$reference <- reference
|
||||
self$finalize()
|
||||
return(self)
|
||||
},
|
||||
save_binary = function(fname) {
|
||||
self$construct()
|
||||
lgb.call("LGBM_DatasetSaveBinary_R",
|
||||
ret = NULL,
|
||||
private$handle,
|
||||
lgb.c_str(fname))
|
||||
return(self)
|
||||
}
|
||||
),
|
||||
private = list(
|
||||
handle = NULL,
|
||||
raw_data = NULL,
|
||||
params = list(),
|
||||
reference = NULL,
|
||||
colnames = NULL,
|
||||
categorical_feature = NULL,
|
||||
predictor = NULL,
|
||||
free_raw_data = TRUE,
|
||||
used_indices = NULL,
|
||||
info = NULL,
|
||||
get_handle = function() {
|
||||
if (lgb.is.null.handle(private$handle)) {
|
||||
self$construct()
|
||||
}
|
||||
return(private$handle)
|
||||
},
|
||||
set_predictor = function(predictor) {
|
||||
if (identical(private$predictor, predictor)) {
|
||||
return(self)
|
||||
}
|
||||
if (is.null(private$raw_data)) {
|
||||
stop(
|
||||
"set_predictor: cannot set predictor after free raw data,
|
||||
please set free_raw_data=FALSE when construct lgb.Dataset"
|
||||
)
|
||||
}
|
||||
if (!is.null(predictor)) {
|
||||
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
|
||||
stop("set_predictor: Only can use lgb.Predictor as predictor")
|
||||
}
|
||||
}
|
||||
private$predictor <- predictor
|
||||
self$finalize()
|
||||
return(self)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
#' Contruct lgb.Dataset object
|
||||
#'
|
||||
#' Contruct lgb.Dataset object
|
||||
#'
|
||||
#' Contruct lgb.Dataset object from dense matrix, sparse matrix
|
||||
#' or local file (that was created previously by saving an \code{lgb.Dataset}).
|
||||
#'
|
||||
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
|
||||
#' @param params a list of parameters
|
||||
#' @param reference reference dataset
|
||||
#' @param colnames names of columns
|
||||
#' @param categorical_feature categorical features
|
||||
#' @param free_raw_data TRUE for need to free raw data after construct
|
||||
#' @param info a list of information of the lgb.Dataset object
|
||||
#' @param ... other information to pass to \code{info} or parameters pass to \code{params}
|
||||
#' @return constructed dataset
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
|
||||
#' dtrain <- lgb.Dataset('lgb.Dataset.data')
|
||||
#' lgb.Dataset.construct(dtrain)
|
||||
#' @export
|
||||
lgb.Dataset <- function(data,
|
||||
params = list(),
|
||||
reference = NULL,
|
||||
colnames = NULL,
|
||||
categorical_feature = NULL,
|
||||
free_raw_data = TRUE,
|
||||
info = list(),
|
||||
...) {
|
||||
Dataset$new(
|
||||
data,
|
||||
params,
|
||||
reference,
|
||||
colnames,
|
||||
categorical_feature,
|
||||
NULL,
|
||||
free_raw_data,
|
||||
NULL,
|
||||
info,
|
||||
...
|
||||
)
|
||||
}
|
||||
|
||||
# internal helper method
|
||||
lgb.is.Dataset <- function(x){
|
||||
if(lgb.check.r6.class(x, "lgb.Dataset")){
|
||||
return(TRUE)
|
||||
} else{
|
||||
return(FALSE)
|
||||
}
|
||||
}
|
||||
|
||||
#' Contruct a validation data
|
||||
#'
|
||||
#' Contruct a validation data according to training data
|
||||
#'
|
||||
#' @param dataset \code{lgb.Dataset} object, training data
|
||||
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
|
||||
#' @param info a list of information of the lgb.Dataset object
|
||||
#' @param ... other information to pass to \code{info}.
|
||||
#' @return constructed dataset
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
#' @export
|
||||
lgb.Dataset.create.valid <-
|
||||
function(dataset, data, info = list(), ...) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("lgb.Dataset.create.valid: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$create_valid(data, info, ...))
|
||||
}
|
||||
|
||||
#' Construct Dataset explicit
|
||||
#'
|
||||
#' Construct Dataset explicit
|
||||
#'
|
||||
#' @param dataset Object of class \code{lgb.Dataset}
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.construct(dtrain)
|
||||
#' @export
|
||||
lgb.Dataset.construct <- function(dataset) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("lgb.Dataset.construct: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$construct())
|
||||
}
|
||||
|
||||
#' Dimensions of lgb.Dataset
|
||||
#'
|
||||
#' Dimensions of lgb.Dataset
|
||||
#'
|
||||
#' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
|
||||
#' @param x Object of class \code{lgb.Dataset}
|
||||
#' @param ... other parameters
|
||||
#' @return a vector of numbers of rows and of columns
|
||||
#'
|
||||
#' @details
|
||||
#' Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
|
||||
#' be directly used with an \code{lgb.Dataset} object.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#'
|
||||
#' stopifnot(nrow(dtrain) == nrow(train$data))
|
||||
#' stopifnot(ncol(dtrain) == ncol(train$data))
|
||||
#' stopifnot(all(dim(dtrain) == dim(train$data)))
|
||||
#'
|
||||
#' @rdname dim
|
||||
#' @export
|
||||
dim.lgb.Dataset <- function(x, ...) {
|
||||
if(!lgb.is.Dataset(x)) {
|
||||
stop("dim.lgb.Dataset: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(x$dim())
|
||||
}
|
||||
|
||||
#' Handling of column names of \code{lgb.Dataset}
|
||||
#'
|
||||
#' Handling of column names of \code{lgb.Dataset}
|
||||
#'
|
||||
#' Only column names are supported for \code{lgb.Dataset}, thus setting of
|
||||
#' row names would have no effect and returnten row names would be NULL.
|
||||
#'
|
||||
#' @param x object of class \code{lgb.Dataset}
|
||||
#' @param value a list of two elements: the first one is ignored
|
||||
#' and the second one is column names
|
||||
#'
|
||||
#' @details
|
||||
#' Generic \code{dimnames} methods are used by \code{colnames}.
|
||||
#' Since row names are irrelevant, it is recommended to use \code{colnames} directly.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.construct(dtrain)
|
||||
#' dimnames(dtrain)
|
||||
#' colnames(dtrain)
|
||||
#' colnames(dtrain) <- make.names(1:ncol(train$data))
|
||||
#' print(dtrain, verbose=TRUE)
|
||||
#'
|
||||
#' @rdname dimnames.lgb.Dataset
|
||||
#' @export
|
||||
dimnames.lgb.Dataset <- function(x) {
|
||||
if(!lgb.is.Dataset(x)) {
|
||||
stop("dimnames.lgb.Dataset: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(list(NULL, x$get_colnames()))
|
||||
}
|
||||
|
||||
#' @rdname dimnames.lgb.Dataset
|
||||
#' @export
|
||||
`dimnames<-.lgb.Dataset` <- function(x, value) {
|
||||
if (!is.list(value) || length(value) != 2L)
|
||||
stop("invalid 'dimnames' given: must be a list of two elements")
|
||||
if (!is.null(value[[1L]]))
|
||||
stop("lgb.Dataset does not have rownames")
|
||||
if (is.null(value[[2]])) {
|
||||
x$set_colnames(NULL)
|
||||
return(x)
|
||||
}
|
||||
if (ncol(x) != length(value[[2]]))
|
||||
stop("can't assign ",
|
||||
length(value[[2]]),
|
||||
" colnames to a ",
|
||||
ncol(x),
|
||||
" column lgb.Dataset")
|
||||
x$set_colnames(value[[2]])
|
||||
return(x)
|
||||
}
|
||||
|
||||
#' Slice an dataset
|
||||
#'
|
||||
#' Get a new Dataset containing the specified rows of
|
||||
#' orginal lgb.Dataset object
|
||||
#'
|
||||
#' @param dataset Object of class "lgb.Dataset"
|
||||
#' @param idxset a integer vector of indices of rows needed
|
||||
#' @param ... other parameters (currently not used)
|
||||
#' @return constructed sub dataset
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#'
|
||||
#' dsub <- slice(dtrain, 1:42)
|
||||
#' labels1 <- getinfo(dsub, 'label')
|
||||
#'
|
||||
#' @export
|
||||
slice <- function(dataset, ...)
|
||||
UseMethod("slice")
|
||||
|
||||
#' @rdname slice
|
||||
#' @export
|
||||
slice.lgb.Dataset <- function(dataset, idxset, ...) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("slice.lgb.Dataset: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$slice(idxset, ...))
|
||||
}
|
||||
|
||||
|
||||
#' Get information of an lgb.Dataset object
|
||||
#'
|
||||
#' Get information of an lgb.Dataset object
|
||||
#'
|
||||
#' @param dataset Object of class \code{lgb.Dataset}
|
||||
#' @param name the name of the information field to get (see details)
|
||||
#' @param ... other parameters
|
||||
#' @return info data
|
||||
#'
|
||||
#' @details
|
||||
#' The \code{name} field can be one of the following:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{label}: label lightgbm learn from ;
|
||||
#' \item \code{weight}: to do a weight rescale ;
|
||||
#' \item \code{group}: group size
|
||||
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
|
||||
#' }
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.construct(dtrain)
|
||||
#' labels <- getinfo(dtrain, 'label')
|
||||
#' setinfo(dtrain, 'label', 1-labels)
|
||||
#'
|
||||
#' labels2 <- getinfo(dtrain, 'label')
|
||||
#' stopifnot(all(labels2 == 1-labels))
|
||||
#' @export
|
||||
getinfo <- function(dataset, ...)
|
||||
UseMethod("getinfo")
|
||||
|
||||
#' @rdname getinfo
|
||||
#' @export
|
||||
getinfo.lgb.Dataset <- function(dataset, name, ...) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("getinfo.lgb.Dataset: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$getinfo(name))
|
||||
}
|
||||
|
||||
#' Set information of an lgb.Dataset object
|
||||
#'
|
||||
#' Set information of an lgb.Dataset object
|
||||
#'
|
||||
#' @param dataset Object of class "lgb.Dataset"
|
||||
#' @param name the name of the field to get
|
||||
#' @param info the specific field of information to set
|
||||
#' @param ... other parameters
|
||||
#' @return passed object
|
||||
#'
|
||||
#' @details
|
||||
#' The \code{name} field can be one of the following:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{label}: label lightgbm learn from ;
|
||||
#' \item \code{weight}: to do a weight rescale ;
|
||||
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
|
||||
#' \item \code{group}.
|
||||
#' }
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.construct(dtrain)
|
||||
#' labels <- getinfo(dtrain, 'label')
|
||||
#' setinfo(dtrain, 'label', 1-labels)
|
||||
#' labels2 <- getinfo(dtrain, 'label')
|
||||
#' stopifnot(all.equal(labels2, 1-labels))
|
||||
#' @export
|
||||
setinfo <- function(dataset, ...)
|
||||
UseMethod("setinfo")
|
||||
|
||||
#' @rdname setinfo
|
||||
#' @export
|
||||
setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("setinfo.lgb.Dataset: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$setinfo(name, info))
|
||||
}
|
||||
|
||||
#' set categorical feature of \code{lgb.Dataset}
|
||||
#'
|
||||
#' set categorical feature of \code{lgb.Dataset}
|
||||
#'
|
||||
#' @param dataset object of class \code{lgb.Dataset}
|
||||
#' @param categorical_feature categorical features
|
||||
#' @return passed dataset
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
|
||||
#' dtrain <- lgb.Dataset('lgb.Dataset.data')
|
||||
#' lgb.Dataset.set.categorical(dtrain, 1:2)
|
||||
#' @rdname lgb.Dataset.set.categorical
|
||||
#' @export
|
||||
lgb.Dataset.set.categorical <-
|
||||
function(dataset, categorical_feature) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("lgb.Dataset.set.categorical: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$set_categorical_feature(categorical_feature))
|
||||
}
|
||||
|
||||
#' set reference of \code{lgb.Dataset}
|
||||
#'
|
||||
#' set reference of \code{lgb.Dataset}.
|
||||
#' If you want to use validation data, you should set its reference to training data
|
||||
#'
|
||||
#' @param dataset object of class \code{lgb.Dataset}
|
||||
#' @param reference object of class \code{lgb.Dataset}
|
||||
#' @return passed dataset
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset(test$data, test=train$label)
|
||||
#' lgb.Dataset.set.reference(dtest, dtrain)
|
||||
#' @rdname lgb.Dataset.set.reference
|
||||
#' @export
|
||||
lgb.Dataset.set.reference <- function(dataset, reference) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("lgb.Dataset.set.reference: input data should be lgb.Dataset object")
|
||||
}
|
||||
return(dataset$set_reference(reference))
|
||||
}
|
||||
|
||||
#' save \code{lgb.Dataset} to binary file
|
||||
#'
|
||||
#' save \code{lgb.Dataset} to binary file
|
||||
#'
|
||||
#' @param dataset object of class \code{lgb.Dataset}
|
||||
#' @param fname object filename of output file
|
||||
#' @return passed dataset
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' lgb.Dataset.save(dtrain, "data.bin")
|
||||
#' @rdname lgb.Dataset.save
|
||||
#' @export
|
||||
lgb.Dataset.save <- function(dataset, fname) {
|
||||
if(!lgb.is.Dataset(dataset)) {
|
||||
stop("lgb.Dataset.set: input data should be lgb.Dataset object")
|
||||
}
|
||||
if(!is.character(fname)) {
|
||||
stop("lgb.Dataset.set: filename should be character type")
|
||||
}
|
||||
return(dataset$save_binary(fname))
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
Predictor <- R6Class(
|
||||
"lgb.Predictor",
|
||||
cloneable=FALSE,
|
||||
public = list(
|
||||
finalize = function() {
|
||||
if(private$need_free_handle & !lgb.is.null.handle(private$handle)){
|
||||
print("free booster handle")
|
||||
lgb.call("LGBM_BoosterFree_R", ret=NULL, private$handle)
|
||||
private$handle <- NULL
|
||||
}
|
||||
},
|
||||
initialize = function(modelfile) {
|
||||
handle <- lgb.new.handle()
|
||||
if(typeof(modelfile) == "character") {
|
||||
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret=handle, lgb.c_str(modelfile))
|
||||
private$need_free_handle = TRUE
|
||||
} else if (class(modelfile) == "lgb.Booster.handle") {
|
||||
handle <- modelfile
|
||||
private$need_free_handle = FALSE
|
||||
} else {
|
||||
stop("lgb.Predictor: modelfile must be either character filename, or lgb.Booster.handle")
|
||||
}
|
||||
class(handle) <- "lgb.Booster.handle"
|
||||
private$handle <- handle
|
||||
},
|
||||
current_iter = function() {
|
||||
cur_iter <- as.integer(0)
|
||||
return(lgb.call("LGBM_BoosterGetCurrentIteration_R", ret=cur_iter, private$handle))
|
||||
},
|
||||
predict = function(data,
|
||||
num_iteration = NULL, rawscore = FALSE, predleaf = FALSE, header = FALSE,
|
||||
reshape = FALSE) {
|
||||
|
||||
if (is.null(num_iteration)) {
|
||||
num_iteration <- -1
|
||||
|
||||
}
|
||||
|
||||
num_row <- 0
|
||||
if (typeof(data) == "character") {
|
||||
tmp_filename <- tempfile(pattern = "lightgbm_")
|
||||
lgb.call("LGBM_BoosterPredictForFile_R", ret=NULL, private$handle, data, as.integer(header),
|
||||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration),
|
||||
lgb.c_str(tmp_filename))
|
||||
preds <- read.delim(tmp_filename, header=FALSE, seq="\t")
|
||||
num_row <- nrow(preds)
|
||||
preds <- as.vector(t(preds))
|
||||
# delete temp file
|
||||
if(file.exists(tmp_filename)) { file.remove(tmp_filename) }
|
||||
} else {
|
||||
num_row <- nrow(data)
|
||||
npred <- as.integer(0)
|
||||
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret=npred,
|
||||
private$handle,
|
||||
as.integer(num_row),
|
||||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration))
|
||||
# allocte space for prediction
|
||||
preds <- rep(0.0, npred)
|
||||
if (is.matrix(data)) {
|
||||
preds <- lgb.call("LGBM_BoosterPredictForMat_R", ret=preds,
|
||||
private$handle,
|
||||
data,
|
||||
as.integer(nrow(data)),
|
||||
as.integer(ncol(data)),
|
||||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration))
|
||||
} else if (class(data) == "dgCMatrix") {
|
||||
preds <- lgb.call("LGBM_BoosterPredictForCSC_R", ret=preds,
|
||||
private$handle,
|
||||
data@p,
|
||||
data@i,
|
||||
data@x,
|
||||
length(data@p),
|
||||
length(data@x),
|
||||
nrow(data),
|
||||
as.integer(rawscore),
|
||||
as.integer(predleaf),
|
||||
as.integer(num_iteration))
|
||||
} else {
|
||||
stop(paste("predict: does not support to predict from ",
|
||||
typeof(data)))
|
||||
}
|
||||
}
|
||||
|
||||
if (length(preds) %% num_row != 0) {
|
||||
stop("predict: prediction length ", length(preds)," is not multiple of nrows(data) ", num_row)
|
||||
}
|
||||
npred_per_case <- length(preds) / num_row
|
||||
if (reshape && npred_per_case > 1) {
|
||||
preds <- matrix(preds, ncol = npred_per_case)
|
||||
}
|
||||
return(preds)
|
||||
}
|
||||
),
|
||||
private = list(
|
||||
handle = NULL,
|
||||
need_free_handle = FALSE
|
||||
)
|
||||
)
|
|
@ -0,0 +1,177 @@
|
|||
#' Main training logic for LightGBM
|
||||
#'
|
||||
#' Main training logic for LightGBM
|
||||
#'
|
||||
#' @param params List of parameters
|
||||
#' @param data a \code{lgb.Dataset} object, used for training
|
||||
#' @param nrounds number of training rounds
|
||||
#' @param valids a list of \code{lgb.Dataset} object, used for validation
|
||||
#' @param obj objective function, can be character or custom objective function
|
||||
#' @param eval evaluation function, can be (list of) character or custom eval function
|
||||
#' @param verbose verbosity for output
|
||||
#' if verbose > 0 , also will record iteration message to booster$record_evals
|
||||
#' @param eval_freq evalutaion output frequence
|
||||
#' @param init_model path of model file of \code{lgb.Booster} object, will continue train from this model
|
||||
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
|
||||
#' @param categorical_feature list of str or int
|
||||
#' type int represents index,
|
||||
#' type str represents feature names
|
||||
#' @param early_stopping_rounds int
|
||||
#' Activates early stopping.
|
||||
#' Requires at least one validation data and one metric
|
||||
#' If there's more than one, will check all of them
|
||||
#' Returns the model with (best_iter + early_stopping_rounds)
|
||||
#' If early stopping occurs, the model will have 'best_iter' field
|
||||
#' @param callbacks list of callback functions
|
||||
#' List of callback functions that are applied at each iteration.
|
||||
#' @param ... other parameters, see parameters.md for more informations
|
||||
#' @return a trained booster model \code{lgb.Booster}.
|
||||
#' @examples
|
||||
#' library(lightgbm)
|
||||
#' data(agaricus.train, package='lightgbm')
|
||||
#' train <- agaricus.train
|
||||
#' dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
#' data(agaricus.test, package='lightgbm')
|
||||
#' test <- agaricus.test
|
||||
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
#' params <- list(objective="regression", metric="l2")
|
||||
#' valids <- list(test=dtest)
|
||||
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
#'
|
||||
#' @rdname lgb.train
|
||||
#' @export
|
||||
lgb.train <- function(params=list(), data, nrounds=10,
|
||||
valids=list(),
|
||||
obj=NULL, eval=NULL,
|
||||
verbose=1, eval_freq=1L,
|
||||
init_model=NULL,
|
||||
colnames=NULL,
|
||||
categorical_feature=NULL,
|
||||
early_stopping_rounds=NULL,
|
||||
callbacks=list(), ...) {
|
||||
addiction_params <- list(...)
|
||||
params <- append(params, addiction_params)
|
||||
params$verbose <- verbose
|
||||
params <- lgb.check.obj(params, obj)
|
||||
params <- lgb.check.eval(params, eval)
|
||||
fobj <- NULL
|
||||
feval <- NULL
|
||||
if(typeof(params$objective) == "closure"){
|
||||
fobj <- params$objective
|
||||
params$objective <- "NONE"
|
||||
}
|
||||
if (typeof(eval) == "closure"){
|
||||
feval <- eval
|
||||
}
|
||||
lgb.check.params(params)
|
||||
predictor <- NULL
|
||||
if(is.character(init_model)){
|
||||
predictor <- Predictor$new(init_model)
|
||||
} else if(lgb.is.Booster(init_model)) {
|
||||
predictor <- init_model$to_predictor()
|
||||
}
|
||||
begin_iteration <- 1
|
||||
if(!is.null(predictor)){
|
||||
begin_iteration <- predictor$current_iter() + 1
|
||||
}
|
||||
end_iteration <- begin_iteration + nrounds - 1
|
||||
|
||||
# check dataset
|
||||
if(!lgb.is.Dataset(data)){
|
||||
stop("lgb.train: data only accepts lgb.Dataset object")
|
||||
}
|
||||
if (length(valids) > 0) {
|
||||
if (typeof(valids) != "list" ||
|
||||
!all(sapply(valids, lgb.is.Dataset)))
|
||||
stop("valids must be a list of lgb.Dataset elements")
|
||||
evnames <- names(valids)
|
||||
if (is.null(evnames) || any(evnames == ""))
|
||||
stop("each element of the valids must have a name tag")
|
||||
}
|
||||
|
||||
data$update_params(params)
|
||||
data$.__enclos_env__$private$set_predictor(predictor)
|
||||
if(!is.null(colnames)){
|
||||
data$set_colnames(colnames)
|
||||
}
|
||||
data$set_categorical_feature(categorical_feature)
|
||||
|
||||
vaild_contain_train <- FALSE
|
||||
train_data_name <- "train"
|
||||
reduced_valid_sets <- list()
|
||||
if(length(valids) > 0){
|
||||
for (key in names(valids)) {
|
||||
valid_data <- valids[[key]]
|
||||
if(identical(data, valid_data)){
|
||||
vaild_contain_train <- TRUE
|
||||
train_data_name <- key
|
||||
next
|
||||
}
|
||||
valid_data$update_params(params)
|
||||
valid_data$set_reference(data)
|
||||
reduced_valid_sets[[key]] <- valid_data
|
||||
}
|
||||
}
|
||||
# process callbacks
|
||||
if(eval_freq > 0){
|
||||
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
|
||||
}
|
||||
|
||||
if (verbose > 0 && length(valids) > 0) {
|
||||
callbacks <- add.cb(callbacks, cb.record.evaluation())
|
||||
}
|
||||
|
||||
# Early stopping callback
|
||||
if (!is.null(early_stopping_rounds)) {
|
||||
if(early_stopping_rounds > 0){
|
||||
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose=verbose))
|
||||
}
|
||||
}
|
||||
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
|
||||
# construct booster
|
||||
booster <- Booster$new(params=params, train_set=data)
|
||||
if(vaild_contain_train){
|
||||
booster$set_train_data_name(train_data_name)
|
||||
}
|
||||
for (key in names(reduced_valid_sets)) {
|
||||
booster$add_valid(reduced_valid_sets[[key]], key)
|
||||
}
|
||||
|
||||
# callback env
|
||||
|
||||
env <- CB_ENV$new()
|
||||
env$model <- booster
|
||||
env$begin_iteration <- begin_iteration
|
||||
env$end_iteration <- end_iteration
|
||||
|
||||
#start training
|
||||
for(i in begin_iteration:end_iteration){
|
||||
env$iteration <- i
|
||||
env$eval_list <- list()
|
||||
for (f in cb$pre_iter) f(env)
|
||||
# update one iter
|
||||
booster$update(fobj=fobj)
|
||||
|
||||
# collect eval result
|
||||
eval_list <- list()
|
||||
if(length(valids) > 0){
|
||||
if(vaild_contain_train){
|
||||
eval_list <- append(eval_list, booster$eval_train(feval=feval))
|
||||
}
|
||||
eval_list <- append(eval_list, booster$eval_valid(feval=feval))
|
||||
}
|
||||
env$eval_list <- eval_list
|
||||
|
||||
for (f in cb$post_iter) f(env)
|
||||
|
||||
# met early stopping
|
||||
if(env$met_early_stop) break
|
||||
}
|
||||
|
||||
return(booster)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
# Simple interface for training an lightgbm model.
|
||||
# Its documentation is combined with lgb.train.
|
||||
#
|
||||
#' @rdname lgb.train
|
||||
#' @export
|
||||
lightgbm <- function(data, label = NULL, weight = NULL,
|
||||
params = list(), nrounds=10,
|
||||
verbose = 1, eval_freq = 1L,
|
||||
early_stopping_rounds = NULL,
|
||||
save_name = "lightgbm.model",
|
||||
init_model = NULL, callbacks = list(), ...) {
|
||||
|
||||
dtrain <- lgb.Dataset(data, label=label, weight=weight)
|
||||
|
||||
valids <- list()
|
||||
if (verbose > 0)
|
||||
valids$train = dtrain
|
||||
|
||||
bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq=eval_freq,
|
||||
early_stopping_rounds = early_stopping_rounds,
|
||||
init_model = init_model, callbacks = callbacks, ...)
|
||||
bst$save_model(save_name)
|
||||
return(bst)
|
||||
}
|
||||
|
||||
#' Training part from Mushroom Data Set
|
||||
#'
|
||||
#' This data set is originally from the Mushroom data set,
|
||||
#' UCI Machine Learning Repository.
|
||||
#'
|
||||
#' This data set includes the following fields:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{label} the label for each record
|
||||
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
|
||||
#' }
|
||||
#'
|
||||
#' @references
|
||||
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
|
||||
#'
|
||||
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
#' School of Information and Computer Science.
|
||||
#'
|
||||
#' @docType data
|
||||
#' @keywords datasets
|
||||
#' @name agaricus.train
|
||||
#' @usage data(agaricus.train)
|
||||
#' @format A list containing a label vector, and a dgCMatrix object with 6513
|
||||
#' rows and 127 variables
|
||||
NULL
|
||||
|
||||
#' Test part from Mushroom Data Set
|
||||
#'
|
||||
#' This data set is originally from the Mushroom data set,
|
||||
#' UCI Machine Learning Repository.
|
||||
#'
|
||||
#' This data set includes the following fields:
|
||||
#'
|
||||
#' \itemize{
|
||||
#' \item \code{label} the label for each record
|
||||
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
|
||||
#' }
|
||||
#'
|
||||
#' @references
|
||||
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
|
||||
#'
|
||||
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
#' School of Information and Computer Science.
|
||||
#'
|
||||
#' @docType data
|
||||
#' @keywords datasets
|
||||
#' @name agaricus.test
|
||||
#' @usage data(agaricus.test)
|
||||
#' @format A list containing a label vector, and a dgCMatrix object with 1611
|
||||
#' rows and 126 variables
|
||||
NULL
|
||||
|
||||
# Various imports
|
||||
#' @importFrom R6 R6Class
|
||||
#' @useDynLib lightgbm
|
||||
NULL
|
|
@ -0,0 +1,157 @@
|
|||
lgb.new.handle <- function() {
|
||||
# use 64bit data to store address
|
||||
return(0.0)
|
||||
}
|
||||
lgb.is.null.handle <- function(x) {
|
||||
if (is.null(x)) {
|
||||
return(TRUE)
|
||||
}
|
||||
if (x == 0) {
|
||||
return(TRUE)
|
||||
}
|
||||
return(FALSE)
|
||||
}
|
||||
|
||||
lgb.encode.char <- function(arr, len) {
|
||||
if (typeof(arr) != "raw") {
|
||||
stop("lgb.encode.char: only can encode from raw type")
|
||||
}
|
||||
return(rawToChar(arr[1:len]))
|
||||
}
|
||||
|
||||
lgb.call <- function(fun_name, ret, ...) {
|
||||
call_state <- as.integer(0)
|
||||
if (!is.null(ret)) {
|
||||
call_state <-
|
||||
.Call(fun_name, ..., ret, call_state , PACKAGE = "lightgbm")
|
||||
} else {
|
||||
call_state <- .Call(fun_name, ..., call_state , PACKAGE = "lightgbm")
|
||||
}
|
||||
if (call_state != as.integer(0)) {
|
||||
buf_len <- as.integer(200)
|
||||
act_len <- as.integer(0)
|
||||
err_msg <- raw(buf_len)
|
||||
err_msg <-
|
||||
.Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lightgbm")
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
err_msg <- raw(buf_len)
|
||||
err_msg <-
|
||||
.Call("LGBM_GetLastError_R",
|
||||
buf_len,
|
||||
act_len,
|
||||
err_msg,
|
||||
PACKAGE = "lightgbm")
|
||||
}
|
||||
stop(paste0("api error: ", lgb.encode.char(err_msg, act_len)))
|
||||
}
|
||||
return(ret)
|
||||
}
|
||||
|
||||
|
||||
lgb.call.return.str <- function(fun_name, ...) {
|
||||
buf_len <- as.integer(1024 * 1024)
|
||||
act_len <- as.integer(0)
|
||||
buf <- raw(buf_len)
|
||||
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
|
||||
if (act_len > buf_len) {
|
||||
buf_len <- act_len
|
||||
buf <- raw(buf_len)
|
||||
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
|
||||
}
|
||||
return(lgb.encode.char(buf, act_len))
|
||||
}
|
||||
|
||||
lgb.params2str <- function(params, ...) {
|
||||
if (typeof(params) != "list")
|
||||
stop("params must be a list")
|
||||
names(params) <- gsub("\\.", "_", names(params))
|
||||
# merge parameters from the params and the dots-expansion
|
||||
dot_params <- list(...)
|
||||
names(dot_params) <- gsub("\\.", "_", names(dot_params))
|
||||
if (length(intersect(names(params),
|
||||
names(dot_params))) > 0)
|
||||
stop(
|
||||
"Same parameters in 'params' and in the call are not allowed. Please check your 'params' list."
|
||||
)
|
||||
params <- c(params, dot_params)
|
||||
ret <- list()
|
||||
for (key in names(params)) {
|
||||
# join multi value first
|
||||
val <- paste0(params[[key]], collapse = ",")
|
||||
if(nchar(val) <= 0) next
|
||||
# join key value
|
||||
pair <- paste0(c(key, val), collapse = "=")
|
||||
ret <- c(ret, pair)
|
||||
}
|
||||
if (length(ret) == 0) {
|
||||
return(lgb.c_str(""))
|
||||
} else{
|
||||
return(lgb.c_str(paste0(ret, collapse = " ")))
|
||||
}
|
||||
}
|
||||
|
||||
lgb.c_str <- function(x) {
|
||||
ret <- charToRaw(as.character(x))
|
||||
ret <- c(ret, as.raw(0))
|
||||
return(ret)
|
||||
}
|
||||
|
||||
lgb.check.r6.class <- function(object, name) {
|
||||
if (!("R6" %in% class(object))) {
|
||||
return(FALSE)
|
||||
}
|
||||
if (!(name %in% class(object))) {
|
||||
return(FALSE)
|
||||
}
|
||||
return(TRUE)
|
||||
}
|
||||
|
||||
lgb.check.params <- function(params){
|
||||
# To-do
|
||||
return(params)
|
||||
}
|
||||
|
||||
lgb.check.obj <- function(params, obj) {
|
||||
if(!is.null(obj)){
|
||||
params$objective <- obj
|
||||
}
|
||||
if(is.character(params$objective)){
|
||||
if(!(params$objective %in% c("regression", "binary", "multiclass", "lambdarank"))){
|
||||
stop("lgb.check.obj: objective name error should be (regression, binary, multiclass, lambdarank)")
|
||||
}
|
||||
} else if(typeof(params$objective) != "closure"){
|
||||
stop("lgb.check.obj: objective should be character or function")
|
||||
}
|
||||
return(params)
|
||||
}
|
||||
|
||||
lgb.check.eval <- function(params, eval) {
|
||||
if(is.null(params$metric)){
|
||||
params$metric <- list()
|
||||
}
|
||||
if(!is.null(eval)){
|
||||
# append metric
|
||||
if(is.character(eval) || is.list(eval)){
|
||||
params$metric <- append(params$metric, eval)
|
||||
}
|
||||
}
|
||||
if (typeof(eval) != "closure"){
|
||||
if(is.null(params$metric) | length(params$metric) == 0) {
|
||||
# add default metric
|
||||
if(is.character(params$objective)){
|
||||
if(params$objective == "regression"){
|
||||
params$metric <- "l2"
|
||||
} else if(params$objective == "binary"){
|
||||
params$metric <- "binary_logloss"
|
||||
} else if(params$objective == "multiclass"){
|
||||
params$metric <- "multi_logloss"
|
||||
} else if(params$objective == "lambdarank"){
|
||||
params$metric <- "ndcg"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return(params)
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
LightGBM R Package
|
||||
==================
|
||||
|
||||
Installation
|
||||
------------
|
||||
```
|
||||
cd R-package
|
||||
R CMD INSTALL --build .
|
||||
```
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,32 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lightgbm.R
|
||||
\docType{data}
|
||||
\name{agaricus.test}
|
||||
\alias{agaricus.test}
|
||||
\title{Test part from Mushroom Data Set}
|
||||
\format{A list containing a label vector, and a dgCMatrix object with 1611
|
||||
rows and 126 variables}
|
||||
\usage{
|
||||
data(agaricus.test)
|
||||
}
|
||||
\description{
|
||||
This data set is originally from the Mushroom data set,
|
||||
UCI Machine Learning Repository.
|
||||
}
|
||||
\details{
|
||||
This data set includes the following fields:
|
||||
|
||||
\itemize{
|
||||
\item \code{label} the label for each record
|
||||
\item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
|
||||
}
|
||||
}
|
||||
\references{
|
||||
https://archive.ics.uci.edu/ml/datasets/Mushroom
|
||||
|
||||
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
School of Information and Computer Science.
|
||||
}
|
||||
\keyword{datasets}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lightgbm.R
|
||||
\docType{data}
|
||||
\name{agaricus.train}
|
||||
\alias{agaricus.train}
|
||||
\title{Training part from Mushroom Data Set}
|
||||
\format{A list containing a label vector, and a dgCMatrix object with 6513
|
||||
rows and 127 variables}
|
||||
\usage{
|
||||
data(agaricus.train)
|
||||
}
|
||||
\description{
|
||||
This data set is originally from the Mushroom data set,
|
||||
UCI Machine Learning Repository.
|
||||
}
|
||||
\details{
|
||||
This data set includes the following fields:
|
||||
|
||||
\itemize{
|
||||
\item \code{label} the label for each record
|
||||
\item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
|
||||
}
|
||||
}
|
||||
\references{
|
||||
https://archive.ics.uci.edu/ml/datasets/Mushroom
|
||||
|
||||
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
|
||||
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
|
||||
School of Information and Computer Science.
|
||||
}
|
||||
\keyword{datasets}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{dim.lgb.Dataset}
|
||||
\alias{dim.lgb.Dataset}
|
||||
\title{Dimensions of lgb.Dataset}
|
||||
\usage{
|
||||
\method{dim}{lgb.Dataset}(x, ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{x}{Object of class \code{lgb.Dataset}}
|
||||
|
||||
\item{...}{other parameters}
|
||||
}
|
||||
\value{
|
||||
a vector of numbers of rows and of columns
|
||||
}
|
||||
\description{
|
||||
Dimensions of lgb.Dataset
|
||||
}
|
||||
\details{
|
||||
Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
|
||||
|
||||
|
||||
Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
|
||||
be directly used with an \code{lgb.Dataset} object.
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
|
||||
stopifnot(nrow(dtrain) == nrow(train$data))
|
||||
stopifnot(ncol(dtrain) == ncol(train$data))
|
||||
stopifnot(all(dim(dtrain) == dim(train$data)))
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{dimnames.lgb.Dataset}
|
||||
\alias{dimnames.lgb.Dataset}
|
||||
\alias{dimnames<-.lgb.Dataset}
|
||||
\title{Handling of column names of \code{lgb.Dataset}}
|
||||
\usage{
|
||||
\method{dimnames}{lgb.Dataset}(x)
|
||||
|
||||
\method{dimnames}{lgb.Dataset}(x) <- value
|
||||
}
|
||||
\arguments{
|
||||
\item{x}{object of class \code{lgb.Dataset}}
|
||||
|
||||
\item{value}{a list of two elements: the first one is ignored
|
||||
and the second one is column names}
|
||||
}
|
||||
\description{
|
||||
Handling of column names of \code{lgb.Dataset}
|
||||
}
|
||||
\details{
|
||||
Only column names are supported for \code{lgb.Dataset}, thus setting of
|
||||
row names would have no effect and returnten row names would be NULL.
|
||||
|
||||
|
||||
Generic \code{dimnames} methods are used by \code{colnames}.
|
||||
Since row names are irrelevant, it is recommended to use \code{colnames} directly.
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.construct(dtrain)
|
||||
dimnames(dtrain)
|
||||
colnames(dtrain)
|
||||
colnames(dtrain) <- make.names(1:ncol(train$data))
|
||||
print(dtrain, verbose=TRUE)
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{getinfo}
|
||||
\alias{getinfo}
|
||||
\alias{getinfo.lgb.Dataset}
|
||||
\title{Get information of an lgb.Dataset object}
|
||||
\usage{
|
||||
getinfo(dataset, ...)
|
||||
|
||||
\method{getinfo}{lgb.Dataset}(dataset, name, ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{Object of class \code{lgb.Dataset}}
|
||||
|
||||
\item{...}{other parameters}
|
||||
|
||||
\item{name}{the name of the information field to get (see details)}
|
||||
}
|
||||
\value{
|
||||
info data
|
||||
}
|
||||
\description{
|
||||
Get information of an lgb.Dataset object
|
||||
}
|
||||
\details{
|
||||
The \code{name} field can be one of the following:
|
||||
|
||||
\itemize{
|
||||
\item \code{label}: label lightgbm learn from ;
|
||||
\item \code{weight}: to do a weight rescale ;
|
||||
\item \code{group}: group size
|
||||
\item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
|
||||
}
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.construct(dtrain)
|
||||
labels <- getinfo(dtrain, 'label')
|
||||
setinfo(dtrain, 'label', 1-labels)
|
||||
|
||||
labels2 <- getinfo(dtrain, 'label')
|
||||
stopifnot(all(labels2 == 1-labels))
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{lgb.Dataset}
|
||||
\alias{lgb.Dataset}
|
||||
\title{Contruct lgb.Dataset object}
|
||||
\usage{
|
||||
lgb.Dataset(data, params = list(), reference = NULL, colnames = NULL,
|
||||
categorical_feature = NULL, free_raw_data = TRUE, info = list(), ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename}
|
||||
|
||||
\item{params}{a list of parameters}
|
||||
|
||||
\item{reference}{reference dataset}
|
||||
|
||||
\item{colnames}{names of columns}
|
||||
|
||||
\item{categorical_feature}{categorical features}
|
||||
|
||||
\item{free_raw_data}{TRUE for need to free raw data after construct}
|
||||
|
||||
\item{info}{a list of information of the lgb.Dataset object}
|
||||
|
||||
\item{...}{other information to pass to \code{info} or parameters pass to \code{params}}
|
||||
}
|
||||
\value{
|
||||
constructed dataset
|
||||
}
|
||||
\description{
|
||||
Contruct lgb.Dataset object
|
||||
}
|
||||
\details{
|
||||
Contruct lgb.Dataset object from dense matrix, sparse matrix
|
||||
or local file (that was created previously by saving an \code{lgb.Dataset}).
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
|
||||
dtrain <- lgb.Dataset('lgb.Dataset.data')
|
||||
lgb.Dataset.construct(dtrain)
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{lgb.Dataset.construct}
|
||||
\alias{lgb.Dataset.construct}
|
||||
\title{Construct Dataset explicit}
|
||||
\usage{
|
||||
lgb.Dataset.construct(dataset)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{Object of class \code{lgb.Dataset}}
|
||||
}
|
||||
\description{
|
||||
Construct Dataset explicit
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.construct(dtrain)
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{lgb.Dataset.create.valid}
|
||||
\alias{lgb.Dataset.create.valid}
|
||||
\title{Contruct a validation data}
|
||||
\usage{
|
||||
lgb.Dataset.create.valid(dataset, data, info = list(), ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{\code{lgb.Dataset} object, training data}
|
||||
|
||||
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename}
|
||||
|
||||
\item{info}{a list of information of the lgb.Dataset object}
|
||||
|
||||
\item{...}{other information to pass to \code{info}.}
|
||||
}
|
||||
\value{
|
||||
constructed dataset
|
||||
}
|
||||
\description{
|
||||
Contruct a validation data according to training data
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{lgb.Dataset.save}
|
||||
\alias{lgb.Dataset.save}
|
||||
\title{save \code{lgb.Dataset} to binary file}
|
||||
\usage{
|
||||
lgb.Dataset.save(dataset, fname)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{object of class \code{lgb.Dataset}}
|
||||
|
||||
\item{fname}{object filename of output file}
|
||||
}
|
||||
\value{
|
||||
passed dataset
|
||||
}
|
||||
\description{
|
||||
save \code{lgb.Dataset} to binary file
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.save(dtrain, "data.bin")
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{lgb.Dataset.set.categorical}
|
||||
\alias{lgb.Dataset.set.categorical}
|
||||
\title{set categorical feature of \code{lgb.Dataset}}
|
||||
\usage{
|
||||
lgb.Dataset.set.categorical(dataset, categorical_feature)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{object of class \code{lgb.Dataset}}
|
||||
|
||||
\item{categorical_feature}{categorical features}
|
||||
}
|
||||
\value{
|
||||
passed dataset
|
||||
}
|
||||
\description{
|
||||
set categorical feature of \code{lgb.Dataset}
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
|
||||
dtrain <- lgb.Dataset('lgb.Dataset.data')
|
||||
lgb.Dataset.set.categorical(dtrain, 1:2)
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{lgb.Dataset.set.reference}
|
||||
\alias{lgb.Dataset.set.reference}
|
||||
\title{set reference of \code{lgb.Dataset}}
|
||||
\usage{
|
||||
lgb.Dataset.set.reference(dataset, reference)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{object of class \code{lgb.Dataset}}
|
||||
|
||||
\item{reference}{object of class \code{lgb.Dataset}}
|
||||
}
|
||||
\value{
|
||||
passed dataset
|
||||
}
|
||||
\description{
|
||||
set reference of \code{lgb.Dataset}.
|
||||
If you want to use validation data, you should set its reference to training data
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset(test$data, test=train$label)
|
||||
lgb.Dataset.set.reference(dtest, dtrain)
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Booster.R
|
||||
\name{lgb.dump}
|
||||
\alias{lgb.dump}
|
||||
\title{Dump LightGBM model to json}
|
||||
\usage{
|
||||
lgb.dump(booster, num_iteration = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{booster}{Object of class \code{lgb.Booster}}
|
||||
|
||||
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration}
|
||||
}
|
||||
\value{
|
||||
json format of model
|
||||
}
|
||||
\description{
|
||||
Dump LightGBM model to json
|
||||
}
|
||||
\examples{
|
||||
library(lightgbm)
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
params <- list(objective="regression", metric="l2")
|
||||
valids <- list(test=dtest)
|
||||
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
json_model <- lgb.dump(model)
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Booster.R
|
||||
\name{lgb.get.eval.result}
|
||||
\alias{lgb.get.eval.result}
|
||||
\title{Get record evaluation result from booster}
|
||||
\usage{
|
||||
lgb.get.eval.result(booster, data_name, eval_name, iters = NULL,
|
||||
is_err = FALSE)
|
||||
}
|
||||
\arguments{
|
||||
\item{booster}{Object of class \code{lgb.Booster}}
|
||||
|
||||
\item{data_name}{name of dataset}
|
||||
|
||||
\item{eval_name}{name of evaluation}
|
||||
|
||||
\item{iters}{iterations, NULL will return all}
|
||||
|
||||
\item{is_err}{TRUE will return evaluation error instead}
|
||||
}
|
||||
\value{
|
||||
vector of evaluation result
|
||||
}
|
||||
\description{
|
||||
Get record evaluation result from booster
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Booster.R
|
||||
\name{lgb.load}
|
||||
\alias{lgb.load}
|
||||
\title{Load LightGBM model}
|
||||
\usage{
|
||||
lgb.load(filename)
|
||||
}
|
||||
\arguments{
|
||||
\item{filename}{path of model file}
|
||||
}
|
||||
\value{
|
||||
booster
|
||||
}
|
||||
\description{
|
||||
Load LightGBM model from saved model file
|
||||
}
|
||||
\examples{
|
||||
library(lightgbm)
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
params <- list(objective="regression", metric="l2")
|
||||
valids <- list(test=dtest)
|
||||
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
lgb.save(model, "model.txt")
|
||||
load_booster <- lgb.load("model.txt")
|
||||
}
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Booster.R
|
||||
\name{lgb.save}
|
||||
\alias{lgb.save}
|
||||
\title{Save LightGBM model}
|
||||
\usage{
|
||||
lgb.save(booster, filename, num_iteration = NULL)
|
||||
}
|
||||
\arguments{
|
||||
\item{booster}{Object of class \code{lgb.Booster}}
|
||||
|
||||
\item{filename}{saved filename}
|
||||
|
||||
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration}
|
||||
}
|
||||
\value{
|
||||
booster
|
||||
}
|
||||
\description{
|
||||
Save LightGBM model
|
||||
}
|
||||
\examples{
|
||||
library(lightgbm)
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
params <- list(objective="regression", metric="l2")
|
||||
valids <- list(test=dtest)
|
||||
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
lgb.save(model, "model.txt")
|
||||
}
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.train.R, R/lightgbm.R
|
||||
\name{lgb.train}
|
||||
\alias{lgb.train}
|
||||
\alias{lightgbm}
|
||||
\title{Main training logic for LightGBM}
|
||||
\usage{
|
||||
lgb.train(params = list(), data, nrounds = 10, valids = list(),
|
||||
obj = NULL, eval = NULL, verbose = 1, eval_freq = 1L,
|
||||
init_model = NULL, colnames = NULL, categorical_feature = NULL,
|
||||
early_stopping_rounds = NULL, callbacks = list(), ...)
|
||||
|
||||
lightgbm(data, label = NULL, weight = NULL, params = list(),
|
||||
nrounds = 10, verbose = 1, eval_freq = 1L,
|
||||
early_stopping_rounds = NULL, save_name = "lightgbm.model",
|
||||
init_model = NULL, callbacks = list(), ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{params}{List of parameters}
|
||||
|
||||
\item{data}{a \code{lgb.Dataset} object, used for training}
|
||||
|
||||
\item{nrounds}{number of training rounds}
|
||||
|
||||
\item{valids}{a list of \code{lgb.Dataset} object, used for validation}
|
||||
|
||||
\item{obj}{objective function, can be character or custom objective function}
|
||||
|
||||
\item{eval}{evaluation function, can be (list of) character or custom eval function}
|
||||
|
||||
\item{verbose}{verbosity for output
|
||||
if verbose > 0 , also will record iteration message to booster$record_evals}
|
||||
|
||||
\item{eval_freq}{evalutaion output frequence}
|
||||
|
||||
\item{init_model}{path of model file of \code{lgb.Booster} object, will continue train from this model}
|
||||
|
||||
\item{colnames}{feature names, if not null, will use this to overwrite the names in dataset}
|
||||
|
||||
\item{categorical_feature}{list of str or int
|
||||
type int represents index,
|
||||
type str represents feature names}
|
||||
|
||||
\item{early_stopping_rounds}{int
|
||||
Activates early stopping.
|
||||
Requires at least one validation data and one metric
|
||||
If there's more than one, will check all of them
|
||||
Returns the model with (best_iter + early_stopping_rounds)
|
||||
If early stopping occurs, the model will have 'best_iter' field}
|
||||
|
||||
\item{callbacks}{list of callback functions
|
||||
List of callback functions that are applied at each iteration.}
|
||||
|
||||
\item{...}{other parameters, see parameters.md for more informations}
|
||||
}
|
||||
\value{
|
||||
a trained booster model \code{lgb.Booster}.
|
||||
}
|
||||
\description{
|
||||
Main training logic for LightGBM
|
||||
}
|
||||
\examples{
|
||||
library(lightgbm)
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
params <- list(objective="regression", metric="l2")
|
||||
valids <- list(test=dtest)
|
||||
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Booster.R
|
||||
\name{predict.lgb.Booster}
|
||||
\alias{predict.lgb.Booster}
|
||||
\title{Predict method for LightGBM model}
|
||||
\usage{
|
||||
\method{predict}{lgb.Booster}(object, data, num_iteration = NULL,
|
||||
rawscore = FALSE, predleaf = FALSE, header = FALSE, reshape = FALSE)
|
||||
}
|
||||
\arguments{
|
||||
\item{object}{Object of class \code{lgb.Booster}}
|
||||
|
||||
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename}
|
||||
|
||||
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration}
|
||||
|
||||
\item{rawscore}{whether the prediction should be returned in the for of original untransformed
|
||||
sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
|
||||
logistic regression would result in predictions for log-odds instead of probabilities.}
|
||||
|
||||
\item{predleaf}{whether predict leaf index instead.}
|
||||
|
||||
\item{header}{only used for prediction for text file. True if text file has header}
|
||||
|
||||
\item{reshape}{whether to reshape the vector of predictions to a matrix form when there are several
|
||||
prediction outputs per case.}
|
||||
}
|
||||
\value{
|
||||
For regression or binary classification, it returns a vector of length \code{nrows(data)}.
|
||||
For multiclass classification, either a \code{num_class * nrows(data)} vector or
|
||||
a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
|
||||
the \code{reshape} value.
|
||||
|
||||
When \code{predleaf = TRUE}, the output is a matrix object with the
|
||||
number of columns corresponding to the number of trees.
|
||||
}
|
||||
\description{
|
||||
Predicted values based on class \code{lgb.Booster}
|
||||
}
|
||||
\examples{
|
||||
library(lightgbm)
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test <- agaricus.test
|
||||
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
|
||||
params <- list(objective="regression", metric="l2")
|
||||
valids <- list(test=dtest)
|
||||
model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
|
||||
preds <- predict(model, test$data)
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{setinfo}
|
||||
\alias{setinfo}
|
||||
\alias{setinfo.lgb.Dataset}
|
||||
\title{Set information of an lgb.Dataset object}
|
||||
\usage{
|
||||
setinfo(dataset, ...)
|
||||
|
||||
\method{setinfo}{lgb.Dataset}(dataset, name, info, ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{Object of class "lgb.Dataset"}
|
||||
|
||||
\item{...}{other parameters}
|
||||
|
||||
\item{name}{the name of the field to get}
|
||||
|
||||
\item{info}{the specific field of information to set}
|
||||
}
|
||||
\value{
|
||||
passed object
|
||||
}
|
||||
\description{
|
||||
Set information of an lgb.Dataset object
|
||||
}
|
||||
\details{
|
||||
The \code{name} field can be one of the following:
|
||||
|
||||
\itemize{
|
||||
\item \code{label}: label lightgbm learn from ;
|
||||
\item \code{weight}: to do a weight rescale ;
|
||||
\item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
|
||||
\item \code{group}.
|
||||
}
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
lgb.Dataset.construct(dtrain)
|
||||
labels <- getinfo(dtrain, 'label')
|
||||
setinfo(dtrain, 'label', 1-labels)
|
||||
labels2 <- getinfo(dtrain, 'label')
|
||||
stopifnot(all.equal(labels2, 1-labels))
|
||||
}
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
% Generated by roxygen2: do not edit by hand
|
||||
% Please edit documentation in R/lgb.Dataset.R
|
||||
\name{slice}
|
||||
\alias{slice}
|
||||
\alias{slice.lgb.Dataset}
|
||||
\title{Slice an dataset}
|
||||
\usage{
|
||||
slice(dataset, ...)
|
||||
|
||||
\method{slice}{lgb.Dataset}(dataset, idxset, ...)
|
||||
}
|
||||
\arguments{
|
||||
\item{dataset}{Object of class "lgb.Dataset"}
|
||||
|
||||
\item{...}{other parameters (currently not used)}
|
||||
|
||||
\item{idxset}{a integer vector of indices of rows needed}
|
||||
}
|
||||
\value{
|
||||
constructed sub dataset
|
||||
}
|
||||
\description{
|
||||
Get a new Dataset containing the specified rows of
|
||||
orginal lgb.Dataset object
|
||||
}
|
||||
\examples{
|
||||
data(agaricus.train, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
dtrain <- lgb.Dataset(train$data, label=train$label)
|
||||
|
||||
dsub <- slice(dtrain, 1:42)
|
||||
labels1 <- getinfo(dsub, 'label')
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
# package root
|
||||
PKGROOT=../../
|
||||
|
||||
ENABLE_STD_THREAD=1
|
||||
CXX_STD = CXX11
|
||||
|
||||
LGBM_RFLAGS = -DUSE_SOCKET
|
||||
|
||||
PKG_CPPFLAGS= -I$(PKGROOT)/include $(LGBM_RFLAGS)
|
||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) -std=c++11
|
||||
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS)
|
||||
OBJECTS = ./lightgbm-all.o ./lightgbm_R.o
|
|
@ -0,0 +1,12 @@
|
|||
# package root
|
||||
PKGROOT=../../
|
||||
|
||||
ENABLE_STD_THREAD=1
|
||||
CXX_STD = CXX11
|
||||
|
||||
LGBM_RFLAGS = -DUSE_SOCKET
|
||||
|
||||
PKG_CPPFLAGS= -I$(PKGROOT)/include $(LGBM_RFLAGS)
|
||||
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) -std=c++11
|
||||
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) -lws2_32 -liphlpapi
|
||||
OBJECTS = ./lightgbm-all.o ./lightgbm_R.o
|
|
@ -0,0 +1,151 @@
|
|||
/*
|
||||
* A simple wrapper for access data in R object.
|
||||
* Due to license issue(GPLv2), we cannot include R's header file, so use this simple wrapper instead.
|
||||
* However, If R change its define of object, this file need to be updated as well.
|
||||
*/
|
||||
#ifndef R_OBJECT_HELPER_H_
|
||||
#define R_OBJECT_HELPER_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define TYPE_BITS 5
|
||||
struct sxpinfo_struct {
|
||||
unsigned int type : 5;
|
||||
unsigned int obj : 1;
|
||||
unsigned int named : 2;
|
||||
unsigned int gp : 16;
|
||||
unsigned int mark : 1;
|
||||
unsigned int debug : 1;
|
||||
unsigned int trace : 1;
|
||||
unsigned int spare : 1;
|
||||
unsigned int gcgen : 1;
|
||||
unsigned int gccls : 3;
|
||||
};
|
||||
|
||||
struct primsxp_struct {
|
||||
int offset;
|
||||
};
|
||||
|
||||
struct symsxp_struct {
|
||||
struct SEXPREC *pname;
|
||||
struct SEXPREC *value;
|
||||
struct SEXPREC *internal;
|
||||
};
|
||||
|
||||
struct listsxp_struct {
|
||||
struct SEXPREC *carval;
|
||||
struct SEXPREC *cdrval;
|
||||
struct SEXPREC *tagval;
|
||||
};
|
||||
|
||||
struct envsxp_struct {
|
||||
struct SEXPREC *frame;
|
||||
struct SEXPREC *enclos;
|
||||
struct SEXPREC *hashtab;
|
||||
};
|
||||
|
||||
struct closxp_struct {
|
||||
struct SEXPREC *formals;
|
||||
struct SEXPREC *body;
|
||||
struct SEXPREC *env;
|
||||
};
|
||||
|
||||
struct promsxp_struct {
|
||||
struct SEXPREC *value;
|
||||
struct SEXPREC *expr;
|
||||
struct SEXPREC *env;
|
||||
};
|
||||
|
||||
typedef struct SEXPREC {
|
||||
struct sxpinfo_struct sxpinfo;
|
||||
struct SEXPREC* attrib;
|
||||
struct SEXPREC* gengc_next_node, *gengc_prev_node;
|
||||
union {
|
||||
struct primsxp_struct primsxp;
|
||||
struct symsxp_struct symsxp;
|
||||
struct listsxp_struct listsxp;
|
||||
struct envsxp_struct envsxp;
|
||||
struct closxp_struct closxp;
|
||||
struct promsxp_struct promsxp;
|
||||
} u;
|
||||
} SEXPREC, *SEXP;
|
||||
|
||||
struct vecsxp_struct {
|
||||
int length;
|
||||
int truelength;
|
||||
};
|
||||
|
||||
typedef struct VECTOR_SEXPREC {
|
||||
struct sxpinfo_struct sxpinfo;
|
||||
struct SEXPREC* attrib;
|
||||
struct SEXPREC* gengc_next_node, *gengc_prev_node;
|
||||
struct vecsxp_struct vecsxp;
|
||||
} VECTOR_SEXPREC, *VECSEXP;
|
||||
|
||||
typedef union { VECTOR_SEXPREC s; double align; } SEXPREC_ALIGN;
|
||||
|
||||
#define DATAPTR(x) (((SEXPREC_ALIGN *) (x)) + 1)
|
||||
|
||||
#define R_CHAR_PTR(x) ((char *) DATAPTR(x))
|
||||
|
||||
#define R_INT_PTR(x) ((int *) DATAPTR(x))
|
||||
|
||||
#define R_REAL_PTR(x) ((double *) DATAPTR(x))
|
||||
|
||||
#define R_AS_INT(x) (*((int *) DATAPTR(x)))
|
||||
|
||||
#define R_IS_NULL(x) ((*(SEXP)(x)).sxpinfo.type == 0)
|
||||
|
||||
|
||||
// 64bit pointer
|
||||
#if INTPTR_MAX == INT64_MAX
|
||||
|
||||
#define R_ADDR(x) ((int64_t *) DATAPTR(x))
|
||||
|
||||
inline void R_SET_PTR(SEXP x, void* ptr) {
|
||||
if (ptr == nullptr) {
|
||||
R_ADDR(x)[0] = (int64_t)(NULL);
|
||||
} else {
|
||||
R_ADDR(x)[0] = (int64_t)(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
inline void* R_GET_PTR(SEXP x) {
|
||||
if (R_IS_NULL(x)) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto ret = (void *)(R_ADDR(x)[0]);
|
||||
if (ret == NULL) {
|
||||
ret = nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define R_ADDR(x) ((int32_t *) DATAPTR(x))
|
||||
|
||||
inline void R_SET_PTR(SEXP x, void* ptr) {
|
||||
if (ptr == nullptr) {
|
||||
R_ADDR(x)[0] = (int32_t)(NULL);
|
||||
} else {
|
||||
R_ADDR(x)[0] = (int32_t)(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
inline void* R_GET_PTR(SEXP x) {
|
||||
if (R_IS_NULL(x)) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto ret = (void *)(R_ADDR(x)[0]);
|
||||
if (ret == NULL) {
|
||||
ret = nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif // R_OBJECT_HELPER_H_
|
|
@ -0,0 +1,37 @@
|
|||
// application
|
||||
#include "../../src/application/application.cpp"
|
||||
|
||||
// boosting
|
||||
#include "../../src/boosting/boosting.cpp"
|
||||
#include "../../src/boosting/gbdt.cpp"
|
||||
|
||||
// io
|
||||
#include "../../src/io/bin.cpp"
|
||||
#include "../../src/io/config.cpp"
|
||||
#include "../../src/io/dataset.cpp"
|
||||
#include "../../src/io/dataset_loader.cpp"
|
||||
#include "../../src/io/metadata.cpp"
|
||||
#include "../../src/io/parser.cpp"
|
||||
#include "../../src/io/tree.cpp"
|
||||
|
||||
// metric
|
||||
#include "../../src/metric/dcg_calculator.cpp"
|
||||
#include "../../src/metric/metric.cpp"
|
||||
|
||||
// network
|
||||
#include "../../src/network/linker_topo.cpp"
|
||||
#include "../../src/network/linkers_socket.cpp"
|
||||
#include "../../src/network/network.cpp"
|
||||
|
||||
// objective
|
||||
#include "../../src/objective/objective_function.cpp"
|
||||
|
||||
// treelearner
|
||||
#include "../../src/treelearner/data_parallel_tree_learner.cpp"
|
||||
#include "../../src/treelearner/feature_parallel_tree_learner.cpp"
|
||||
#include "../../src/treelearner/serial_tree_learner.cpp"
|
||||
#include "../../src/treelearner/tree_learner.cpp"
|
||||
#include "../../src/treelearner/voting_parallel_tree_learner.cpp"
|
||||
|
||||
// c_api
|
||||
#include "../../src/c_api.cpp"
|
|
@ -0,0 +1,591 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include <omp.h>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include <LightGBM/utils/text_reader.h>
|
||||
#include <LightGBM/utils/common.h>
|
||||
|
||||
#include "./lightgbm_R.h"
|
||||
|
||||
#define COL_MAJOR (0)
|
||||
|
||||
#define R_API_BEGIN() \
|
||||
try {
|
||||
|
||||
#define R_API_END() } \
|
||||
catch(std::exception& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.what()); return call_state;} \
|
||||
catch(std::string& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.c_str()); return call_state; } \
|
||||
catch(...) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError("unknown exception"); return call_state;} \
|
||||
return call_state;
|
||||
|
||||
#define CHECK_CALL(x) \
|
||||
if ((x) != 0) { \
|
||||
R_INT_PTR(call_state)[0] = -1; \
|
||||
return call_state; \
|
||||
}
|
||||
|
||||
using namespace LightGBM;
|
||||
|
||||
SEXP EncodeChar(SEXP dest, const char* src, SEXP buf_len, SEXP actual_len) {
|
||||
int str_len = static_cast<int>(std::strlen(src));
|
||||
R_INT_PTR(actual_len)[0] = str_len;
|
||||
if (R_AS_INT(buf_len) < str_len) { return dest; }
|
||||
auto ptr = R_CHAR_PTR(dest);
|
||||
int i = 0;
|
||||
while (src[i] != '\0') {
|
||||
ptr[i] = src[i];
|
||||
++i;
|
||||
}
|
||||
return dest;
|
||||
}
|
||||
|
||||
SEXP LGBM_GetLastError_R(SEXP buf_len, SEXP actual_len, SEXP err_msg) {
|
||||
return EncodeChar(err_msg, LGBM_GetLastError(), buf_len, actual_len);
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
|
||||
SEXP parameters,
|
||||
SEXP reference,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
DatasetHandle handle;
|
||||
CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
|
||||
R_GET_PTR(reference), &handle));
|
||||
R_SET_PTR(out, handle);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
|
||||
SEXP indices,
|
||||
SEXP data,
|
||||
SEXP num_indptr,
|
||||
SEXP nelem,
|
||||
SEXP num_row,
|
||||
SEXP parameters,
|
||||
SEXP reference,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
const int* p_indptr = R_INT_PTR(indptr);
|
||||
const int* p_indices = R_INT_PTR(indices);
|
||||
const double* p_data = R_REAL_PTR(data);
|
||||
|
||||
int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
|
||||
int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
|
||||
int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
|
||||
DatasetHandle handle;
|
||||
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
|
||||
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
|
||||
nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
|
||||
R_SET_PTR(out, handle);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
|
||||
SEXP num_row,
|
||||
SEXP num_col,
|
||||
SEXP parameters,
|
||||
SEXP reference,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
|
||||
int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
|
||||
double* p_mat = R_REAL_PTR(data);
|
||||
DatasetHandle handle;
|
||||
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
|
||||
R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
|
||||
R_SET_PTR(out, handle);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
|
||||
SEXP used_row_indices,
|
||||
SEXP len_used_row_indices,
|
||||
SEXP parameters,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int len = R_AS_INT(len_used_row_indices);
|
||||
std::vector<int> idxvec(len);
|
||||
// convert from one-based to zero-based index
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < len; ++i) {
|
||||
idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
|
||||
}
|
||||
DatasetHandle res;
|
||||
CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
|
||||
idxvec.data(), len, R_CHAR_PTR(parameters),
|
||||
&res));
|
||||
R_SET_PTR(out, res);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
|
||||
SEXP feature_names,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
auto vec_names = Common::Split(R_CHAR_PTR(feature_names), "\t");
|
||||
std::vector<const char*> vec_sptr;
|
||||
int len = static_cast<int>(vec_names.size());
|
||||
for (int i = 0; i < len; ++i) {
|
||||
vec_sptr.push_back(vec_names[i].c_str());
|
||||
}
|
||||
CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
|
||||
vec_sptr.data(), len));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
SEXP feature_names,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int len = 0;
|
||||
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
|
||||
std::vector<std::vector<char>> names(len);
|
||||
std::vector<char*> ptr_names(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
names[i].resize(256);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
int out_len;
|
||||
CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
|
||||
ptr_names.data(), &out_len));
|
||||
CHECK(len == out_len);
|
||||
auto merge_str = Common::Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
|
||||
SEXP filename,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
|
||||
R_CHAR_PTR(filename)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetFree_R(SEXP handle,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
if (R_GET_PTR(handle) != nullptr) {
|
||||
CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
|
||||
R_SET_PTR(handle, nullptr);
|
||||
}
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetSetField_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP field_data,
|
||||
SEXP num_element,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
int len = static_cast<int>(R_AS_INT(num_element));
|
||||
const char* name = R_CHAR_PTR(field_name);
|
||||
if (!strcmp("group", name) || !strcmp("query", name)) {
|
||||
std::vector<int32_t> vec(len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < len; ++i) {
|
||||
vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
|
||||
}
|
||||
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
|
||||
} else {
|
||||
std::vector<float> vec(len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < len; ++i) {
|
||||
vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
|
||||
}
|
||||
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
|
||||
}
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetField_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP field_data,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
const char* name = R_CHAR_PTR(field_name);
|
||||
int out_len = 0;
|
||||
int out_type = 0;
|
||||
const void* res;
|
||||
CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
|
||||
|
||||
if (!strcmp("group", name) || !strcmp("query", name)) {
|
||||
auto p_data = reinterpret_cast<const int32_t*>(res);
|
||||
// convert from boundaries to size
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < out_len - 1; ++i) {
|
||||
R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
|
||||
}
|
||||
} else {
|
||||
auto p_data = reinterpret_cast<const float*>(res);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int i = 0; i < out_len; ++i) {
|
||||
R_REAL_PTR(field_data)[i] = p_data[i];
|
||||
}
|
||||
}
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
const char* name = R_CHAR_PTR(field_name);
|
||||
int out_len = 0;
|
||||
int out_type = 0;
|
||||
const void* res;
|
||||
CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
|
||||
if (!strcmp("group", name) || !strcmp("query", name)) {
|
||||
out_len -= 1;
|
||||
}
|
||||
R_INT_PTR(out)[0] = static_cast<int>(out_len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out,
|
||||
SEXP call_state) {
|
||||
int nrow;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
|
||||
R_INT_PTR(out)[0] = static_cast<int>(nrow);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
int nfeature;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
|
||||
R_INT_PTR(out)[0] = static_cast<int>(nfeature);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
// --- start Booster interfaces
|
||||
|
||||
SEXP LGBM_BoosterFree_R(SEXP handle,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
if (R_GET_PTR(handle) != nullptr) {
|
||||
CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
|
||||
R_SET_PTR(handle, nullptr);
|
||||
}
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterCreate_R(SEXP train_data,
|
||||
SEXP parameters,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
BoosterHandle handle;
|
||||
CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
|
||||
R_SET_PTR(out, handle);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int out_num_iterations = 0;
|
||||
BoosterHandle handle;
|
||||
CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
|
||||
R_SET_PTR(out, handle);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterMerge_R(SEXP handle,
|
||||
SEXP other_handle,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
|
||||
SEXP valid_data,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
|
||||
SEXP train_data,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
|
||||
SEXP parameters,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
int num_class;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
|
||||
R_INT_PTR(out)[0] = static_cast<int>(num_class);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle,
|
||||
SEXP call_state) {
|
||||
int is_finished = 0;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
|
||||
SEXP grad,
|
||||
SEXP hess,
|
||||
SEXP len,
|
||||
SEXP call_state) {
|
||||
int is_finished = 0;
|
||||
R_API_BEGIN();
|
||||
int int_len = R_AS_INT(len);
|
||||
std::vector<float> tgrad(int_len), thess(int_len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (int j = 0; j < int_len; ++j) {
|
||||
tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
|
||||
thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
|
||||
}
|
||||
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
|
||||
int out_iteration;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
|
||||
R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
SEXP eval_names,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int len;
|
||||
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
|
||||
std::vector<std::vector<char>> names(len);
|
||||
std::vector<char*> ptr_names(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
names[i].resize(128);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
int out_len;
|
||||
CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
|
||||
CHECK(out_len == len);
|
||||
auto merge_names = Common::Join<char*>(ptr_names, "\t");
|
||||
EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetEval_R(SEXP handle,
|
||||
SEXP data_idx,
|
||||
SEXP out_result,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
int len;
|
||||
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
|
||||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
int out_len;
|
||||
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
|
||||
CHECK(out_len == len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
|
||||
SEXP data_idx,
|
||||
SEXP out,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
int64_t len;
|
||||
CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
|
||||
R_INT_PTR(out)[0] = static_cast<int>(len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
|
||||
SEXP data_idx,
|
||||
SEXP out_result,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
int64_t out_len;
|
||||
CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx) {
|
||||
int pred_type = C_API_PREDICT_NORMAL;
|
||||
if (R_AS_INT(is_rawscore)) {
|
||||
pred_type = C_API_PREDICT_RAW_SCORE;
|
||||
}
|
||||
if (R_AS_INT(is_leafidx)) {
|
||||
pred_type = C_API_PREDICT_LEAF_INDEX;
|
||||
}
|
||||
return pred_type;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
|
||||
SEXP data_filename,
|
||||
SEXP data_has_header,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP result_filename,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx);
|
||||
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
|
||||
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration),
|
||||
R_CHAR_PTR(result_filename)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
|
||||
SEXP num_row,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP out_len,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx);
|
||||
int64_t len = 0;
|
||||
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
|
||||
pred_type, R_AS_INT(num_iteration), &len));
|
||||
R_INT_PTR(out_len)[0] = static_cast<int>(len);
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
|
||||
SEXP indptr,
|
||||
SEXP indices,
|
||||
SEXP data,
|
||||
SEXP num_indptr,
|
||||
SEXP nelem,
|
||||
SEXP num_row,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP out_result,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx);
|
||||
|
||||
const int* p_indptr = R_INT_PTR(indptr);
|
||||
const int* p_indices = R_INT_PTR(indices);
|
||||
const double* p_data = R_REAL_PTR(data);
|
||||
|
||||
int64_t nindptr = R_AS_INT(num_indptr);
|
||||
int64_t ndata = R_AS_INT(nelem);
|
||||
int64_t nrow = R_AS_INT(num_row);
|
||||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
int64_t out_len;
|
||||
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
|
||||
p_indptr, C_API_DTYPE_INT32, p_indices,
|
||||
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
|
||||
nrow, pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
|
||||
SEXP data,
|
||||
SEXP num_row,
|
||||
SEXP num_col,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP out_result,
|
||||
SEXP call_state) {
|
||||
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx);
|
||||
|
||||
int32_t nrow = R_AS_INT(num_row);
|
||||
int32_t ncol = R_AS_INT(num_col);
|
||||
|
||||
double* p_mat = R_REAL_PTR(data);
|
||||
double* ptr_ret = R_REAL_PTR(out_result);
|
||||
int64_t out_len;
|
||||
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
|
||||
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
|
||||
pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret));
|
||||
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP filename,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP buffer_len,
|
||||
SEXP actual_len,
|
||||
SEXP out_str,
|
||||
SEXP call_state) {
|
||||
R_API_BEGIN();
|
||||
int out_len = 0;
|
||||
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
|
||||
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
|
||||
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
|
||||
if (out_len < R_AS_INT(buffer_len)) {
|
||||
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
|
||||
} else {
|
||||
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
|
||||
}
|
||||
R_API_END();
|
||||
}
|
|
@ -0,0 +1,486 @@
|
|||
#ifndef LIGHTGBM_R_H_
|
||||
#define LIGHTGBM_R_H_
|
||||
|
||||
#include <LightGBM/utils/log.h>
|
||||
#include <cstdint>
|
||||
#include <LightGBM/c_api.h>
|
||||
|
||||
#include "R_object_helper.h"
|
||||
|
||||
|
||||
|
||||
/*!
|
||||
* \brief get string message of the last error
|
||||
* all function in this file will return 0 when succeed
|
||||
* and -1 when an error occured,
|
||||
* \return err_msg error inforomation
|
||||
* \return error inforomation
|
||||
*/
|
||||
DllExport SEXP LGBM_GetLastError_R(SEXP buf_len, SEXP actual_len, SEXP err_msg);
|
||||
|
||||
// --- start Dataset interface
|
||||
|
||||
/*!
|
||||
* \brief load data set from file like the command_line LightGBM do
|
||||
* \param filename the name of the file
|
||||
* \param parameters additional parameters
|
||||
* \param reference used to align bin mapper with other dataset, nullptr means don't used
|
||||
* \param out created dataset
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
|
||||
SEXP parameters,
|
||||
SEXP reference,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief create a dataset from CSC format
|
||||
* \param indptr pointer to row headers
|
||||
* \param indices findex
|
||||
* \param data fvalue
|
||||
* \param nindptr number of cols in the matrix + 1
|
||||
* \param nelem number of nonzero elements in the matrix
|
||||
* \param num_row number of rows
|
||||
* \param parameters additional parameters
|
||||
* \param reference used to align bin mapper with other dataset, nullptr means don't used
|
||||
* \param out created dataset
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
|
||||
SEXP indices,
|
||||
SEXP data,
|
||||
SEXP nindptr,
|
||||
SEXP nelem,
|
||||
SEXP num_row,
|
||||
SEXP parameters,
|
||||
SEXP reference,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief create dataset from dense matrix
|
||||
* \param data matric data
|
||||
* \param nrow number of rows
|
||||
* \param ncol number columns
|
||||
* \param parameters additional parameters
|
||||
* \param reference used to align bin mapper with other dataset, nullptr means don't used
|
||||
* \param out created dataset
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
|
||||
SEXP nrow,
|
||||
SEXP ncol,
|
||||
SEXP parameters,
|
||||
SEXP reference,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Create subset of a data
|
||||
* \param handle handle of full dataset
|
||||
* \param used_row_indices Indices used in subset
|
||||
* \param len_used_row_indices length of Indices used in subset
|
||||
* \param parameters additional parameters
|
||||
* \param out created dataset
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetGetSubset_R(SEXP handle,
|
||||
SEXP used_row_indices,
|
||||
SEXP len_used_row_indices,
|
||||
SEXP parameters,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief save feature names to Dataset
|
||||
* \param handle handle
|
||||
* \param feature_names feature names
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
|
||||
SEXP feature_names,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief save feature names to Dataset
|
||||
* \param handle handle
|
||||
* \param feature_names feature names
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
SEXP feature_names,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief save dateset to binary file
|
||||
* \param handle a instance of dataset
|
||||
* \param filename file name
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
|
||||
SEXP filename,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief free dataset
|
||||
* \param handle a instance of dataset
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetFree_R(SEXP handle,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief set vector to a content in info
|
||||
* Note: group and group only work for C_API_DTYPE_INT32
|
||||
* label and weight only work for C_API_DTYPE_FLOAT32
|
||||
* \param handle a instance of dataset
|
||||
* \param field_name field name, can be label, weight, group, group_id
|
||||
* \param field_data pointer to vector
|
||||
* \param num_element number of element in field_data
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetSetField_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP field_data,
|
||||
SEXP num_element,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief get size of info vector from dataset
|
||||
* \param handle a instance of dataset
|
||||
* \param field_name field name
|
||||
* \param out size of info vector from dataset
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief get info vector from dataset
|
||||
* \param handle a instance of dataset
|
||||
* \param field_name field name
|
||||
* \param field_data pointer to vector
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetGetField_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP field_data,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief get number of data.
|
||||
* \param handle the handle to the dataset
|
||||
* \param out The address to hold number of data
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetGetNumData_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief get number of features
|
||||
* \param handle the handle to the dataset
|
||||
* \param out The output of number of features
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
// --- start Booster interfaces
|
||||
|
||||
/*!
|
||||
* \brief create an new boosting learner
|
||||
* \param train_data training data set
|
||||
* \param parameters format: 'key1=value1 key2=value2'
|
||||
* \prama out handle of created Booster
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterCreate_R(SEXP train_data,
|
||||
SEXP parameters,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief free obj in handle
|
||||
* \param handle handle to be freed
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterFree_R(SEXP handle,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief load an existing boosting from model file
|
||||
* \param filename filename of model
|
||||
* \prama out handle of created Booster
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Merge model in two booster to first handle
|
||||
* \param handle handle, will merge other handle to this
|
||||
* \param other_handle
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterMerge_R(SEXP handle,
|
||||
SEXP other_handle,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Add new validation to booster
|
||||
* \param handle handle
|
||||
* \param valid_data validation data set
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterAddValidData_R(SEXP handle,
|
||||
SEXP valid_data,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Reset training data for booster
|
||||
* \param handle handle
|
||||
* \param train_data training data set
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
|
||||
SEXP train_data,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Reset config for current booster
|
||||
* \param handle handle
|
||||
* \param parameters format: 'key1=value1 key2=value2'
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterResetParameter_R(SEXP handle,
|
||||
SEXP parameters,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get number of class
|
||||
* \param handle handle
|
||||
* \param out number of classes
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief update the model in one round
|
||||
* \param handle handle
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief update the model, by directly specify gradient and second order gradient,
|
||||
* this can be used to support customized loss function
|
||||
* \param handle handle
|
||||
* \param grad gradient statistics
|
||||
* \param hess second order gradient statistics
|
||||
* \param len length of grad/hess
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
|
||||
SEXP grad,
|
||||
SEXP hess,
|
||||
SEXP len,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Rollback one iteration
|
||||
* \param handle handle
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get iteration of current boosting rounds
|
||||
* \param out iteration of boosting rounds
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get Name of eval
|
||||
* \param eval_names eval names
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterGetEvalNames_R(SEXP handle,
|
||||
SEXP buf_len,
|
||||
SEXP actual_len,
|
||||
SEXP eval_names,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief get evaluation for training data and validation data
|
||||
* \param handle handle
|
||||
* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ...
|
||||
* \param out_result float arrary contains result
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterGetEval_R(SEXP handle,
|
||||
SEXP data_idx,
|
||||
SEXP out_result,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get number of prediction for training data and validation data
|
||||
* \param handle handle
|
||||
* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ...
|
||||
* \param out size of predict
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
|
||||
SEXP data_idx,
|
||||
SEXP out,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get prediction for training data and validation data
|
||||
this can be used to support customized eval function
|
||||
* \param handle handle
|
||||
* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ...
|
||||
* \param out_result, used to store predict result, should pre-allocate memory
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterGetPredict_R(SEXP handle,
|
||||
SEXP data_idx,
|
||||
SEXP out_result,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief make prediction for file
|
||||
* \param handle handle
|
||||
* \param data_filename filename of data file
|
||||
* \param data_has_header data file has header or not
|
||||
* \param is_rawscore
|
||||
* \param is_leafidx
|
||||
* \param num_iteration number of iteration for prediction, <= 0 means no limit
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
|
||||
SEXP data_filename,
|
||||
SEXP data_has_header,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP result_filename,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief Get number of prediction
|
||||
* \param handle handle
|
||||
* \param num_row
|
||||
* \param is_rawscore
|
||||
* \param is_leafidx
|
||||
* \param num_iteration number of iteration for prediction, <= 0 means no limit
|
||||
* \param out_len lenght of prediction
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
|
||||
SEXP num_row,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP out_len,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief make prediction for an new data set
|
||||
* Note: should pre-allocate memory for out_result,
|
||||
* for noraml and raw score: its length is equal to num_class * num_data
|
||||
* for leaf index, its length is equal to num_class * num_data * num_iteration
|
||||
* \param handle handle
|
||||
* \param indptr pointer to row headers
|
||||
* \param indices findex
|
||||
* \param data fvalue
|
||||
* \param nindptr number of cols in the matrix + 1
|
||||
* \param nelem number of nonzero elements in the matrix
|
||||
* \param num_row number of rows
|
||||
* \param is_rawscore
|
||||
* \param is_leafidx
|
||||
* \param num_iteration number of iteration for prediction, <= 0 means no limit
|
||||
* \param out prediction result
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
|
||||
SEXP indptr,
|
||||
SEXP indices,
|
||||
SEXP data,
|
||||
SEXP nindptr,
|
||||
SEXP nelem,
|
||||
SEXP num_row,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP out_result,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief make prediction for an new data set
|
||||
* Note: should pre-allocate memory for out_result,
|
||||
* for noraml and raw score: its length is equal to num_class * num_data
|
||||
* for leaf index, its length is equal to num_class * num_data * num_iteration
|
||||
* \param handle handle
|
||||
* \param data pointer to the data space
|
||||
* \param nrow number of rows
|
||||
* \param ncol number columns
|
||||
* \param is_rawscore
|
||||
* \param is_leafidx
|
||||
* \param num_iteration number of iteration for prediction, <= 0 means no limit
|
||||
* \param out prediction result
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
|
||||
SEXP data,
|
||||
SEXP nrow,
|
||||
SEXP ncol,
|
||||
SEXP is_rawscore,
|
||||
SEXP is_leafidx,
|
||||
SEXP num_iteration,
|
||||
SEXP out_result,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief save model into file
|
||||
* \param handle handle
|
||||
* \param num_iteration, <= 0 means save all
|
||||
* \param filename file name
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterSaveModel_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP filename,
|
||||
SEXP call_state);
|
||||
|
||||
/*!
|
||||
* \brief dump model to json
|
||||
* \param handle handle
|
||||
* \param num_iteration, <= 0 means save all
|
||||
* \param out_str json format string of model
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
DllExport SEXP LGBM_BoosterDumpModel_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP buffer_len,
|
||||
SEXP actual_len,
|
||||
SEXP out_str,
|
||||
SEXP call_state);
|
||||
|
||||
#endif // LIGHTGBM_R_H_
|
|
@ -0,0 +1,4 @@
|
|||
library(testthat)
|
||||
library(lightgbm)
|
||||
|
||||
test_check("lightgbm")
|
|
@ -0,0 +1,77 @@
|
|||
require(lightgbm)
|
||||
|
||||
context("basic functions")
|
||||
|
||||
data(agaricus.train, package='lightgbm')
|
||||
data(agaricus.test, package='lightgbm')
|
||||
train <- agaricus.train
|
||||
test <- agaricus.test
|
||||
|
||||
windows_flag = grepl('Windows', Sys.info()[['sysname']])
|
||||
|
||||
test_that("train and predict binary classification", {
|
||||
nrounds = 10
|
||||
bst <- lightgbm(data = train$data, label = train$label, num_leaves = 5,
|
||||
nrounds = nrounds, objective = "binary", metric="binary_error")
|
||||
expect_false(is.null(bst$record_evals))
|
||||
record_results <- lgb.get.eval.result(bst, "train", "binary_error")
|
||||
expect_lt(min(record_results), 0.02)
|
||||
|
||||
pred <- predict(bst, test$data)
|
||||
expect_length(pred, 1611)
|
||||
|
||||
pred1 <- predict(bst, train$data, num_iteration = 1)
|
||||
expect_length(pred1, 6513)
|
||||
err_pred1 <- sum((pred1 > 0.5) != train$label)/length(train$label)
|
||||
err_log <- record_results[1]
|
||||
expect_lt(abs(err_pred1 - err_log), 10e-6)
|
||||
})
|
||||
|
||||
|
||||
test_that("train and predict softmax", {
|
||||
lb <- as.numeric(iris$Species) - 1
|
||||
|
||||
bst <- lightgbm(data = as.matrix(iris[, -5]), label = lb,
|
||||
num_leaves = 4, learning_rate = 0.1, nrounds = 20, min_data=20, min_hess=20,
|
||||
objective = "multiclass", metric="multi_error", num_class=3)
|
||||
|
||||
expect_false(is.null(bst$record_evals))
|
||||
record_results <- lgb.get.eval.result(bst, "train", "multi_error")
|
||||
expect_lt(min(record_results), 0.03)
|
||||
|
||||
pred <- predict(bst, as.matrix(iris[, -5]))
|
||||
expect_length(pred, nrow(iris) * 3)
|
||||
|
||||
})
|
||||
|
||||
|
||||
test_that("use of multiple eval metrics works", {
|
||||
bst <- lightgbm(data = train$data, label = train$label, num_leaves = 4,
|
||||
learning_rate=1, nrounds = 10, objective = "binary",
|
||||
metric = list("binary_error","auc","binary_logloss") )
|
||||
expect_false(is.null(bst$record_evals))
|
||||
})
|
||||
|
||||
|
||||
test_that("training continuation works", {
|
||||
dtrain <- lgb.Dataset(train$data, label = train$label, free_raw_data=FALSE)
|
||||
watchlist = list(train=dtrain)
|
||||
param <- list(objective = "binary", metric="binary_logloss", num_leaves = 5, learning_rate = 1)
|
||||
|
||||
# for the reference, use 10 iterations at once:
|
||||
bst <- lgb.train(param, dtrain, nrounds = 10, watchlist)
|
||||
err_bst <- lgb.get.eval.result(bst, "train", "binary_logloss", 10)
|
||||
# first 5 iterations:
|
||||
bst1 <- lgb.train(param, dtrain, nrounds = 5, watchlist)
|
||||
# test continuing from a model in file
|
||||
lgb.save(bst1, "lightgbm.model")
|
||||
# continue for 5 more:
|
||||
bst2 <- lgb.train(param, dtrain, nrounds = 5, watchlist, init_model = bst1)
|
||||
err_bst2 <- lgb.get.eval.result(bst2, "train", "binary_logloss", 10)
|
||||
expect_lt(abs(err_bst - err_bst2), 0.01)
|
||||
|
||||
bst2 <- lgb.train(param, dtrain, nrounds = 5, watchlist, init_model = "lightgbm.model")
|
||||
err_bst2 <- lgb.get.eval.result(bst2, "train", "binary_logloss", 10)
|
||||
expect_lt(abs(err_bst - err_bst2), 0.01)
|
||||
})
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
context('Test models with custom objective')
|
||||
|
||||
require(lightgbm)
|
||||
|
||||
data(agaricus.train, package='lightgbm')
|
||||
data(agaricus.test, package='lightgbm')
|
||||
dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label)
|
||||
dtest <- lgb.Dataset(agaricus.test$data, label = agaricus.test$label)
|
||||
watchlist <- list(eval = dtest, train = dtrain)
|
||||
|
||||
logregobj <- function(preds, dtrain) {
|
||||
labels <- getinfo(dtrain, "label")
|
||||
preds <- 1 / (1 + exp(-preds))
|
||||
grad <- preds - labels
|
||||
hess <- preds * (1 - preds)
|
||||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
|
||||
evalerror <- function(preds, dtrain) {
|
||||
labels <- getinfo(dtrain, "label")
|
||||
err <- as.numeric(sum(labels != (preds > 0))) / length(labels)
|
||||
return(list(name = "error", value = err, higher_better=FALSE))
|
||||
}
|
||||
|
||||
param <- list(num_leaves=8, learning_rate=1,
|
||||
objective=logregobj, metric="auc")
|
||||
num_round <- 10
|
||||
|
||||
test_that("custom objective works", {
|
||||
bst <- lgb.train(param, dtrain, num_round, watchlist, eval=evalerror)
|
||||
expect_false(is.null(bst$record_evals))
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
require(lightgbm)
|
||||
require(Matrix)
|
||||
|
||||
context("testing lgb.Dataset functionality")
|
||||
|
||||
data(agaricus.test, package='lightgbm')
|
||||
test_data <- agaricus.test$data[1:100,]
|
||||
test_label <- agaricus.test$label[1:100]
|
||||
|
||||
test_that("lgb.Dataset: basic construction, saving, loading", {
|
||||
# from sparse matrix
|
||||
dtest1 <- lgb.Dataset(test_data, label=test_label)
|
||||
# from dense matrix
|
||||
dtest2 <- lgb.Dataset(as.matrix(test_data), label=test_label)
|
||||
expect_equal(getinfo(dtest1, 'label'), getinfo(dtest2, 'label'))
|
||||
|
||||
# save to a local file
|
||||
tmp_file <- tempfile('lgb.Dataset_')
|
||||
lgb.Dataset.save(dtest1, tmp_file)
|
||||
# read from a local file
|
||||
dtest3 <- lgb.Dataset(tmp_file)
|
||||
lgb.Dataset.construct(dtest3)
|
||||
unlink(tmp_file)
|
||||
expect_equal(getinfo(dtest1, 'label'), getinfo(dtest3, 'label'))
|
||||
})
|
||||
|
||||
test_that("lgb.Dataset: getinfo & setinfo", {
|
||||
dtest <- lgb.Dataset(test_data)
|
||||
setinfo(dtest, 'label', test_label)
|
||||
labels <- getinfo(dtest, 'label')
|
||||
expect_equal(test_label, getinfo(dtest, 'label'))
|
||||
|
||||
expect_true(length(getinfo(dtest, 'weight')) == 0)
|
||||
expect_true(length(getinfo(dtest, 'init_score')) == 0)
|
||||
|
||||
# any other label should error
|
||||
expect_error(setinfo(dtest, 'asdf', test_label))
|
||||
})
|
||||
|
||||
test_that("lgb.Dataset: slice, dim", {
|
||||
dtest <- lgb.Dataset(test_data, label=test_label)
|
||||
lgb.Dataset.construct(dtest)
|
||||
expect_equal(dim(dtest), dim(test_data))
|
||||
dsub1 <- slice(dtest, 1:42)
|
||||
lgb.Dataset.construct(dsub1)
|
||||
expect_equal(nrow(dsub1), 42)
|
||||
expect_equal(ncol(dsub1), ncol(test_data))
|
||||
})
|
||||
|
||||
test_that("lgb.Dataset: colnames", {
|
||||
dtest <- lgb.Dataset(test_data, label=test_label)
|
||||
expect_equal(colnames(dtest), colnames(test_data))
|
||||
lgb.Dataset.construct(dtest)
|
||||
expect_equal(colnames(dtest), colnames(test_data))
|
||||
expect_error( colnames(dtest) <- 'asdf')
|
||||
new_names <- make.names(1:ncol(test_data))
|
||||
expect_silent(colnames(dtest) <- new_names)
|
||||
expect_equal(colnames(dtest), new_names)
|
||||
})
|
||||
|
||||
test_that("lgb.Dataset: nrow is correct for a very sparse matrix", {
|
||||
nr <- 1000
|
||||
x <- rsparsematrix(nr, 100, density=0.0005)
|
||||
# we want it very sparse, so that last rows are empty
|
||||
expect_lt(max(x@i), nr)
|
||||
dtest <- lgb.Dataset(x)
|
||||
expect_equal(dim(dtest), dim(x))
|
||||
})
|
|
@ -559,7 +559,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
|
|||
// get feature names
|
||||
line = Common::FindFromLines(lines, "feature_names=");
|
||||
if (line.size() > 0) {
|
||||
feature_names_ = Common::Split(Common::Split(line.c_str(), '=')[1].c_str(), ' ');
|
||||
feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), " ");
|
||||
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
|
||||
Log::Fatal("Wrong size of feature_names");
|
||||
return;
|
||||
|
|
|
@ -8,6 +8,10 @@
|
|||
namespace LightGBM {
|
||||
|
||||
Metadata::Metadata() {
|
||||
num_weights_ = 0;
|
||||
num_init_score_ = 0;
|
||||
num_data_ = 0;
|
||||
num_queries_ = 0;
|
||||
}
|
||||
|
||||
void Metadata::Init(const char * data_filename) {
|
||||
|
|
|
@ -116,7 +116,7 @@ public:
|
|||
}
|
||||
|
||||
inline static const char* Name() {
|
||||
return "logloss";
|
||||
return "binary_logloss";
|
||||
}
|
||||
};
|
||||
/*!
|
||||
|
@ -135,7 +135,7 @@ public:
|
|||
}
|
||||
|
||||
inline static const char* Name() {
|
||||
return "error";
|
||||
return "binary_error";
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ class TestEngine(unittest.TestCase):
|
|||
}
|
||||
evals_result, ret = test_template(params, X_y, log_loss)
|
||||
self.assertLess(ret, 0.15)
|
||||
self.assertAlmostEqual(min(evals_result['eval']['logloss']), ret, places=5)
|
||||
self.assertAlmostEqual(min(evals_result['eval']['binary_logloss']), ret, places=5)
|
||||
|
||||
def test_regreesion(self):
|
||||
evals_result, ret = test_template()
|
||||
|
|
Загрузка…
Ссылка в новой задаче