[python] support Dataset.get_data for Sequence input. (#4472)

* [python] support Dataset.get_data for Sequence input.

* Tweaks according to review comments.

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Add test cases.

* fix import order in test_basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Chen Yufei 2021-07-31 04:49:13 +08:00 коммит произвёл GitHub
Родитель 2370961ae0
Коммит 1d21d1ad4c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 43 добавлений и 6 удалений

Просмотреть файл

@ -641,9 +641,12 @@ class Sequence(abc.ABC):
.. code-block:: python
if isinstance(idx, numbers.Integral):
return self.__get_one_line__(idx)
return self._get_one_line(idx)
elif isinstance(idx, slice):
return np.stack(self.__get_one_line__(i) for i in range(idx.start, idx.stop))
return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
elif isinstance(idx, list):
# Only required if using ``Dataset.get_data()``.
return np.array([self._get_one_line(i) for i in idx])
else:
raise TypeError(f"Sequence index must be integer or slice, got {type(idx).__name__}")
@ -1515,7 +1518,8 @@ class Dataset:
# set feature names
return self.set_feature_name(feature_name)
def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]):
@staticmethod
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]):
offset = 0
seq_id = 0
seq = seqs[seq_id]
@ -1541,7 +1545,7 @@ class Dataset:
indices = self._create_sample_indices(total_nrow)
# Select sampled rows, transpose to column order.
sampled = np.array([row for row in self.__yield_row_from(seqs, indices)])
sampled = np.array([row for row in self._yield_row_from_seqlist(seqs, indices)])
sampled = sampled.T
filtered = []
@ -2236,7 +2240,7 @@ class Dataset:
Returns
-------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays or None
Raw data used in the Dataset construction.
"""
if self.handle is None:
@ -2250,6 +2254,10 @@ class Dataset:
self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, dt_DataTable):
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data):
self.data = np.array([row for row in self._yield_row_from_seqlist(self.data, self.used_indices)])
else:
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
"Returning original raw data")

Просмотреть файл

@ -1,6 +1,7 @@
# coding: utf-8
import filecmp
import numbers
import types
from pathlib import Path
import numpy as np
@ -194,6 +195,31 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname)
def test_sequence_get_data():
nrow = 20
ncol = 11
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
X = data[:, :-1]
Y = data[:, -1]
seqs = _create_sequence_from_ndarray(X, 2, 6)
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False)
seq_ds.construct()
assert seqs == seq_ds.get_data()
# This is a hack to add test coverage in get_data.
used_indices = [0, 5, 11, 15]
ref_data = types.SimpleNamespace()
ref_data.data = seqs
seq_ds.need_slice = True
seq_ds.reference = ref_data
seq_ds.used_indices = used_indices
assert (X[used_indices] == seq_ds.get_data()).all()
def test_chunked_dataset():
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
random_state=2)
@ -313,7 +339,7 @@ def test_add_features_from_different_sources():
n_row = 100
n_col = 5
X = np.random.random((n_row, n_col))
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)]
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X), _create_sequence_from_ndarray(X, 1, 30)]
names = [f'col_{i}' for i in range(n_col)]
for x_1 in xxs:
# test that method works even with free_raw_data=True
@ -333,6 +359,9 @@ def test_add_features_from_different_sources():
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
res_feature_names = [name for name in names]
for idx, x_2 in enumerate(xxs, 2):
# Dataset.get_data does not support Sequence input.
if isinstance(x_1, lgb.Sequence) or isinstance(x_2, lgb.Sequence):
continue
original_type = type(d1.get_data())
d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct()
d1.add_features_from(d2)