зеркало из https://github.com/microsoft/LightGBM.git
[python-package] Allow to pass Arrow table for prediction (#6168)
This commit is contained in:
Родитель
6fc80528f1
Коммит
2dfb9a4047
|
@ -1417,6 +1417,40 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
|
|||
int64_t* out_len,
|
||||
double* out_result);
|
||||
|
||||
/*!
|
||||
* \brief Make prediction for a new dataset.
|
||||
* \note
|
||||
* You should pre-allocate memory for ``out_result``:
|
||||
* - for normal and raw score, its length is equal to ``num_class * num_data``;
|
||||
* - for leaf index, its length is equal to ``num_class * num_data * num_iteration``;
|
||||
* - for feature contributions, its length is equal to ``num_class * num_data * (num_feature + 1)``.
|
||||
* \param handle Handle of booster
|
||||
* \param n_chunks The number of Arrow arrays passed to this function
|
||||
* \param chunks Pointer to the list of Arrow arrays
|
||||
* \param schema Pointer to the schema of all Arrow arrays
|
||||
* \param predict_type What should be predicted
|
||||
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
|
||||
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
|
||||
* - ``C_API_PREDICT_LEAF_INDEX``: leaf index;
|
||||
* - ``C_API_PREDICT_CONTRIB``: feature contributions (SHAP values)
|
||||
* \param start_iteration Start index of the iteration to predict
|
||||
* \param num_iteration Number of iteration for prediction, <= 0 means no limit
|
||||
* \param parameter Other parameters for prediction, e.g. early stopping for prediction
|
||||
* \param[out] out_len Length of output result
|
||||
* \param[out] out_result Pointer to array with predictions
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForArrow(BoosterHandle handle,
|
||||
int64_t n_chunks,
|
||||
const ArrowArray* chunks,
|
||||
const ArrowSchema* schema,
|
||||
int predict_type,
|
||||
int start_iteration,
|
||||
int num_iteration,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result);
|
||||
|
||||
/*!
|
||||
* \brief Save model into file.
|
||||
* \param handle Handle of booster
|
||||
|
|
|
@ -115,7 +115,8 @@ _LGBM_PredictDataType = Union[
|
|||
np.ndarray,
|
||||
pd_DataFrame,
|
||||
dt_DataTable,
|
||||
scipy.sparse.spmatrix
|
||||
scipy.sparse.spmatrix,
|
||||
pa_Table,
|
||||
]
|
||||
_LGBM_WeightType = Union[
|
||||
List[float],
|
||||
|
@ -1069,7 +1070,7 @@ class _InnerPredictor:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
|
||||
data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
|
||||
Data source for prediction.
|
||||
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
|
||||
start_iteration : int, optional (default=0)
|
||||
|
@ -1161,6 +1162,13 @@ class _InnerPredictor:
|
|||
num_iteration=num_iteration,
|
||||
predict_type=predict_type
|
||||
)
|
||||
elif _is_pyarrow_table(data):
|
||||
preds, nrow = self.__pred_for_pyarrow_table(
|
||||
table=data,
|
||||
start_iteration=start_iteration,
|
||||
num_iteration=num_iteration,
|
||||
predict_type=predict_type
|
||||
)
|
||||
elif isinstance(data, list):
|
||||
try:
|
||||
data = np.array(data)
|
||||
|
@ -1614,6 +1622,48 @@ class _InnerPredictor:
|
|||
if n_preds != out_num_preds.value:
|
||||
raise ValueError("Wrong length for predict results")
|
||||
return preds, nrow
|
||||
|
||||
def __pred_for_pyarrow_table(
|
||||
self,
|
||||
table: pa_Table,
|
||||
start_iteration: int,
|
||||
num_iteration: int,
|
||||
predict_type: int
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Predict for a PyArrow table."""
|
||||
if not PYARROW_INSTALLED:
|
||||
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")
|
||||
|
||||
# Check that the input is valid: we only handle numbers (for now)
|
||||
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
|
||||
raise ValueError("Arrow table may only have integer or floating point datatypes")
|
||||
|
||||
# Prepare prediction output array
|
||||
n_preds = self.__get_num_preds(
|
||||
start_iteration=start_iteration,
|
||||
num_iteration=num_iteration,
|
||||
nrow=table.num_rows,
|
||||
predict_type=predict_type
|
||||
)
|
||||
preds = np.empty(n_preds, dtype=np.float64)
|
||||
out_num_preds = ctypes.c_int64(0)
|
||||
|
||||
# Export Arrow table to C and run prediction
|
||||
c_array = _export_arrow_to_c(table)
|
||||
_safe_call(_LIB.LGBM_BoosterPredictForArrow(
|
||||
self._handle,
|
||||
ctypes.c_int64(c_array.n_chunks),
|
||||
ctypes.c_void_p(c_array.chunks_ptr),
|
||||
ctypes.c_void_p(c_array.schema_ptr),
|
||||
ctypes.c_int(predict_type),
|
||||
ctypes.c_int(start_iteration),
|
||||
ctypes.c_int(num_iteration),
|
||||
_c_str(self.pred_parameter),
|
||||
ctypes.byref(out_num_preds),
|
||||
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
|
||||
if n_preds != out_num_preds.value:
|
||||
raise ValueError("Wrong length for predict results")
|
||||
return preds, table.num_rows
|
||||
|
||||
def current_iteration(self) -> int:
|
||||
"""Get the index of the current iteration.
|
||||
|
@ -4350,7 +4400,7 @@ class Booster:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
|
||||
data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
|
||||
Data source for prediction.
|
||||
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
|
||||
start_iteration : int, optional (default=0)
|
||||
|
|
|
@ -2568,6 +2568,57 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
|
|||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_BoosterPredictForArrow(BoosterHandle handle,
|
||||
int64_t n_chunks,
|
||||
const ArrowArray* chunks,
|
||||
const ArrowSchema* schema,
|
||||
int predict_type,
|
||||
int start_iteration,
|
||||
int num_iteration,
|
||||
const char* parameter,
|
||||
int64_t* out_len,
|
||||
double* out_result) {
|
||||
API_BEGIN();
|
||||
|
||||
// Apply the configuration
|
||||
auto param = Config::Str2Map(parameter);
|
||||
Config config;
|
||||
config.Set(param);
|
||||
OMP_SET_NUM_THREADS(config.num_threads);
|
||||
|
||||
// Set up chunked array and iterators for all columns
|
||||
ArrowTable table(n_chunks, chunks, schema);
|
||||
std::vector<ArrowChunkedArray::Iterator<double>> its;
|
||||
its.reserve(table.get_num_columns());
|
||||
for (int64_t j = 0; j < table.get_num_columns(); ++j) {
|
||||
its.emplace_back(table.get_column(j).begin<double>());
|
||||
}
|
||||
|
||||
// Build row function
|
||||
auto num_columns = table.get_num_columns();
|
||||
auto row_fn = [num_columns, &its] (int row_idx) {
|
||||
std::vector<std::pair<int, double>> result;
|
||||
result.reserve(num_columns);
|
||||
for (int64_t j = 0; j < num_columns; ++j) {
|
||||
result.emplace_back(static_cast<int>(j), its[j][row_idx]);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// Run prediction
|
||||
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
|
||||
ref_booster->Predict(start_iteration,
|
||||
num_iteration,
|
||||
predict_type,
|
||||
static_cast<int>(table.get_num_rows()),
|
||||
static_cast<int>(table.get_num_columns()),
|
||||
row_fn,
|
||||
config,
|
||||
out_result,
|
||||
out_len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_BoosterSaveModel(BoosterHandle handle,
|
||||
int start_iteration,
|
||||
int num_iteration,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# coding: utf-8
|
||||
import filecmp
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
@ -63,19 +63,40 @@ def generate_dummy_arrow_table() -> pa.Table:
|
|||
return pa.Table.from_arrays([col1, col2], names=["a", "b"])
|
||||
|
||||
|
||||
def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table:
|
||||
columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)]
|
||||
def generate_random_arrow_table(
|
||||
num_columns: int,
|
||||
num_datapoints: int,
|
||||
seed: int,
|
||||
generate_nulls: bool = True,
|
||||
values: Optional[np.ndarray] = None,
|
||||
) -> pa.Table:
|
||||
columns = [
|
||||
generate_random_arrow_array(
|
||||
num_datapoints, seed + i, generate_nulls=generate_nulls, values=values
|
||||
)
|
||||
for i in range(num_columns)
|
||||
]
|
||||
names = [f"col_{i}" for i in range(num_columns)]
|
||||
return pa.Table.from_arrays(columns, names=names)
|
||||
|
||||
|
||||
def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray:
|
||||
def generate_random_arrow_array(
|
||||
num_datapoints: int,
|
||||
seed: int,
|
||||
generate_nulls: bool = True,
|
||||
values: Optional[np.ndarray] = None,
|
||||
) -> pa.ChunkedArray:
|
||||
generator = np.random.default_rng(seed)
|
||||
data = generator.standard_normal(num_datapoints)
|
||||
data = (
|
||||
generator.standard_normal(num_datapoints)
|
||||
if values is None
|
||||
else generator.choice(values, size=num_datapoints, replace=True)
|
||||
)
|
||||
|
||||
# Set random nulls
|
||||
indices = generator.choice(len(data), size=num_datapoints // 10)
|
||||
data[indices] = None
|
||||
if generate_nulls:
|
||||
indices = generator.choice(len(data), size=num_datapoints // 10)
|
||||
data[indices] = None
|
||||
|
||||
# Split data into <=2 random chunks
|
||||
split_points = np.sort(generator.choice(np.arange(1, num_datapoints), 2, replace=False))
|
||||
|
@ -131,8 +152,8 @@ def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
|
|||
|
||||
def test_dataset_construct_fields_fuzzy():
|
||||
arrow_table = generate_random_arrow_table(3, 1000, 42)
|
||||
arrow_labels = generate_random_arrow_array(1000, 42)
|
||||
arrow_weights = generate_random_arrow_array(1000, 42)
|
||||
arrow_labels = generate_random_arrow_array(1000, 42, generate_nulls=False)
|
||||
arrow_weights = generate_random_arrow_array(1000, 42, generate_nulls=False)
|
||||
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())
|
||||
|
||||
arrow_dataset = lgb.Dataset(
|
||||
|
@ -264,9 +285,9 @@ def test_dataset_construct_init_scores_table():
|
|||
data = generate_dummy_arrow_table()
|
||||
init_scores = pa.Table.from_arrays(
|
||||
[
|
||||
generate_random_arrow_array(5, seed=1),
|
||||
generate_random_arrow_array(5, seed=2),
|
||||
generate_random_arrow_array(5, seed=3),
|
||||
generate_random_arrow_array(5, seed=1, generate_nulls=False),
|
||||
generate_random_arrow_array(5, seed=2, generate_nulls=False),
|
||||
generate_random_arrow_array(5, seed=3, generate_nulls=False),
|
||||
],
|
||||
names=["a", "b", "c"],
|
||||
)
|
||||
|
@ -276,3 +297,91 @@ def test_dataset_construct_init_scores_table():
|
|||
actual = dataset.get_init_score()
|
||||
expected = init_scores.to_pandas().to_numpy().astype(np.float64)
|
||||
np_assert_array_equal(expected, actual, strict=True)
|
||||
|
||||
|
||||
# ------------------------------------------ PREDICTION ----------------------------------------- #
|
||||
|
||||
|
||||
def assert_equal_predict_arrow_pandas(booster: lgb.Booster, data: pa.Table):
|
||||
p_arrow = booster.predict(data)
|
||||
p_pandas = booster.predict(data.to_pandas())
|
||||
np_assert_array_equal(p_arrow, p_pandas, strict=True)
|
||||
|
||||
p_raw_arrow = booster.predict(data, raw_score=True)
|
||||
p_raw_pandas = booster.predict(data.to_pandas(), raw_score=True)
|
||||
np_assert_array_equal(p_raw_arrow, p_raw_pandas, strict=True)
|
||||
|
||||
p_leaf_arrow = booster.predict(data, pred_leaf=True)
|
||||
p_leaf_pandas = booster.predict(data.to_pandas(), pred_leaf=True)
|
||||
np_assert_array_equal(p_leaf_arrow, p_leaf_pandas, strict=True)
|
||||
|
||||
p_pred_contrib_arrow = booster.predict(data, pred_contrib=True)
|
||||
p_pred_contrib_pandas = booster.predict(data.to_pandas(), pred_contrib=True)
|
||||
np_assert_array_equal(p_pred_contrib_arrow, p_pred_contrib_pandas, strict=True)
|
||||
|
||||
p_first_iter_arrow = booster.predict(data, start_iteration=0, num_iteration=1, raw_score=True)
|
||||
p_first_iter_pandas = booster.predict(
|
||||
data.to_pandas(), start_iteration=0, num_iteration=1, raw_score=True
|
||||
)
|
||||
np_assert_array_equal(p_first_iter_arrow, p_first_iter_pandas, strict=True)
|
||||
|
||||
|
||||
def test_predict_regression():
|
||||
data = generate_random_arrow_table(10, 10000, 42)
|
||||
dataset = lgb.Dataset(
|
||||
data,
|
||||
label=generate_random_arrow_array(10000, 43, generate_nulls=False),
|
||||
params=dummy_dataset_params(),
|
||||
)
|
||||
booster = lgb.train(
|
||||
{"objective": "regression", "num_leaves": 7},
|
||||
dataset,
|
||||
num_boost_round=5,
|
||||
)
|
||||
assert_equal_predict_arrow_pandas(booster, data)
|
||||
|
||||
|
||||
def test_predict_binary_classification():
|
||||
data = generate_random_arrow_table(10, 10000, 42)
|
||||
dataset = lgb.Dataset(
|
||||
data,
|
||||
label=generate_random_arrow_array(10000, 43, generate_nulls=False, values=np.arange(2)),
|
||||
params=dummy_dataset_params(),
|
||||
)
|
||||
booster = lgb.train(
|
||||
{"objective": "binary", "num_leaves": 7},
|
||||
dataset,
|
||||
num_boost_round=5,
|
||||
)
|
||||
assert_equal_predict_arrow_pandas(booster, data)
|
||||
|
||||
|
||||
def test_predict_multiclass_classification():
|
||||
data = generate_random_arrow_table(10, 10000, 42)
|
||||
dataset = lgb.Dataset(
|
||||
data,
|
||||
label=generate_random_arrow_array(10000, 43, generate_nulls=False, values=np.arange(5)),
|
||||
params=dummy_dataset_params(),
|
||||
)
|
||||
booster = lgb.train(
|
||||
{"objective": "multiclass", "num_leaves": 7, "num_class": 5},
|
||||
dataset,
|
||||
num_boost_round=5,
|
||||
)
|
||||
assert_equal_predict_arrow_pandas(booster, data)
|
||||
|
||||
|
||||
def test_predict_ranking():
|
||||
data = generate_random_arrow_table(10, 10000, 42)
|
||||
dataset = lgb.Dataset(
|
||||
data,
|
||||
label=generate_random_arrow_array(10000, 43, generate_nulls=False, values=np.arange(4)),
|
||||
group=np.array([1000, 2000, 3000, 4000]),
|
||||
params=dummy_dataset_params(),
|
||||
)
|
||||
booster = lgb.train(
|
||||
{"objective": "lambdarank", "num_leaves": 7},
|
||||
dataset,
|
||||
num_boost_round=5,
|
||||
)
|
||||
assert_equal_predict_arrow_pandas(booster, data)
|
||||
|
|
Загрузка…
Ссылка в новой задаче