зеркало из https://github.com/microsoft/LightGBM.git
[python-package] Allow to pass Arrow array as groups (#6166)
This commit is contained in:
Родитель
bc6942226e
Коммит
516bde9501
|
@ -558,9 +558,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
|
|||
/*!
|
||||
* \brief Set vector to a content in info.
|
||||
* \note
|
||||
* - \a group converts input datatype into ``int32``;
|
||||
* - \a label and \a weight convert input datatype into ``float32``.
|
||||
* \param handle Handle of dataset
|
||||
* \param field_name Field name, can be \a label, \a weight
|
||||
* \param field_name Field name, can be \a label, \a weight, \a group
|
||||
* \param n_chunks The number of Arrow arrays passed to this function
|
||||
* \param chunks Pointer to the list of Arrow arrays
|
||||
* \param schema Pointer to the schema of all Arrow arrays
|
||||
|
|
|
@ -116,6 +116,7 @@ class Metadata {
|
|||
void SetWeights(const ArrowChunkedArray& array);
|
||||
|
||||
void SetQuery(const data_size_t* query, data_size_t len);
|
||||
void SetQuery(const ArrowChunkedArray& array);
|
||||
|
||||
void SetPosition(const data_size_t* position, data_size_t len);
|
||||
|
||||
|
@ -348,6 +349,9 @@ class Metadata {
|
|||
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
|
||||
/*! \brief Insert queries at the given index */
|
||||
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
|
||||
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
|
||||
template <typename It>
|
||||
void SetQueriesFromIterator(It first, It last);
|
||||
/*! \brief Filename of current data */
|
||||
std::string data_filename_;
|
||||
/*! \brief Number of data */
|
||||
|
|
|
@ -70,7 +70,9 @@ _LGBM_GroupType = Union[
|
|||
List[float],
|
||||
List[int],
|
||||
np.ndarray,
|
||||
pd_Series
|
||||
pd_Series,
|
||||
pa_Array,
|
||||
pa_ChunkedArray,
|
||||
]
|
||||
_LGBM_PositionType = Union[
|
||||
np.ndarray,
|
||||
|
@ -1652,7 +1654,7 @@ class Dataset:
|
|||
If this is Dataset for validation, training data should be used as reference.
|
||||
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
|
||||
Weight for each instance. Weights should be non-negative.
|
||||
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
|
||||
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
|
||||
Group/query data.
|
||||
Only used in the learning-to-rank task.
|
||||
sum(group) = n_samples.
|
||||
|
@ -2432,7 +2434,7 @@ class Dataset:
|
|||
Label of the data.
|
||||
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
|
||||
Weight for each instance. Weights should be non-negative.
|
||||
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
|
||||
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
|
||||
Group/query data.
|
||||
Only used in the learning-to-rank task.
|
||||
sum(group) = n_samples.
|
||||
|
@ -2889,7 +2891,7 @@ class Dataset:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
group : list, numpy 1-D array, pandas Series or None
|
||||
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
|
||||
Group/query data.
|
||||
Only used in the learning-to-rank task.
|
||||
sum(group) = n_samples.
|
||||
|
@ -2903,7 +2905,8 @@ class Dataset:
|
|||
"""
|
||||
self.group = group
|
||||
if self._handle is not None and group is not None:
|
||||
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
|
||||
if not _is_pyarrow_array(group):
|
||||
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
|
||||
self.set_field('group', group)
|
||||
# original values can be modified at cpp side
|
||||
constructed_group = self.get_field('group')
|
||||
|
@ -4431,7 +4434,7 @@ class Booster:
|
|||
|
||||
.. versionadded:: 4.0.0
|
||||
|
||||
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
|
||||
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
|
||||
Group/query size for ``data``.
|
||||
Only used in the learning-to-rank task.
|
||||
sum(group) = n_samples.
|
||||
|
|
|
@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
|
|||
metadata_.SetLabel(ca);
|
||||
} else if (name == std::string("weight") || name == std::string("weights")) {
|
||||
metadata_.SetWeights(ca);
|
||||
} else if (name == std::string("query") || name == std::string("group")) {
|
||||
metadata_.SetQuery(ca);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -507,30 +507,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da
|
|||
// CUDA is handled after all insertions are complete
|
||||
}
|
||||
|
||||
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
|
||||
template <typename It>
|
||||
void Metadata::SetQueriesFromIterator(It first, It last) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// save to nullptr
|
||||
if (query == nullptr || len == 0) {
|
||||
// Clear query boundaries on empty input
|
||||
if (last - first == 0) {
|
||||
query_boundaries_.clear();
|
||||
num_queries_ = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
data_size_t sum = 0;
|
||||
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum)
|
||||
for (data_size_t i = 0; i < len; ++i) {
|
||||
sum += query[i];
|
||||
for (data_size_t i = 0; i < last - first; ++i) {
|
||||
sum += first[i];
|
||||
}
|
||||
if (num_data_ != sum) {
|
||||
Log::Fatal("Sum of query counts is not same with #data");
|
||||
Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum);
|
||||
}
|
||||
num_queries_ = len;
|
||||
num_queries_ = last - first;
|
||||
|
||||
query_boundaries_.resize(num_queries_ + 1);
|
||||
query_boundaries_[0] = 0;
|
||||
for (data_size_t i = 0; i < num_queries_; ++i) {
|
||||
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
|
||||
query_boundaries_[i + 1] = query_boundaries_[i] + first[i];
|
||||
}
|
||||
CalculateQueryWeights();
|
||||
query_load_from_file_ = false;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
if (cuda_metadata_ != nullptr) {
|
||||
if (query_weights_.size() > 0) {
|
||||
|
@ -543,6 +547,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
|
|||
#endif // USE_CUDA
|
||||
}
|
||||
|
||||
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
|
||||
SetQueriesFromIterator(query, query + len);
|
||||
}
|
||||
|
||||
void Metadata::SetQuery(const ArrowChunkedArray& array) {
|
||||
SetQueriesFromIterator(array.begin<data_size_t>(), array.end<data_size_t>());
|
||||
}
|
||||
|
||||
void Metadata::SetPosition(const data_size_t* positions, data_size_t len) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// save to nullptr
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf-8
|
||||
import filecmp
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
@ -15,6 +14,21 @@ from .utils import np_assert_array_equal
|
|||
# UTILITIES #
|
||||
# ----------------------------------------------------------------------------------------------- #
|
||||
|
||||
_INTEGER_TYPES = [
|
||||
pa.int8(),
|
||||
pa.int16(),
|
||||
pa.int32(),
|
||||
pa.int64(),
|
||||
pa.uint8(),
|
||||
pa.uint16(),
|
||||
pa.uint32(),
|
||||
pa.uint64(),
|
||||
]
|
||||
_FLOAT_TYPES = [
|
||||
pa.float32(),
|
||||
pa.float64(),
|
||||
]
|
||||
|
||||
|
||||
def generate_simple_arrow_table() -> pa.Table:
|
||||
columns = [
|
||||
|
@ -85,9 +99,7 @@ def dummy_dataset_params() -> Dict[str, Any]:
|
|||
(lambda: generate_random_arrow_table(100, 10000, 43), {}),
|
||||
],
|
||||
)
|
||||
def test_dataset_construct_fuzzy(
|
||||
tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any]
|
||||
):
|
||||
def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
|
||||
arrow_table = arrow_table_fn()
|
||||
|
||||
arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params)
|
||||
|
@ -108,17 +120,23 @@ def test_dataset_construct_fields_fuzzy():
|
|||
arrow_table = generate_random_arrow_table(3, 1000, 42)
|
||||
arrow_labels = generate_random_arrow_array(1000, 42)
|
||||
arrow_weights = generate_random_arrow_array(1000, 42)
|
||||
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())
|
||||
|
||||
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
|
||||
arrow_dataset = lgb.Dataset(
|
||||
arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups
|
||||
)
|
||||
arrow_dataset.construct()
|
||||
|
||||
pandas_dataset = lgb.Dataset(
|
||||
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
|
||||
arrow_table.to_pandas(),
|
||||
label=arrow_labels.to_numpy(),
|
||||
weight=arrow_weights.to_numpy(),
|
||||
group=arrow_groups.to_numpy(),
|
||||
)
|
||||
pandas_dataset.construct()
|
||||
|
||||
# Check for equality
|
||||
for field in ("label", "weight"):
|
||||
for field in ("label", "weight", "group"):
|
||||
np_assert_array_equal(
|
||||
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
|
||||
)
|
||||
|
@ -133,22 +151,8 @@ def test_dataset_construct_fields_fuzzy():
|
|||
["array_type", "label_data"],
|
||||
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"arrow_type",
|
||||
[
|
||||
pa.int8(),
|
||||
pa.int16(),
|
||||
pa.int32(),
|
||||
pa.int64(),
|
||||
pa.uint8(),
|
||||
pa.uint16(),
|
||||
pa.uint32(),
|
||||
pa.uint64(),
|
||||
pa.float32(),
|
||||
pa.float64(),
|
||||
],
|
||||
)
|
||||
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
|
||||
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
|
||||
def test_dataset_construct_labels(array_type, label_data, arrow_type):
|
||||
data = generate_dummy_arrow_table()
|
||||
labels = array_type(label_data, type=arrow_type)
|
||||
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
|
||||
|
@ -175,7 +179,7 @@ def test_dataset_construct_weights_none():
|
|||
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
|
||||
)
|
||||
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
|
||||
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
|
||||
def test_dataset_construct_weights(array_type, weight_data, arrow_type):
|
||||
data = generate_dummy_arrow_table()
|
||||
weights = array_type(weight_data, type=arrow_type)
|
||||
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
|
||||
|
@ -183,3 +187,26 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type
|
|||
|
||||
expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
|
||||
np_assert_array_equal(expected, dataset.get_weight(), strict=True)
|
||||
|
||||
|
||||
# -------------------------------------------- GROUPS ------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["array_type", "group_data"],
|
||||
[
|
||||
(pa.array, [2, 3]),
|
||||
(pa.chunked_array, [[2], [3]]),
|
||||
(pa.chunked_array, [[], [2, 3]]),
|
||||
(pa.chunked_array, [[2], [], [3], []]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES)
|
||||
def test_dataset_construct_groups(array_type, group_data, arrow_type):
|
||||
data = generate_dummy_arrow_table()
|
||||
groups = array_type(group_data, type=arrow_type)
|
||||
dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())
|
||||
dataset.construct()
|
||||
|
||||
expected = np.array([0, 2, 5], dtype=np.int32)
|
||||
np_assert_array_equal(expected, dataset.get_field("group"), strict=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче