2018-01-31 09:20:41 +03:00
|
|
|
/* lightgbmlib.i */
|
|
|
|
%module lightgbmlib
|
|
|
|
%ignore LGBM_BoosterSaveModelToString;
|
2019-03-16 09:29:41 +03:00
|
|
|
%ignore LGBM_BoosterGetEvalNames;
|
2018-01-31 09:20:41 +03:00
|
|
|
%{
|
|
|
|
/* Includes the header in the wrapper code */
|
|
|
|
#include "../include/LightGBM/export.h"
|
|
|
|
#include "../include/LightGBM/utils/log.h"
|
2019-03-16 09:29:41 +03:00
|
|
|
#include "../include/LightGBM/utils/common.h"
|
2018-01-31 09:20:41 +03:00
|
|
|
#include "../include/LightGBM/c_api.h"
|
|
|
|
%}
|
|
|
|
|
|
|
|
/* header files */
|
|
|
|
%include "../include/LightGBM/export.h"
|
|
|
|
%include "../include/LightGBM/c_api.h"
|
|
|
|
%include "cpointer.i"
|
|
|
|
%include "carrays.i"
|
|
|
|
|
2019-03-19 02:37:48 +03:00
|
|
|
%typemap(in, numinputs=0) JNIEnv *jenv %{
|
|
|
|
$1 = jenv;
|
|
|
|
%}
|
|
|
|
|
2018-01-31 09:20:41 +03:00
|
|
|
%inline %{
|
|
|
|
char * LGBM_BoosterSaveModelToStringSWIG(BoosterHandle handle,
|
2019-03-16 09:29:41 +03:00
|
|
|
int start_iteration,
|
|
|
|
int num_iteration,
|
|
|
|
int64_t buffer_len,
|
|
|
|
int64_t* out_len) {
|
2018-01-31 09:20:41 +03:00
|
|
|
char* dst = new char[buffer_len];
|
2018-08-25 11:20:42 +03:00
|
|
|
int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, buffer_len, out_len, dst);
|
2019-03-16 09:29:41 +03:00
|
|
|
// Reallocate to use larger length
|
|
|
|
if (*out_len > buffer_len) {
|
|
|
|
delete [] dst;
|
|
|
|
int64_t realloc_len = *out_len;
|
|
|
|
dst = new char[realloc_len];
|
|
|
|
result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, realloc_len, out_len, dst);
|
|
|
|
}
|
|
|
|
if (result != 0) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return dst;
|
|
|
|
}
|
|
|
|
|
|
|
|
char ** LGBM_BoosterGetEvalNamesSWIG(BoosterHandle handle,
|
|
|
|
int eval_counts) {
|
|
|
|
char** dst = new char*[eval_counts];
|
|
|
|
for (int i = 0; i < eval_counts; ++i) {
|
|
|
|
dst[i] = new char[128];
|
|
|
|
}
|
|
|
|
int result = LGBM_BoosterGetEvalNames(handle, &eval_counts, dst);
|
2018-01-31 09:20:41 +03:00
|
|
|
if (result != 0) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return dst;
|
2019-03-19 02:37:48 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv,
|
|
|
|
jdoubleArray data,
|
|
|
|
BoosterHandle handle,
|
|
|
|
int data_type,
|
|
|
|
int ncol,
|
|
|
|
int is_row_major,
|
|
|
|
int predict_type,
|
|
|
|
int num_iteration,
|
|
|
|
const char* parameter,
|
|
|
|
int64_t* out_len,
|
|
|
|
double* out_result) {
|
|
|
|
double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
|
|
|
|
|
|
|
|
int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type,
|
|
|
|
num_iteration, parameter, out_len, out_result);
|
|
|
|
|
|
|
|
jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
int LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv,
|
|
|
|
jintArray indices,
|
|
|
|
jdoubleArray values,
|
|
|
|
int numNonZeros,
|
|
|
|
BoosterHandle handle,
|
|
|
|
int indptr_type,
|
|
|
|
int data_type,
|
|
|
|
int64_t nelem,
|
|
|
|
int64_t num_col,
|
|
|
|
int predict_type,
|
|
|
|
int num_iteration,
|
|
|
|
const char* parameter,
|
|
|
|
int64_t* out_len,
|
|
|
|
double* out_result) {
|
|
|
|
// Alternatives
|
|
|
|
// - GetIntArrayElements: performs copy
|
|
|
|
// - GetDirectBufferAddress: fails on wrapped array
|
|
|
|
// Some words of warning for GetPrimitiveArrayCritical
|
|
|
|
// https://stackoverflow.com/questions/23258357/whats-the-trade-off-between-using-getprimitivearraycritical-and-getprimitivety
|
|
|
|
|
|
|
|
jboolean isCopy;
|
|
|
|
int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, &isCopy);
|
|
|
|
double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, &isCopy);
|
|
|
|
|
|
|
|
int32_t ind[2] = { 0, numNonZeros };
|
|
|
|
|
|
|
|
int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2,
|
|
|
|
nelem, num_col, predict_type, num_iteration, parameter, out_len, out_result);
|
|
|
|
|
|
|
|
jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
|
|
|
|
jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
struct CSRDirect {
|
|
|
|
jintArray indices;
|
|
|
|
jdoubleArray values;
|
|
|
|
int* indices0;
|
|
|
|
double* values0;
|
|
|
|
int size;
|
|
|
|
};
|
|
|
|
|
|
|
|
int LGBM_DatasetCreateFromCSRSpark(JNIEnv *jenv,
|
|
|
|
jobjectArray arrayOfSparseVector,
|
|
|
|
int num_rows,
|
|
|
|
int64_t num_col,
|
|
|
|
const char* parameters,
|
|
|
|
const DatasetHandle reference,
|
|
|
|
DatasetHandle* out) {
|
|
|
|
jclass sparseVectorClass = jenv->FindClass("org/apache/spark/ml/linalg/SparseVector");
|
|
|
|
jmethodID sparseVectorIndices = jenv->GetMethodID(sparseVectorClass, "indices", "()[I");
|
|
|
|
jmethodID sparseVectorValues = jenv->GetMethodID(sparseVectorClass, "values", "()[D");
|
|
|
|
|
|
|
|
std::vector<CSRDirect> jniCache;
|
|
|
|
jniCache.reserve(num_rows);
|
|
|
|
|
|
|
|
// this needs to be done ahead of time as row_func is invoked from multiple threads
|
|
|
|
// these threads would have to be registered with the JVM and also unregistered.
|
|
|
|
// It is not clear if that can be achieved with OpenMP
|
|
|
|
for (int i=0; i<num_rows; i++) {
|
|
|
|
// get the row
|
|
|
|
jobject objSparseVec = jenv->GetObjectArrayElement(arrayOfSparseVector, i);
|
|
|
|
|
|
|
|
// get the size, indices and values
|
|
|
|
auto indices = (jintArray)jenv->CallObjectMethod(objSparseVec, sparseVectorIndices);
|
|
|
|
auto values = (jdoubleArray)jenv->CallObjectMethod(objSparseVec, sparseVectorValues);
|
|
|
|
int size = jenv->GetArrayLength(indices);
|
|
|
|
|
|
|
|
// Note: when testing on larger data (e.g. 288k rows per partition and 36mio rows total)
|
|
|
|
// using GetPrimitiveArrayCritical resulted in a dead-lock
|
|
|
|
// lock arrays
|
|
|
|
// int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, 0);
|
|
|
|
// double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, 0);
|
|
|
|
// in test-usecase an alternative to GetPrimitiveArrayCritical as it performs copies
|
|
|
|
int* indices0 = jenv->GetIntArrayElements(indices, 0);
|
|
|
|
double* values0 = jenv->GetDoubleArrayElements(values, 0);
|
|
|
|
|
|
|
|
jniCache.push_back({indices, values, indices0, values0, size});
|
|
|
|
}
|
|
|
|
|
|
|
|
// type is important here as we want a std::function, rather than a lambda
|
|
|
|
std::function<void(int idx, std::vector<std::pair<int, double>>& ret)> row_func = [&](int row_num, std::vector<std::pair<int, double>>& ret) {
|
|
|
|
auto& jc = jniCache[row_num];
|
|
|
|
ret.clear(); // reset size, but not free()
|
|
|
|
ret.reserve(jc.size); // make sure we have enough allocated
|
|
|
|
|
|
|
|
// copy data
|
|
|
|
int* indices0p = jc.indices0;
|
|
|
|
double* values0p = jc.values0;
|
|
|
|
int* indices0e = indices0p + jc.size;
|
|
|
|
|
|
|
|
for (; indices0p != indices0e; ++indices0p, ++values0p)
|
|
|
|
ret.emplace_back(*indices0p, *values0p);
|
|
|
|
};
|
|
|
|
|
|
|
|
int ret = LGBM_DatasetCreateFromCSRFunc(&row_func, num_rows, num_col, parameters, reference, out);
|
|
|
|
|
|
|
|
for (auto& jc : jniCache) {
|
|
|
|
// jenv->ReleasePrimitiveArrayCritical(jc.values, jc.values0, JNI_ABORT);
|
|
|
|
// jenv->ReleasePrimitiveArrayCritical(jc.indices, jc.indices0, JNI_ABORT);
|
|
|
|
jenv->ReleaseDoubleArrayElements(jc.values, jc.values0, JNI_ABORT);
|
|
|
|
jenv->ReleaseIntArrayElements(jc.indices, jc.indices0, JNI_ABORT);
|
|
|
|
}
|
|
|
|
|
|
|
|
return ret;
|
2018-01-31 09:20:41 +03:00
|
|
|
}
|
|
|
|
%}
|
|
|
|
|
|
|
|
%pointer_functions(int, intp)
|
|
|
|
%pointer_functions(long, longp)
|
|
|
|
%pointer_functions(double, doublep)
|
|
|
|
%pointer_functions(float, floatp)
|
|
|
|
%pointer_functions(int64_t, int64_tp)
|
|
|
|
%pointer_functions(int32_t, int32_tp)
|
|
|
|
|
|
|
|
%pointer_cast(int64_t *, long *, int64_t_to_long_ptr)
|
|
|
|
%pointer_cast(int64_t *, double *, int64_t_to_double_ptr)
|
|
|
|
%pointer_cast(int32_t *, int *, int32_t_to_int_ptr)
|
|
|
|
%pointer_cast(long *, int64_t *, long_to_int64_t_ptr)
|
|
|
|
%pointer_cast(double *, int64_t *, double_to_int64_t_ptr)
|
|
|
|
%pointer_cast(double *, void *, double_to_voidp_ptr)
|
|
|
|
%pointer_cast(int *, int32_t *, int_to_int32_t_ptr)
|
|
|
|
%pointer_cast(float *, void *, float_to_voidp_ptr)
|
|
|
|
|
|
|
|
%array_functions(double, doubleArray)
|
|
|
|
%array_functions(float, floatArray)
|
|
|
|
%array_functions(int, intArray)
|
|
|
|
%array_functions(long, longArray)
|
2019-03-16 09:29:41 +03:00
|
|
|
%array_functions(char *, stringArray)
|
2018-01-31 09:20:41 +03:00
|
|
|
|
|
|
|
/* Custom pointer manipulation template */
|
|
|
|
%define %pointer_manipulation(TYPE,NAME)
|
|
|
|
%{
|
|
|
|
static TYPE *new_##NAME() { %}
|
|
|
|
%{ TYPE* NAME = new TYPE; return NAME; %}
|
|
|
|
%{}
|
|
|
|
|
|
|
|
static void delete_##NAME(TYPE *self) { %}
|
|
|
|
%{ if (self) delete self; %}
|
|
|
|
%{}
|
|
|
|
%}
|
|
|
|
|
|
|
|
TYPE *new_##NAME();
|
|
|
|
void delete_##NAME(TYPE *self);
|
|
|
|
|
|
|
|
%enddef
|
|
|
|
|
|
|
|
%define %pointer_dereference(TYPE,NAME)
|
|
|
|
%{
|
|
|
|
static TYPE NAME ##_value(TYPE *self) {
|
|
|
|
TYPE NAME = *self;
|
|
|
|
return NAME;
|
|
|
|
}
|
|
|
|
%}
|
|
|
|
|
|
|
|
TYPE NAME##_value(TYPE *self);
|
|
|
|
|
|
|
|
%enddef
|
|
|
|
|
|
|
|
%define %pointer_handle(TYPE,NAME)
|
|
|
|
%{
|
|
|
|
static TYPE* NAME ##_handle() { %}
|
|
|
|
%{ TYPE* NAME = new TYPE; *NAME = (TYPE)operator new(sizeof(int*)); return NAME; %}
|
|
|
|
%{}
|
|
|
|
%}
|
|
|
|
|
|
|
|
TYPE *NAME##_handle();
|
|
|
|
|
|
|
|
%enddef
|
|
|
|
|
|
|
|
%pointer_manipulation(void*, voidpp)
|
|
|
|
|
|
|
|
/* Allow dereferencing of void** to void* */
|
|
|
|
%pointer_dereference(void*, voidpp)
|
|
|
|
|
|
|
|
/* Allow retrieving handle to void** */
|
|
|
|
%pointer_handle(void*, voidpp)
|
|
|
|
|