зеркало из https://github.com/microsoft/LightGBM.git
* [R-package] fix protection stack imbalance and unprotected objects issues * [R-package] fix minor linting issues * [ci][R-package] change timeout-minutes in valgrind test * [R-package] remove extra space Co-authored-by: James Lamb <jaylamb20@gmail.com> * [R-package] remove counter for number of protected objects * Update .github/workflows/r_valgrind.yml Co-authored-by: James Lamb <jaylamb20@gmail.com>
This commit is contained in:
Родитель
f62c490474
Коммит
aacb4c8fd9
|
@ -28,15 +28,13 @@
|
|||
#define R_API_BEGIN() \
|
||||
try {
|
||||
#define R_API_END() } \
|
||||
catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
|
||||
catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
|
||||
catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
|
||||
return R_NilValue;
|
||||
catch(std::exception& ex) { LGBM_SetLastError(ex.what()); } \
|
||||
catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); } \
|
||||
catch(...) { LGBM_SetLastError("unknown exception"); }
|
||||
|
||||
#define CHECK_CALL(x) \
|
||||
if ((x) != 0) { \
|
||||
Rf_error(LGBM_GetLastError()); \
|
||||
return R_NilValue; \
|
||||
}
|
||||
|
||||
using LightGBM::Common::Split;
|
||||
|
@ -54,19 +52,20 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
|
|||
SEXP parameters,
|
||||
SEXP reference) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
DatasetHandle handle = nullptr;
|
||||
DatasetHandle ref = nullptr;
|
||||
if (!Rf_isNull(reference)) {
|
||||
ref = R_ExternalPtrAddr(reference);
|
||||
}
|
||||
CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)),
|
||||
ref, &handle));
|
||||
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
|
||||
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(3);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
|
||||
|
@ -78,27 +77,27 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
|
|||
SEXP parameters,
|
||||
SEXP reference) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
const int* p_indptr = INTEGER(indptr);
|
||||
const int* p_indices = INTEGER(indices);
|
||||
const double* p_data = REAL(data);
|
||||
|
||||
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
|
||||
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
|
||||
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
|
||||
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
|
||||
DatasetHandle handle = nullptr;
|
||||
DatasetHandle ref = nullptr;
|
||||
if (!Rf_isNull(reference)) {
|
||||
ref = R_ExternalPtrAddr(reference);
|
||||
}
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
|
||||
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
|
||||
nrow, CHAR(Rf_asChar(parameters)), ref, &handle));
|
||||
nrow, parameters_ptr, ref, &handle));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(2);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
|
||||
|
@ -107,22 +106,23 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
|
|||
SEXP parameters,
|
||||
SEXP reference) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
|
||||
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
|
||||
double* p_mat = REAL(data);
|
||||
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
|
||||
DatasetHandle handle = nullptr;
|
||||
DatasetHandle ref = nullptr;
|
||||
if (!Rf_isNull(reference)) {
|
||||
ref = R_ExternalPtrAddr(reference);
|
||||
}
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
|
||||
CHAR(Rf_asChar(parameters)), ref, &handle));
|
||||
parameters_ptr, ref, &handle));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(2);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
|
||||
|
@ -130,7 +130,6 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
|
|||
SEXP len_used_row_indices,
|
||||
SEXP parameters) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
|
||||
std::vector<int32_t> idxvec(len);
|
||||
// convert from one-based to zero-based index
|
||||
|
@ -138,36 +137,41 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
|
|||
for (int32_t i = 0; i < len; ++i) {
|
||||
idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
|
||||
}
|
||||
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
|
||||
DatasetHandle res = nullptr;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
|
||||
idxvec.data(), len, CHAR(Rf_asChar(parameters)),
|
||||
idxvec.data(), len, parameters_ptr,
|
||||
&res));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(2);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
|
||||
SEXP feature_names) {
|
||||
R_API_BEGIN();
|
||||
auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
|
||||
auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
|
||||
std::vector<const char*> vec_sptr;
|
||||
int len = static_cast<int>(vec_names.size());
|
||||
for (int i = 0; i < len; ++i) {
|
||||
vec_sptr.push_back(vec_names[i].c_str());
|
||||
}
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
|
||||
vec_sptr.data(), len));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
|
||||
SEXP feature_names;
|
||||
R_API_BEGIN();
|
||||
int len = 0;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
|
||||
R_API_END();
|
||||
const size_t reserved_string_size = 256;
|
||||
std::vector<std::vector<char>> names(len);
|
||||
std::vector<char*> ptr_names(len);
|
||||
|
@ -177,12 +181,14 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
|
|||
}
|
||||
int out_len;
|
||||
size_t required_string_size;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(
|
||||
LGBM_DatasetGetFeatureNames(
|
||||
R_ExternalPtrAddr(handle),
|
||||
len, &out_len,
|
||||
reserved_string_size, &required_string_size,
|
||||
ptr_names.data()));
|
||||
R_API_END();
|
||||
// if any feature names were larger than allocated size,
|
||||
// allow for a larger size and try again
|
||||
if (required_string_size > reserved_string_size) {
|
||||
|
@ -190,6 +196,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
|
|||
names[i].resize(required_string_size);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(
|
||||
LGBM_DatasetGetFeatureNames(
|
||||
R_ExternalPtrAddr(handle),
|
||||
|
@ -198,6 +205,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
|
|||
required_string_size,
|
||||
&required_string_size,
|
||||
ptr_names.data()));
|
||||
R_API_END();
|
||||
}
|
||||
CHECK_EQ(len, out_len);
|
||||
feature_names = PROTECT(Rf_allocVector(STRSXP, len));
|
||||
|
@ -206,15 +214,17 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
|
|||
}
|
||||
UNPROTECT(1);
|
||||
return feature_names;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
|
||||
SEXP filename) {
|
||||
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
|
||||
CHAR(Rf_asChar(filename))));
|
||||
filename_ptr));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetFree_R(SEXP handle) {
|
||||
|
@ -224,15 +234,16 @@ SEXP LGBM_DatasetFree_R(SEXP handle) {
|
|||
R_ClearExternalPtr(handle);
|
||||
}
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetSetField_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP field_data,
|
||||
SEXP num_element) {
|
||||
R_API_BEGIN();
|
||||
int len = Rf_asInteger(num_element);
|
||||
const char* name = CHAR(Rf_asChar(field_name));
|
||||
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
|
||||
R_API_BEGIN();
|
||||
if (!strcmp("group", name) || !strcmp("query", name)) {
|
||||
std::vector<int32_t> vec(len);
|
||||
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
|
||||
|
@ -251,18 +262,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
|
|||
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
|
||||
}
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetField_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP field_data) {
|
||||
R_API_BEGIN();
|
||||
const char* name = CHAR(Rf_asChar(field_name));
|
||||
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
|
||||
int out_len = 0;
|
||||
int out_type = 0;
|
||||
const void* res;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
|
||||
|
||||
if (!strcmp("group", name) || !strcmp("query", name)) {
|
||||
auto p_data = reinterpret_cast<const int32_t*>(res);
|
||||
// convert from boundaries to size
|
||||
|
@ -284,29 +296,37 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
|
|||
}
|
||||
}
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
|
||||
SEXP field_name,
|
||||
SEXP out) {
|
||||
R_API_BEGIN();
|
||||
const char* name = CHAR(Rf_asChar(field_name));
|
||||
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
|
||||
int out_len = 0;
|
||||
int out_type = 0;
|
||||
const void* res;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
|
||||
if (!strcmp("group", name) || !strcmp("query", name)) {
|
||||
out_len -= 1;
|
||||
}
|
||||
INTEGER(out)[0] = out_len;
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
|
||||
SEXP new_params) {
|
||||
const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
|
||||
const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params))));
|
||||
CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
|
||||
R_API_END();
|
||||
UNPROTECT(2);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
|
||||
|
@ -315,6 +335,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
|
|||
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
|
||||
INTEGER(out)[0] = nrow;
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
|
||||
|
@ -324,6 +345,7 @@ SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
|
|||
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
|
||||
INTEGER(out)[0] = nfeature;
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
// --- start Booster interfaces
|
||||
|
@ -339,45 +361,49 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
|
|||
R_ClearExternalPtr(handle);
|
||||
}
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterCreate_R(SEXP train_data,
|
||||
SEXP parameters) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
|
||||
BoosterHandle handle = nullptr;
|
||||
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(2);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
int out_num_iterations = 0;
|
||||
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
|
||||
BoosterHandle handle = nullptr;
|
||||
CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(2);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
int out_num_iterations = 0;
|
||||
const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str)));
|
||||
BoosterHandle handle = nullptr;
|
||||
CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
|
||||
R_API_END();
|
||||
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
UNPROTECT(2);
|
||||
return ret;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterMerge_R(SEXP handle,
|
||||
|
@ -385,6 +411,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle,
|
|||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
|
||||
|
@ -392,6 +419,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle,
|
|||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
|
||||
|
@ -399,13 +427,17 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
|
|||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
|
||||
SEXP parameters) {
|
||||
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(parameters))));
|
||||
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
|
||||
|
@ -415,6 +447,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
|
|||
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
|
||||
INTEGER(out)[0] = num_class;
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
|
||||
|
@ -422,6 +455,7 @@ SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
|
|||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
|
||||
|
@ -439,12 +473,14 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
|
|||
}
|
||||
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
|
||||
|
@ -454,6 +490,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
|
|||
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
|
||||
INTEGER(out)[0] = out_iteration;
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
|
||||
|
@ -462,6 +499,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
|
|||
double* ptr_ret = REAL(out_result);
|
||||
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
|
||||
|
@ -470,14 +508,15 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
|
|||
double* ptr_ret = REAL(out_result);
|
||||
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
|
||||
SEXP eval_names;
|
||||
R_API_BEGIN();
|
||||
int len;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
|
||||
|
||||
R_API_END();
|
||||
const size_t reserved_string_size = 128;
|
||||
std::vector<std::vector<char>> names(len);
|
||||
std::vector<char*> ptr_names(len);
|
||||
|
@ -488,12 +527,14 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
|
|||
|
||||
int out_len;
|
||||
size_t required_string_size;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(
|
||||
LGBM_BoosterGetEvalNames(
|
||||
R_ExternalPtrAddr(handle),
|
||||
len, &out_len,
|
||||
reserved_string_size, &required_string_size,
|
||||
ptr_names.data()));
|
||||
R_API_END();
|
||||
// if any eval names were larger than allocated size,
|
||||
// allow for a larger size and try again
|
||||
if (required_string_size > reserved_string_size) {
|
||||
|
@ -501,6 +542,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
|
|||
names[i].resize(required_string_size);
|
||||
ptr_names[i] = names[i].data();
|
||||
}
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(
|
||||
LGBM_BoosterGetEvalNames(
|
||||
R_ExternalPtrAddr(handle),
|
||||
|
@ -509,6 +551,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
|
|||
required_string_size,
|
||||
&required_string_size,
|
||||
ptr_names.data()));
|
||||
R_API_END();
|
||||
}
|
||||
CHECK_EQ(out_len, len);
|
||||
eval_names = PROTECT(Rf_allocVector(STRSXP, len));
|
||||
|
@ -517,7 +560,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
|
|||
}
|
||||
UNPROTECT(1);
|
||||
return eval_names;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetEval_R(SEXP handle,
|
||||
|
@ -531,6 +573,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
|
|||
CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
|
||||
CHECK_EQ(out_len, len);
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
|
||||
|
@ -541,6 +584,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
|
|||
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
|
||||
INTEGER(out)[0] = static_cast<int>(len);
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
|
||||
|
@ -551,6 +595,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
|
|||
int64_t out_len;
|
||||
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
|
||||
|
@ -577,12 +622,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
|
|||
SEXP num_iteration,
|
||||
SEXP parameter,
|
||||
SEXP result_filename) {
|
||||
R_API_BEGIN();
|
||||
const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
|
||||
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
|
||||
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
|
||||
CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(data_filename)),
|
||||
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)),
|
||||
CHAR(Rf_asChar(result_filename))));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
|
||||
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
|
||||
result_filename_ptr));
|
||||
R_API_END();
|
||||
UNPROTECT(3);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
|
||||
|
@ -600,6 +650,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
|
|||
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
|
||||
INTEGER(out_len)[0] = static_cast<int>(len);
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
|
||||
|
@ -616,23 +667,24 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
|
|||
SEXP num_iteration,
|
||||
SEXP parameter,
|
||||
SEXP out_result) {
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
|
||||
|
||||
const int* p_indptr = INTEGER(indptr);
|
||||
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
|
||||
const double* p_data = REAL(data);
|
||||
|
||||
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
|
||||
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
|
||||
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
|
||||
double* ptr_ret = REAL(out_result);
|
||||
int64_t out_len;
|
||||
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
|
||||
p_indptr, C_API_DTYPE_INT32, p_indices,
|
||||
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
|
||||
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
|
||||
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
|
||||
|
@ -646,75 +698,82 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
|
|||
SEXP num_iteration,
|
||||
SEXP parameter,
|
||||
SEXP out_result) {
|
||||
R_API_BEGIN();
|
||||
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
|
||||
|
||||
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
|
||||
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
|
||||
|
||||
const double* p_mat = REAL(data);
|
||||
double* ptr_ret = REAL(out_result);
|
||||
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
|
||||
int64_t out_len;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
|
||||
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
|
||||
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
|
||||
|
||||
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type,
|
||||
SEXP filename) {
|
||||
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename))));
|
||||
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type) {
|
||||
SEXP model_str;
|
||||
R_API_BEGIN();
|
||||
int64_t out_len = 0;
|
||||
int64_t buf_len = 1024 * 1024;
|
||||
int num_iter = Rf_asInteger(num_iteration);
|
||||
int importance_type = Rf_asInteger(feature_importance_type);
|
||||
std::vector<char> inner_char_buf(buf_len);
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
|
||||
R_API_END();
|
||||
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
|
||||
if (out_len > buf_len) {
|
||||
inner_char_buf.resize(out_len);
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
|
||||
R_API_END();
|
||||
}
|
||||
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
|
||||
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
|
||||
UNPROTECT(1);
|
||||
return model_str;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
|
||||
SEXP num_iteration,
|
||||
SEXP feature_importance_type) {
|
||||
SEXP model_str;
|
||||
R_API_BEGIN();
|
||||
int64_t out_len = 0;
|
||||
int64_t buf_len = 1024 * 1024;
|
||||
int num_iter = Rf_asInteger(num_iteration);
|
||||
int importance_type = Rf_asInteger(feature_importance_type);
|
||||
std::vector<char> inner_char_buf(buf_len);
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
|
||||
R_API_END();
|
||||
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
|
||||
if (out_len > buf_len) {
|
||||
inner_char_buf.resize(out_len);
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
|
||||
R_API_END();
|
||||
}
|
||||
model_str = PROTECT(Rf_allocVector(STRSXP, 1));
|
||||
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
|
||||
UNPROTECT(1);
|
||||
return model_str;
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
// .Call() calls
|
||||
|
|
Загрузка…
Ссылка в новой задаче