зеркало из https://github.com/microsoft/LightGBM.git
[python] support Dataset.get_data for Sequence input. (#4472)
* [python] support Dataset.get_data for Sequence input. * Tweaks according to review comments. * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * Add test cases. * fix import order in test_basic.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
2370961ae0
Коммит
1d21d1ad4c
|
@ -641,9 +641,12 @@ class Sequence(abc.ABC):
|
|||
.. code-block:: python
|
||||
|
||||
if isinstance(idx, numbers.Integral):
|
||||
return self.__get_one_line__(idx)
|
||||
return self._get_one_line(idx)
|
||||
elif isinstance(idx, slice):
|
||||
return np.stack(self.__get_one_line__(i) for i in range(idx.start, idx.stop))
|
||||
return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
|
||||
elif isinstance(idx, list):
|
||||
# Only required if using ``Dataset.get_data()``.
|
||||
return np.array([self._get_one_line(i) for i in idx])
|
||||
else:
|
||||
raise TypeError(f"Sequence index must be integer or slice, got {type(idx).__name__}")
|
||||
|
||||
|
@ -1515,7 +1518,8 @@ class Dataset:
|
|||
# set feature names
|
||||
return self.set_feature_name(feature_name)
|
||||
|
||||
def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]):
|
||||
@staticmethod
|
||||
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]):
|
||||
offset = 0
|
||||
seq_id = 0
|
||||
seq = seqs[seq_id]
|
||||
|
@ -1541,7 +1545,7 @@ class Dataset:
|
|||
indices = self._create_sample_indices(total_nrow)
|
||||
|
||||
# Select sampled rows, transpose to column order.
|
||||
sampled = np.array([row for row in self.__yield_row_from(seqs, indices)])
|
||||
sampled = np.array([row for row in self._yield_row_from_seqlist(seqs, indices)])
|
||||
sampled = sampled.T
|
||||
|
||||
filtered = []
|
||||
|
@ -2236,7 +2240,7 @@ class Dataset:
|
|||
|
||||
Returns
|
||||
-------
|
||||
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None
|
||||
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays or None
|
||||
Raw data used in the Dataset construction.
|
||||
"""
|
||||
if self.handle is None:
|
||||
|
@ -2250,6 +2254,10 @@ class Dataset:
|
|||
self.data = self.data.iloc[self.used_indices].copy()
|
||||
elif isinstance(self.data, dt_DataTable):
|
||||
self.data = self.data[self.used_indices, :]
|
||||
elif isinstance(self.data, Sequence):
|
||||
self.data = self.data[self.used_indices]
|
||||
elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data):
|
||||
self.data = np.array([row for row in self._yield_row_from_seqlist(self.data, self.used_indices)])
|
||||
else:
|
||||
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
|
||||
"Returning original raw data")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# coding: utf-8
|
||||
import filecmp
|
||||
import numbers
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
@ -194,6 +195,31 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
|
|||
assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname)
|
||||
|
||||
|
||||
def test_sequence_get_data():
|
||||
nrow = 20
|
||||
ncol = 11
|
||||
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
|
||||
|
||||
X = data[:, :-1]
|
||||
Y = data[:, -1]
|
||||
|
||||
seqs = _create_sequence_from_ndarray(X, 2, 6)
|
||||
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False)
|
||||
seq_ds.construct()
|
||||
|
||||
assert seqs == seq_ds.get_data()
|
||||
|
||||
# This is a hack to add test coverage in get_data.
|
||||
used_indices = [0, 5, 11, 15]
|
||||
ref_data = types.SimpleNamespace()
|
||||
ref_data.data = seqs
|
||||
|
||||
seq_ds.need_slice = True
|
||||
seq_ds.reference = ref_data
|
||||
seq_ds.used_indices = used_indices
|
||||
assert (X[used_indices] == seq_ds.get_data()).all()
|
||||
|
||||
|
||||
def test_chunked_dataset():
|
||||
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
|
||||
random_state=2)
|
||||
|
@ -313,7 +339,7 @@ def test_add_features_from_different_sources():
|
|||
n_row = 100
|
||||
n_col = 5
|
||||
X = np.random.random((n_row, n_col))
|
||||
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)]
|
||||
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X), _create_sequence_from_ndarray(X, 1, 30)]
|
||||
names = [f'col_{i}' for i in range(n_col)]
|
||||
for x_1 in xxs:
|
||||
# test that method works even with free_raw_data=True
|
||||
|
@ -333,6 +359,9 @@ def test_add_features_from_different_sources():
|
|||
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
|
||||
res_feature_names = [name for name in names]
|
||||
for idx, x_2 in enumerate(xxs, 2):
|
||||
# Dataset.get_data does not support Sequence input.
|
||||
if isinstance(x_1, lgb.Sequence) or isinstance(x_2, lgb.Sequence):
|
||||
continue
|
||||
original_type = type(d1.get_data())
|
||||
d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct()
|
||||
d1.add_features_from(d2)
|
||||
|
|
Загрузка…
Ссылка в новой задаче