зеркало из https://github.com/microsoft/LightGBM.git
[R-package] Use ALTREP system to return C++-allocated arrays (#6213)
This commit is contained in:
Родитель
f6c8f5d8a1
Коммит
dee8a18889
|
@ -11,6 +11,7 @@
|
||||||
#include <LightGBM/utils/text_reader.h>
|
#include <LightGBM/utils/text_reader.h>
|
||||||
|
|
||||||
#include <R_ext/Rdynload.h>
|
#include <R_ext/Rdynload.h>
|
||||||
|
#include <R_ext/Altrep.h>
|
||||||
|
|
||||||
#define R_NO_REMAP
|
#define R_NO_REMAP
|
||||||
#define R_USE_C99_IN_CXX
|
#define R_USE_C99_IN_CXX
|
||||||
|
@ -24,6 +25,150 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
R_altrep_class_t lgb_altrepped_char_vec;
|
||||||
|
R_altrep_class_t lgb_altrepped_int_arr;
|
||||||
|
R_altrep_class_t lgb_altrepped_dbl_arr;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void delete_cpp_array(SEXP R_ptr) {
|
||||||
|
T *ptr_to_cpp_obj = static_cast<T*>(R_ExternalPtrAddr(R_ptr));
|
||||||
|
delete[] ptr_to_cpp_obj;
|
||||||
|
R_ClearExternalPtr(R_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void delete_cpp_char_vec(SEXP R_ptr) {
|
||||||
|
std::vector<char> *ptr_to_cpp_obj = static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_ptr));
|
||||||
|
delete ptr_to_cpp_obj;
|
||||||
|
R_ClearExternalPtr(R_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: MSVC has issues with Altrep classes, so they are disabled for it.
|
||||||
|
// See: https://github.com/microsoft/LightGBM/pull/6213#issuecomment-2111025768
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
# define LGB_NO_ALTREP
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef LGB_NO_ALTREP
|
||||||
|
SEXP make_altrepped_raw_vec(void *void_ptr) {
|
||||||
|
std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
|
||||||
|
SEXP R_ptr = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||||
|
SEXP R_raw = PROTECT(R_new_altrep(lgb_altrepped_char_vec, R_NilValue, R_NilValue));
|
||||||
|
|
||||||
|
R_SetExternalPtrAddr(R_ptr, ptr_to_cpp_vec->get());
|
||||||
|
R_RegisterCFinalizerEx(R_ptr, delete_cpp_char_vec, TRUE);
|
||||||
|
ptr_to_cpp_vec->release();
|
||||||
|
|
||||||
|
R_set_altrep_data1(R_raw, R_ptr);
|
||||||
|
UNPROTECT(2);
|
||||||
|
return R_raw;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
SEXP make_r_raw_vec(void *void_ptr) {
|
||||||
|
std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
|
||||||
|
R_xlen_t len = ptr_to_cpp_vec->get()->size();
|
||||||
|
SEXP out = PROTECT(Rf_allocVector(RAWSXP, len));
|
||||||
|
std::copy(ptr_to_cpp_vec->get()->begin(), ptr_to_cpp_vec->get()->end(), reinterpret_cast<char*>(RAW(out)));
|
||||||
|
UNPROTECT(1);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
#define make_altrepped_raw_vec make_r_raw_vec
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::vector<char>* get_ptr_from_altrepped_raw(SEXP R_raw) {
|
||||||
|
return static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_altrep_data1(R_raw)));
|
||||||
|
}
|
||||||
|
|
||||||
|
R_xlen_t get_altrepped_raw_len(SEXP R_raw) {
|
||||||
|
return get_ptr_from_altrepped_raw(R_raw)->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
const void* get_altrepped_raw_dataptr_or_null(SEXP R_raw) {
|
||||||
|
return get_ptr_from_altrepped_raw(R_raw)->data();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* get_altrepped_raw_dataptr(SEXP R_raw, Rboolean writeable) {
|
||||||
|
return get_ptr_from_altrepped_raw(R_raw)->data();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef LGB_NO_ALTREP
|
||||||
|
template <class T>
|
||||||
|
R_altrep_class_t get_altrep_class_for_type() {
|
||||||
|
if (std::is_same<T, double>::value) {
|
||||||
|
return lgb_altrepped_dbl_arr;
|
||||||
|
} else {
|
||||||
|
return lgb_altrepped_int_arr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <class T>
|
||||||
|
SEXPTYPE get_sexptype_class_for_type() {
|
||||||
|
if (std::is_same<T, double>::value) {
|
||||||
|
return REALSXP;
|
||||||
|
} else {
|
||||||
|
return INTSXP;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T* get_r_vec_ptr(SEXP x) {
|
||||||
|
if (std::is_same<T, double>::value) {
|
||||||
|
return static_cast<T*>(static_cast<void*>(REAL(x)));
|
||||||
|
} else {
|
||||||
|
return static_cast<T*>(static_cast<void*>(INTEGER(x)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct arr_and_len {
|
||||||
|
T *arr;
|
||||||
|
int64_t len;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifndef LGB_NO_ALTREP
|
||||||
|
template <class T>
|
||||||
|
SEXP make_altrepped_vec_from_arr(void *void_ptr) {
|
||||||
|
T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
|
||||||
|
uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
|
||||||
|
SEXP R_ptr = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||||
|
SEXP R_len = PROTECT(Rf_allocVector(REALSXP, 1));
|
||||||
|
SEXP R_vec = PROTECT(R_new_altrep(get_altrep_class_for_type<T>(), R_NilValue, R_NilValue));
|
||||||
|
|
||||||
|
REAL(R_len)[0] = static_cast<double>(len);
|
||||||
|
R_SetExternalPtrAddr(R_ptr, arr);
|
||||||
|
R_RegisterCFinalizerEx(R_ptr, delete_cpp_array<T>, TRUE);
|
||||||
|
|
||||||
|
R_set_altrep_data1(R_vec, R_ptr);
|
||||||
|
R_set_altrep_data2(R_vec, R_len);
|
||||||
|
UNPROTECT(3);
|
||||||
|
return R_vec;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <class T>
|
||||||
|
SEXP make_R_vec_from_arr(void *void_ptr) {
|
||||||
|
T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
|
||||||
|
uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
|
||||||
|
SEXP out = PROTECT(Rf_allocVector(get_sexptype_class_for_type<T>(), len));
|
||||||
|
std::copy(arr, arr + len, get_r_vec_ptr<T>(out));
|
||||||
|
UNPROTECT(1);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
#define make_altrepped_vec_from_arr make_R_vec_from_arr
|
||||||
|
#endif
|
||||||
|
|
||||||
|
R_xlen_t get_altrepped_vec_len(SEXP R_vec) {
|
||||||
|
return static_cast<R_xlen_t>(Rf_asReal(R_altrep_data2(R_vec)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const void* get_altrepped_vec_dataptr_or_null(SEXP R_vec) {
|
||||||
|
return R_ExternalPtrAddr(R_altrep_data1(R_vec));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* get_altrepped_vec_dataptr(SEXP R_vec, Rboolean writeable) {
|
||||||
|
return R_ExternalPtrAddr(R_altrep_data1(R_vec));
|
||||||
|
}
|
||||||
|
|
||||||
#define COL_MAJOR (0)
|
#define COL_MAJOR (0)
|
||||||
|
|
||||||
|
@ -964,8 +1109,6 @@ struct SparseOutputPointers {
|
||||||
void* indptr;
|
void* indptr;
|
||||||
int32_t* indices;
|
int32_t* indices;
|
||||||
void* data;
|
void* data;
|
||||||
int indptr_type;
|
|
||||||
int data_type;
|
|
||||||
SparseOutputPointers(void* indptr, int32_t* indices, void* data)
|
SparseOutputPointers(void* indptr, int32_t* indices, void* data)
|
||||||
: indptr(indptr), indices(indices), data(data) {}
|
: indptr(indptr), indices(indices), data(data) {}
|
||||||
};
|
};
|
||||||
|
@ -1015,15 +1158,26 @@ SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
|
||||||
&delete_SparseOutputPointers
|
&delete_SparseOutputPointers
|
||||||
};
|
};
|
||||||
|
|
||||||
SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
|
arr_and_len<int> indptr_str{static_cast<int*>(out_indptr), out_len[1]};
|
||||||
SET_VECTOR_ELT(out, 0, out_indptr_R);
|
SET_VECTOR_ELT(
|
||||||
SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
|
out, 0,
|
||||||
SET_VECTOR_ELT(out, 1, out_indices_R);
|
R_UnwindProtect(make_altrepped_vec_from_arr<int>,
|
||||||
SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
|
static_cast<void*>(&indptr_str), throw_R_memerr, &cont_token, cont_token));
|
||||||
SET_VECTOR_ELT(out, 2, out_data_R);
|
pointers_struct->indptr = nullptr;
|
||||||
std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
|
|
||||||
std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
|
arr_and_len<int> indices_str{static_cast<int*>(out_indices), out_len[0]};
|
||||||
std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));
|
SET_VECTOR_ELT(
|
||||||
|
out, 1,
|
||||||
|
R_UnwindProtect(make_altrepped_vec_from_arr<int>,
|
||||||
|
static_cast<void*>(&indices_str), throw_R_memerr, &cont_token, cont_token));
|
||||||
|
pointers_struct->indices = nullptr;
|
||||||
|
|
||||||
|
arr_and_len<double> data_str{static_cast<double*>(out_data), out_len[0]};
|
||||||
|
SET_VECTOR_ELT(
|
||||||
|
out, 2,
|
||||||
|
R_UnwindProtect(make_altrepped_vec_from_arr<double>,
|
||||||
|
static_cast<void*>(&data_str), throw_R_memerr, &cont_token, cont_token));
|
||||||
|
pointers_struct->data = nullptr;
|
||||||
|
|
||||||
UNPROTECT(3);
|
UNPROTECT(3);
|
||||||
return out;
|
return out;
|
||||||
|
@ -1104,6 +1258,34 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
|
||||||
R_API_END();
|
R_API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: for some reason, MSVC crashes when an error is thrown here
|
||||||
|
// if the buffer variable is defined as 'std::unique_ptr<std::vector<char>>',
|
||||||
|
// but not if it is defined as '<std::vector<char>'.
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
|
||||||
|
SEXP num_iteration,
|
||||||
|
SEXP feature_importance_type,
|
||||||
|
SEXP start_iteration) {
|
||||||
|
SEXP cont_token = PROTECT(R_MakeUnwindCont());
|
||||||
|
R_API_BEGIN();
|
||||||
|
_AssertBoosterHandleNotNull(handle);
|
||||||
|
int64_t out_len = 0;
|
||||||
|
int64_t buf_len = 1024 * 1024;
|
||||||
|
int num_iter = Rf_asInteger(num_iteration);
|
||||||
|
int start_iter = Rf_asInteger(start_iteration);
|
||||||
|
int importance_type = Rf_asInteger(feature_importance_type);
|
||||||
|
std::unique_ptr<std::vector<char>> inner_char_buf(new std::vector<char>(buf_len));
|
||||||
|
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf->data()));
|
||||||
|
inner_char_buf->resize(out_len);
|
||||||
|
if (out_len > buf_len) {
|
||||||
|
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf->data()));
|
||||||
|
}
|
||||||
|
SEXP out = R_UnwindProtect(make_altrepped_raw_vec, &inner_char_buf, throw_R_memerr, &cont_token, cont_token);
|
||||||
|
UNPROTECT(1);
|
||||||
|
return out;
|
||||||
|
R_API_END();
|
||||||
|
}
|
||||||
|
#else
|
||||||
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
|
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
|
||||||
SEXP num_iteration,
|
SEXP num_iteration,
|
||||||
SEXP feature_importance_type,
|
SEXP feature_importance_type,
|
||||||
|
@ -1129,6 +1311,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
|
||||||
return model_str;
|
return model_str;
|
||||||
R_API_END();
|
R_API_END();
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
|
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
|
||||||
SEXP num_iteration,
|
SEXP num_iteration,
|
||||||
|
@ -1281,4 +1464,21 @@ LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);
|
||||||
void R_init_lightgbm(DllInfo *dll) {
|
void R_init_lightgbm(DllInfo *dll) {
|
||||||
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
|
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
|
||||||
R_useDynamicSymbols(dll, FALSE);
|
R_useDynamicSymbols(dll, FALSE);
|
||||||
|
|
||||||
|
#ifndef LGB_NO_ALTREP
|
||||||
|
lgb_altrepped_char_vec = R_make_altraw_class("lgb_altrepped_char_vec", "lightgbm", dll);
|
||||||
|
R_set_altrep_Length_method(lgb_altrepped_char_vec, get_altrepped_raw_len);
|
||||||
|
R_set_altvec_Dataptr_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr);
|
||||||
|
R_set_altvec_Dataptr_or_null_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr_or_null);
|
||||||
|
|
||||||
|
lgb_altrepped_int_arr = R_make_altinteger_class("lgb_altrepped_int_arr", "lightgbm", dll);
|
||||||
|
R_set_altrep_Length_method(lgb_altrepped_int_arr, get_altrepped_vec_len);
|
||||||
|
R_set_altvec_Dataptr_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr);
|
||||||
|
R_set_altvec_Dataptr_or_null_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr_or_null);
|
||||||
|
|
||||||
|
lgb_altrepped_dbl_arr = R_make_altreal_class("lgb_altrepped_dbl_arr", "lightgbm", dll);
|
||||||
|
R_set_altrep_Length_method(lgb_altrepped_dbl_arr, get_altrepped_vec_len);
|
||||||
|
R_set_altvec_Dataptr_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr);
|
||||||
|
R_set_altvec_Dataptr_or_null_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr_or_null);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче