зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Add sparse feature contribution predictions (#5108)
* add predcontrib for sparse inputs * register newly-added function * comments * correct wrong types in test * forcibly take transpose function from Matrix * keep row names, test comparison to dense inputs * workaround for passing test while PR for row names is not merged * Update R-package/R/lgb.Predictor.R Co-authored-by: James Lamb <jaylamb20@gmail.com> * Update R-package/R/lgb.Predictor.R Co-authored-by: James Lamb <jaylamb20@gmail.com> * Update R-package/R/lgb.Predictor.R Co-authored-by: James Lamb <jaylamb20@gmail.com> * proper handling of integer overflow * add test for CSR contrib row names * add more tests for predict(<sparse>, predcontrib=TRUE) * make linter happy * linter * linter * check error messages for bad input shapes * fix regex * hard-coded number of columns in regex for tests Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Родитель
688f73d14a
Коммит
6f92d47aad
|
@ -37,6 +37,10 @@ export(saveRDS.lgb.Booster)
|
||||||
export(set_field)
|
export(set_field)
|
||||||
export(slice)
|
export(slice)
|
||||||
import(methods)
|
import(methods)
|
||||||
|
importClassesFrom(Matrix,dgCMatrix)
|
||||||
|
importClassesFrom(Matrix,dgRMatrix)
|
||||||
|
importClassesFrom(Matrix,dsparseMatrix)
|
||||||
|
importClassesFrom(Matrix,dsparseVector)
|
||||||
importFrom(Matrix,Matrix)
|
importFrom(Matrix,Matrix)
|
||||||
importFrom(R6,R6Class)
|
importFrom(R6,R6Class)
|
||||||
importFrom(data.table,":=")
|
importFrom(data.table,":=")
|
||||||
|
@ -51,6 +55,7 @@ importFrom(graphics,barplot)
|
||||||
importFrom(graphics,par)
|
importFrom(graphics,par)
|
||||||
importFrom(jsonlite,fromJSON)
|
importFrom(jsonlite,fromJSON)
|
||||||
importFrom(methods,is)
|
importFrom(methods,is)
|
||||||
|
importFrom(methods,new)
|
||||||
importFrom(parallel,detectCores)
|
importFrom(parallel,detectCores)
|
||||||
importFrom(stats,quantile)
|
importFrom(stats,quantile)
|
||||||
importFrom(utils,modifyList)
|
importFrom(utils,modifyList)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#' @importFrom methods is
|
#' @importFrom methods is new
|
||||||
|
#' @importClassesFrom Matrix dsparseMatrix dsparseVector dgCMatrix dgRMatrix
|
||||||
#' @importFrom R6 R6Class
|
#' @importFrom R6 R6Class
|
||||||
#' @importFrom utils read.delim
|
#' @importFrom utils read.delim
|
||||||
Predictor <- R6::R6Class(
|
Predictor <- R6::R6Class(
|
||||||
|
@ -126,6 +127,113 @@ Predictor <- R6::R6Class(
|
||||||
num_row <- nrow(preds)
|
num_row <- nrow(preds)
|
||||||
preds <- as.vector(t(preds))
|
preds <- as.vector(t(preds))
|
||||||
|
|
||||||
|
} else if (predcontrib && inherits(data, c("dsparseMatrix", "dsparseVector"))) {
|
||||||
|
|
||||||
|
ncols <- .Call(LGBM_BoosterGetNumFeature_R, private$handle)
|
||||||
|
ncols_out <- integer(1L)
|
||||||
|
.Call(LGBM_BoosterGetNumClasses_R, private$handle, ncols_out)
|
||||||
|
ncols_out <- (ncols + 1L) * max(ncols_out, 1L)
|
||||||
|
if (is.na(ncols_out)) {
|
||||||
|
ncols_out <- as.numeric(ncols + 1L) * as.numeric(max(ncols_out, 1L))
|
||||||
|
}
|
||||||
|
if (!inherits(data, "dsparseVector") && ncols_out > .Machine$integer.max) {
|
||||||
|
stop("Resulting matrix of feature contributions is too large for R to handle.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inherits(data, "dsparseVector")) {
|
||||||
|
|
||||||
|
if (length(data) > ncols) {
|
||||||
|
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
|
||||||
|
, ncols
|
||||||
|
, length(data)))
|
||||||
|
}
|
||||||
|
res <- .Call(
|
||||||
|
LGBM_BoosterPredictSparseOutput_R
|
||||||
|
, private$handle
|
||||||
|
, c(0L, as.integer(length(data@x)))
|
||||||
|
, data@i - 1L
|
||||||
|
, data@x
|
||||||
|
, TRUE
|
||||||
|
, 1L
|
||||||
|
, ncols
|
||||||
|
, start_iteration
|
||||||
|
, num_iteration
|
||||||
|
, private$params
|
||||||
|
)
|
||||||
|
out <- methods::new("dsparseVector")
|
||||||
|
out@i <- res$indices + 1L
|
||||||
|
out@x <- res$data
|
||||||
|
out@length <- ncols_out
|
||||||
|
return(out)
|
||||||
|
|
||||||
|
} else if (inherits(data, "dgRMatrix")) {
|
||||||
|
|
||||||
|
if (ncol(data) > ncols) {
|
||||||
|
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
|
||||||
|
, ncols
|
||||||
|
, ncol(data)))
|
||||||
|
}
|
||||||
|
res <- .Call(
|
||||||
|
LGBM_BoosterPredictSparseOutput_R
|
||||||
|
, private$handle
|
||||||
|
, data@p
|
||||||
|
, data@j
|
||||||
|
, data@x
|
||||||
|
, TRUE
|
||||||
|
, nrow(data)
|
||||||
|
, ncols
|
||||||
|
, start_iteration
|
||||||
|
, num_iteration
|
||||||
|
, private$params
|
||||||
|
)
|
||||||
|
out <- methods::new("dgRMatrix")
|
||||||
|
out@p <- res$indptr
|
||||||
|
out@j <- res$indices
|
||||||
|
out@x <- res$data
|
||||||
|
out@Dim <- as.integer(c(nrow(data), ncols_out))
|
||||||
|
|
||||||
|
} else if (inherits(data, "dgCMatrix")) {
|
||||||
|
|
||||||
|
if (ncol(data) != ncols) {
|
||||||
|
stop(sprintf("Model was fitted to data with %d columns, input data has %.0f columns."
|
||||||
|
, ncols
|
||||||
|
, ncol(data)))
|
||||||
|
}
|
||||||
|
res <- .Call(
|
||||||
|
LGBM_BoosterPredictSparseOutput_R
|
||||||
|
, private$handle
|
||||||
|
, data@p
|
||||||
|
, data@i
|
||||||
|
, data@x
|
||||||
|
, FALSE
|
||||||
|
, nrow(data)
|
||||||
|
, ncols
|
||||||
|
, start_iteration
|
||||||
|
, num_iteration
|
||||||
|
, private$params
|
||||||
|
)
|
||||||
|
out <- methods::new("dgCMatrix")
|
||||||
|
out@p <- res$indptr
|
||||||
|
out@i <- res$indices
|
||||||
|
out@x <- res$data
|
||||||
|
out@Dim <- as.integer(c(nrow(data), length(res$indptr) - 1L))
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
stop(sprintf("Predictions on sparse inputs are only allowed for '%s', '%s', '%s' - got: %s"
|
||||||
|
, "dsparseVector"
|
||||||
|
, "dgRMatrix"
|
||||||
|
, "dgCMatrix"
|
||||||
|
, paste(class(data)
|
||||||
|
, collapse = ", ")))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if (NROW(row.names(data))) {
|
||||||
|
out@Dimnames[[1L]] <- row.names(data)
|
||||||
|
}
|
||||||
|
return(out)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
# Not a file, we need to predict from R object
|
# Not a file, we need to predict from R object
|
||||||
|
|
|
@ -65,6 +65,14 @@ SEXP wrapped_R_raw(void *len) {
|
||||||
return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
|
return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SEXP wrapped_R_int(void *len) {
|
||||||
|
return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
|
||||||
|
}
|
||||||
|
|
||||||
|
SEXP wrapped_R_real(void *len) {
|
||||||
|
return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
|
||||||
|
}
|
||||||
|
|
||||||
SEXP wrapped_Rf_mkChar(void *txt) {
|
SEXP wrapped_Rf_mkChar(void *txt) {
|
||||||
return Rf_mkChar(reinterpret_cast<char*>(txt));
|
return Rf_mkChar(reinterpret_cast<char*>(txt));
|
||||||
}
|
}
|
||||||
|
@ -84,6 +92,14 @@ SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
|
||||||
return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
|
return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
|
||||||
|
return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
|
||||||
|
return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
|
||||||
|
}
|
||||||
|
|
||||||
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
|
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
|
||||||
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
|
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
|
||||||
}
|
}
|
||||||
|
@ -851,6 +867,76 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
|
||||||
R_API_END();
|
R_API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SparseOutputPointers {
|
||||||
|
void* indptr;
|
||||||
|
int32_t* indices;
|
||||||
|
void* data;
|
||||||
|
int indptr_type;
|
||||||
|
int data_type;
|
||||||
|
SparseOutputPointers(void* indptr, int32_t* indices, void* data)
|
||||||
|
: indptr(indptr), indices(indices), data(data) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
|
||||||
|
LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
|
||||||
|
delete ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
|
||||||
|
SEXP indptr,
|
||||||
|
SEXP indices,
|
||||||
|
SEXP data,
|
||||||
|
SEXP is_csr,
|
||||||
|
SEXP nrows,
|
||||||
|
SEXP ncols,
|
||||||
|
SEXP start_iteration,
|
||||||
|
SEXP num_iteration,
|
||||||
|
SEXP parameter) {
|
||||||
|
SEXP cont_token = PROTECT(R_MakeUnwindCont());
|
||||||
|
R_API_BEGIN();
|
||||||
|
_AssertBoosterHandleNotNull(handle);
|
||||||
|
const char* out_names[] = {"indptr", "indices", "data", ""};
|
||||||
|
SEXP out = PROTECT(Rf_mkNamed(VECSXP, out_names));
|
||||||
|
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
|
||||||
|
|
||||||
|
int64_t out_len[2];
|
||||||
|
void *out_indptr;
|
||||||
|
int32_t *out_indices;
|
||||||
|
void *out_data;
|
||||||
|
|
||||||
|
CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
|
||||||
|
INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
|
||||||
|
REAL(data), C_API_DTYPE_FLOAT64,
|
||||||
|
Rf_xlength(indptr), Rf_xlength(data),
|
||||||
|
Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
|
||||||
|
C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
|
||||||
|
parameter_ptr,
|
||||||
|
Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
|
||||||
|
out_len, &out_indptr, &out_indices, &out_data));
|
||||||
|
|
||||||
|
std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
|
||||||
|
new SparseOutputPointers(
|
||||||
|
out_indptr,
|
||||||
|
out_indices,
|
||||||
|
out_data),
|
||||||
|
&delete_SparseOutputPointers
|
||||||
|
};
|
||||||
|
|
||||||
|
SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
|
||||||
|
SET_VECTOR_ELT(out, 0, out_indptr_R);
|
||||||
|
SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
|
||||||
|
SET_VECTOR_ELT(out, 1, out_indices_R);
|
||||||
|
SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
|
||||||
|
SET_VECTOR_ELT(out, 2, out_data_R);
|
||||||
|
std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
|
||||||
|
std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
|
||||||
|
std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));
|
||||||
|
|
||||||
|
UNPROTECT(3);
|
||||||
|
return out;
|
||||||
|
R_API_END();
|
||||||
|
}
|
||||||
|
|
||||||
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
|
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
|
||||||
SEXP num_iteration,
|
SEXP num_iteration,
|
||||||
SEXP feature_importance_type,
|
SEXP feature_importance_type,
|
||||||
|
@ -975,6 +1061,7 @@ static const R_CallMethodDef CallEntries[] = {
|
||||||
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
|
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
|
||||||
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
|
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
|
||||||
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
|
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
|
||||||
|
{"LGBM_BoosterPredictSparseOutput_R", (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R, 10},
|
||||||
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
|
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
|
||||||
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
|
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
|
||||||
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
|
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
|
||||||
|
|
|
@ -574,6 +574,35 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
|
||||||
SEXP out_result
|
SEXP out_result
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief make feature contribution prediction for a new Dataset
|
||||||
|
* \param handle Booster handle
|
||||||
|
* \param indptr array with the index pointer of the data in CSR or CSC format
|
||||||
|
* \param indices array with the non-zero indices of the data in CSR or CSC format
|
||||||
|
* \param data array with the non-zero values of the data in CSR or CSC format
|
||||||
|
* \param is_csr whether the input data is in CSR format or not (pass FALSE for CSC)
|
||||||
|
* \param nrows number of rows in the data
|
||||||
|
* \param ncols number of columns in the data
|
||||||
|
* \param start_iteration Start index of the iteration to predict
|
||||||
|
* \param num_iteration number of iteration for prediction, <= 0 means no limit
|
||||||
|
* \param parameter additional parameters
|
||||||
|
* \return An R list with entries "indptr", "indices", "data", constituting the
|
||||||
|
* feature contributions in sparse format, in the same storage order as
|
||||||
|
* the input data.
|
||||||
|
*/
|
||||||
|
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictSparseOutput_R(
|
||||||
|
SEXP handle,
|
||||||
|
SEXP indptr,
|
||||||
|
SEXP indices,
|
||||||
|
SEXP data,
|
||||||
|
SEXP is_csr,
|
||||||
|
SEXP nrows,
|
||||||
|
SEXP ncols,
|
||||||
|
SEXP start_iteration,
|
||||||
|
SEXP num_iteration,
|
||||||
|
SEXP parameter
|
||||||
|
);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief save model into file
|
* \brief save model into file
|
||||||
* \param handle Booster handle
|
* \param handle Booster handle
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
library(Matrix)
|
||||||
|
|
||||||
VERBOSITY <- as.integer(
|
VERBOSITY <- as.integer(
|
||||||
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
|
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
|
||||||
)
|
)
|
||||||
|
@ -116,6 +118,84 @@ test_that("start_iteration works correctly", {
|
||||||
expect_equal(pred_leaf1, pred_leaf2)
|
expect_equal(pred_leaf1, pred_leaf2)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Feature contributions from sparse inputs produce sparse outputs", {
|
||||||
|
data(mtcars)
|
||||||
|
X <- as.matrix(mtcars[, -1L])
|
||||||
|
y <- as.numeric(mtcars[, 1L])
|
||||||
|
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
|
||||||
|
bst <- lgb.train(
|
||||||
|
data = dtrain
|
||||||
|
, obj = "regression"
|
||||||
|
, nrounds = 5L
|
||||||
|
, verbose = VERBOSITY
|
||||||
|
, params = list(min_data_in_leaf = 5L)
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_dense <- predict(bst, X, predcontrib = TRUE)
|
||||||
|
|
||||||
|
Xcsc <- as(X, "CsparseMatrix")
|
||||||
|
pred_csc <- predict(bst, Xcsc, predcontrib = TRUE)
|
||||||
|
expect_s4_class(pred_csc, "dgCMatrix")
|
||||||
|
expect_equal(unname(pred_dense), unname(as.matrix(pred_csc)))
|
||||||
|
|
||||||
|
Xcsr <- as(X, "RsparseMatrix")
|
||||||
|
pred_csr <- predict(bst, Xcsr, predcontrib = TRUE)
|
||||||
|
expect_s4_class(pred_csr, "dgRMatrix")
|
||||||
|
expect_equal(as(pred_csr, "CsparseMatrix"), pred_csc)
|
||||||
|
|
||||||
|
Xspv <- as(X[1L, , drop = FALSE], "sparseVector")
|
||||||
|
pred_spv <- predict(bst, Xspv, predcontrib = TRUE)
|
||||||
|
expect_s4_class(pred_spv, "dsparseVector")
|
||||||
|
expect_equal(Matrix::t(as(pred_spv, "CsparseMatrix")), unname(pred_csc[1L, , drop = FALSE]))
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("Sparse feature contribution predictions do not take inputs with wrong number of columns", {
|
||||||
|
data(mtcars)
|
||||||
|
X <- as.matrix(mtcars[, -1L])
|
||||||
|
y <- as.numeric(mtcars[, 1L])
|
||||||
|
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
|
||||||
|
bst <- lgb.train(
|
||||||
|
data = dtrain
|
||||||
|
, obj = "regression"
|
||||||
|
, nrounds = 5L
|
||||||
|
, verbose = VERBOSITY
|
||||||
|
, params = list(min_data_in_leaf = 5L)
|
||||||
|
)
|
||||||
|
|
||||||
|
X_wrong <- X[, c(1L:10L, 1L:10L)]
|
||||||
|
X_wrong <- as(X_wrong, "CsparseMatrix")
|
||||||
|
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns")
|
||||||
|
|
||||||
|
X_wrong <- as(X_wrong, "RsparseMatrix")
|
||||||
|
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns")
|
||||||
|
|
||||||
|
X_wrong <- as(X_wrong, "CsparseMatrix")
|
||||||
|
X_wrong <- X_wrong[, 1L:3L]
|
||||||
|
expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 3 columns")
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("Feature contribution predictions do not take non-general CSR or CSC inputs", {
|
||||||
|
set.seed(123L)
|
||||||
|
y <- runif(25L)
|
||||||
|
Dmat <- matrix(runif(625L), nrow = 25L, ncol = 25L)
|
||||||
|
Dmat <- crossprod(Dmat)
|
||||||
|
Dmat <- as(Dmat, "symmetricMatrix")
|
||||||
|
SmatC <- as(Dmat, "sparseMatrix")
|
||||||
|
SmatR <- as(SmatC, "RsparseMatrix")
|
||||||
|
|
||||||
|
dtrain <- lgb.Dataset(as.matrix(Dmat), label = y, params = list(max_bins = 5L))
|
||||||
|
bst <- lgb.train(
|
||||||
|
data = dtrain
|
||||||
|
, obj = "regression"
|
||||||
|
, nrounds = 5L
|
||||||
|
, verbose = VERBOSITY
|
||||||
|
, params = list(min_data_in_leaf = 5L)
|
||||||
|
)
|
||||||
|
|
||||||
|
expect_error(predict(bst, SmatC, predcontrib = TRUE))
|
||||||
|
expect_error(predict(bst, SmatR, predcontrib = TRUE))
|
||||||
|
})
|
||||||
|
|
||||||
test_that("predict() params should override keyword argument for raw-score predictions", {
|
test_that("predict() params should override keyword argument for raw-score predictions", {
|
||||||
data(agaricus.train, package = "lightgbm")
|
data(agaricus.train, package = "lightgbm")
|
||||||
X <- agaricus.train$data
|
X <- agaricus.train$data
|
||||||
|
@ -321,6 +401,8 @@ test_that("predict() params should override keyword argument for feature contrib
|
||||||
.expect_has_row_names(pred, Xcsc)
|
.expect_has_row_names(pred, Xcsc)
|
||||||
pred <- predict(bst, Xcsc, predcontrib = TRUE)
|
pred <- predict(bst, Xcsc, predcontrib = TRUE)
|
||||||
.expect_has_row_names(pred, Xcsc)
|
.expect_has_row_names(pred, Xcsc)
|
||||||
|
pred <- predict(bst, as(Xcsc, "RsparseMatrix"), predcontrib = TRUE)
|
||||||
|
.expect_has_row_names(pred, Xcsc)
|
||||||
|
|
||||||
# sparse matrix without row names
|
# sparse matrix without row names
|
||||||
Xcopy <- Xcsc
|
Xcopy <- Xcsc
|
||||||
|
|
Загрузка…
Ссылка в новой задаче