зеркало из https://github.com/microsoft/LightGBM.git
[python-package] Create Dataset from multiple data files (#4089)
* [python-package] create Dataset from sampled data.
* [python-package] create Dataset from List[Sequence].
1. Use random access for data sampling
2. Support read data from multiple input files
3. Read data in batch so no need to hold all data in memory
* [python-package] example: create Dataset from multiple HDF5 file.
* fix: revert is_class implementation for seq
* fix: unwanted memory view reference for seq
* fix: seq is_class accepts sklearn matrices
* fix: requirements for example
* fix: pycode
* feat: print static code linting stage
* fix: linting: avoid shell str regex conversion
* code style: doc style
* code style: isort
* fix ci dependency: h5py on windows
* [py] remove rm files in test seq
https://github.com/microsoft/LightGBM/pull/4089#discussion_r612929623
* docs(python): init_from_sample summary
https://github.com/microsoft/LightGBM/pull/4089#discussion_r612903389
* remove dataset dump sample data debugging code.
* remove typo fix.
Create separate PR for this.
* fix typo in src/c_api.cpp
Co-authored-by: James Lamb <jaylamb20@gmail.com>
* style(linting): py3 type hint for seq
* test(basic): os.path style path handling
* Revert "feat: print static code linting stage"
This reverts commit 10bd79f7f8
.
* feat(python): sequence on validation set
* minor(python): comment
* minor(python): test option hint
* style(python): fix code linting
* style(python): add pydoc for ref_dataset
* doc(python): sequence
Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
* revert(python): sequence class abc
* chore(python): remove rm_files
* Remove useless static_assert.
* refactor: test_basic test for sequence.
* fix lint complaint.
* remove dataset._dump_text in sequence test.
* Fix reverting typo fix.
* Apply suggestions from code review
Co-authored-by: James Lamb <jaylamb20@gmail.com>
* Fix type hint, code and doc style.
* fix failing test_basic.
* Remove TODO about keep constant in sync with cpp.
* Install h5py only when running python-examples.
* Fix lint complaint.
* Apply suggestions from code review
Co-authored-by: James Lamb <jaylamb20@gmail.com>
* Doc fixes, remove unused params_str in __init_from_seqs.
* Apply suggestions from code review
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
* Remove unnecessary conda install in windows ci script.
* Keep param as example in dataset_from_multi_hdf5.py
* Add _get_sample_count function to remove code duplication.
* Use batch_size parameter in generate_hdf.
* Apply suggestions from code review
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
* Fix after applying suggestions.
* Fix test, check idx is instance of numbers.Integral.
* Update python-package/lightgbm/basic.py
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
* Expose Sequence class in Python-API doc.
* Handle Sequence object not having batch_size.
* Fix isort lint complaint.
* Apply suggestions from code review
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
* Update docstring to mention Sequence as data input.
* Remove get_one_line in test_basic.py
* Make Sequence an abstract class.
* Reduce number of tests for test_sequence.
* Add c_api: LGBM_SampleCount, fix potential bug in LGBMSampleIndices.
* empty commit to trigger ci
* Apply suggestions from code review
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
* Rename to LGBM_GetSampleCount, change LGBM_SampleIndices out_len to int32_t.
Also rename total_nrow to num_total_row in c_api.h for consistency.
* Doc about Sequence in docs/Python-Intro.rst.
* Fix: basic.py change LGBM_SampleIndices out_len to int32.
* Add create_valid test case with Dataset from Sequence.
* Apply suggestions from code review
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
* Apply suggestions from code review
Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
* Remove no longer used DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT.
* Update python-package/lightgbm/basic.py
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
Co-authored-by: Willian Zhang <willian@willian.email>
Co-authored-by: Willian Z <Willian@Willian-Zhang.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
f37b0d463f
Коммит
c359896e9b
|
@ -234,8 +234,8 @@ import matplotlib\
|
|||
matplotlib.use\(\"Agg\"\)\
|
||||
' plot_example.py # prevent interactive window mode
|
||||
sed -i'.bak' 's/graph.render(view=True)/graph.render(view=False)/' plot_example.py
|
||||
conda install -q -y -n $CONDA_ENV h5py ipywidgets notebook # requirements for examples
|
||||
for f in *.py **/*.py; do python $f || exit -1; done # run all examples
|
||||
cd $BUILD_DIRECTORY/examples/python-guide/notebooks
|
||||
conda install -q -y -n $CONDA_ENV ipywidgets notebook
|
||||
jupyter nbconvert --ExecutePreprocessor.timeout=180 --to notebook --execute --inplace *.ipynb || exit -1 # run all notebooks
|
||||
fi
|
||||
|
|
|
@ -106,11 +106,11 @@ if (($env:TASK -eq "regular") -or (($env:APPVEYOR -eq "true") -and ($env:TASK -e
|
|||
cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide
|
||||
@("import matplotlib", "matplotlib.use('Agg')") + (Get-Content "plot_example.py") | Set-Content "plot_example.py"
|
||||
(Get-Content "plot_example.py").replace('graph.render(view=True)', 'graph.render(view=False)') | Set-Content "plot_example.py" # prevent interactive window mode
|
||||
conda install -q -y -n $env:CONDA_ENV h5py ipywidgets notebook
|
||||
foreach ($file in @(Get-ChildItem *.py)) {
|
||||
@("import sys, warnings", "warnings.showwarning = lambda message, category, filename, lineno, file=None, line=None: sys.stdout.write(warnings.formatwarning(message, category, filename, lineno, line))") + (Get-Content $file) | Set-Content $file
|
||||
python $file ; Check-Output $?
|
||||
} # run all examples
|
||||
cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide/notebooks
|
||||
conda install -q -y -n $env:CONDA_ENV ipywidgets notebook
|
||||
jupyter nbconvert --ExecutePreprocessor.timeout=180 --to notebook --execute --inplace *.ipynb ; Check-Output $? # run all notebooks
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ Data Structure API
|
|||
Dataset
|
||||
Booster
|
||||
CVBooster
|
||||
Sequence
|
||||
|
||||
Training API
|
||||
------------
|
||||
|
|
|
@ -39,6 +39,8 @@ The LightGBM Python module can load data from:
|
|||
|
||||
- LightGBM binary file
|
||||
|
||||
- LightGBM ``Sequence`` object(s)
|
||||
|
||||
The data is stored in a ``Dataset`` object.
|
||||
|
||||
Many of the examples in this page use functionality from ``numpy``. To run the examples, be sure to import ``numpy`` in your session.
|
||||
|
@ -69,6 +71,38 @@ Many of the examples in this page use functionality from ``numpy``. To run the e
|
|||
csr = scipy.sparse.csr_matrix((dat, (row, col)))
|
||||
train_data = lgb.Dataset(csr)
|
||||
|
||||
**Load from Sequence objects:**
|
||||
|
||||
We can implement ``Sequence`` interface to read binary files. The following example shows reading HDF5 file with ``h5py``.
|
||||
|
||||
.. code:: python
|
||||
|
||||
import h5py
|
||||
|
||||
class HDFSequence(lgb.Sequence):
|
||||
def __init__(self, hdf_dataset, batch_size):
|
||||
self.data = hdf_dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
f = h5py.File('train.hdf5', 'r')
|
||||
train_data = lgb.Dataset(HDFSequence(f['X'], 8192), label=f['Y'][:])
|
||||
|
||||
Features of using ``Sequence`` interface:
|
||||
|
||||
- Data sampling uses random access, thus does not go through the whole dataset
|
||||
- Reading data in batch, thus saves memory when constructing ``Dataset`` object
|
||||
- Supports creating ``Dataset`` from multiple data files
|
||||
|
||||
Please refer to ``Sequence`` `API doc <./Python-API.rst#data-structure-api>`__.
|
||||
|
||||
`dataset_from_multi_hdf5.py <https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py>`__ is a detailed example.
|
||||
|
||||
**Saving Dataset into a LightGBM binary file will make loading faster:**
|
||||
|
||||
.. code:: python
|
||||
|
|
|
@ -61,3 +61,6 @@ Examples include:
|
|||
- Plot split value histogram
|
||||
- Plot one specified tree
|
||||
- Plot one specified tree with Graphviz
|
||||
- [dataset_from_multi_hdf5.py](https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py)
|
||||
- Construct Dataset from multiple HDF5 files
|
||||
- Avoid loading all data into memory
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
import h5py
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import lightgbm as lgb
|
||||
|
||||
|
||||
class HDFSequence(lgb.Sequence):
|
||||
def __init__(self, hdf_dataset, batch_size):
|
||||
"""
|
||||
Construct a sequence object from HDF5 with required interface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hdf_dataset : h5py.Dataset
|
||||
Dataset in HDF5 file.
|
||||
batch_size : int
|
||||
Size of a batch. When reading data to construct lightgbm Dataset, each read reads batch_size rows.
|
||||
"""
|
||||
# We can also open HDF5 file once and get access to
|
||||
self.data = hdf_dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def create_dataset_from_multiple_hdf(input_flist, batch_size):
|
||||
data = []
|
||||
ylist = []
|
||||
for f in input_flist:
|
||||
f = h5py.File(f, 'r')
|
||||
data.append(HDFSequence(f['X'], batch_size))
|
||||
ylist.append(f['Y'][:])
|
||||
|
||||
params = {
|
||||
'bin_construct_sample_cnt': 200000,
|
||||
'max_bin': 255,
|
||||
}
|
||||
y = np.concatenate(ylist)
|
||||
dataset = lgb.Dataset(data, label=y, params=params)
|
||||
# With binary dataset created, we can use either Python API or cmdline version to train.
|
||||
#
|
||||
# Note: in order to create exactly the same dataset with the one created in simple_example.py, we need
|
||||
# to modify simple_example.py to pass numpy array instead of pandas DataFrame to Dataset constructor.
|
||||
# The reason is that DataFrame column names will be used in Dataset. For a DataFrame with Int64Index
|
||||
# as columns, Dataset will use column names like ["0", "1", "2", ...]. While for numpy array, column names
|
||||
# are using the default one assigned in C++ code (dataset_loader.cpp), like ["Column_0", "Column_1", ...].
|
||||
dataset.save_binary('regression.train.from_hdf.bin')
|
||||
|
||||
|
||||
def save2hdf(input_data, fname, batch_size):
|
||||
"""Store numpy array to HDF5 file.
|
||||
|
||||
Please note chunk size settings in the implementation for I/O performance optimization.
|
||||
"""
|
||||
with h5py.File(fname, 'w') as f:
|
||||
for name, data in input_data.items():
|
||||
nrow, ncol = data.shape
|
||||
if ncol == 1:
|
||||
# Y has a single column and we read it in single shot. So store it as an 1-d array.
|
||||
chunk = (nrow,)
|
||||
data = data.values.flatten()
|
||||
else:
|
||||
# We use random access for data sampling when creating LightGBM Dataset from Sequence.
|
||||
# When accessing any element in a HDF5 chunk, it's read entirely.
|
||||
# To save I/O for sampling, we should keep number of total chunks much larger than sample count.
|
||||
# Here we are just creating a chunk size that matches with batch_size.
|
||||
#
|
||||
# Also note that the data is stored in row major order to avoid extra copy when passing to
|
||||
# lightgbm Dataset.
|
||||
chunk = (batch_size, ncol)
|
||||
f.create_dataset(name, data=data, chunks=chunk, compression='lzf')
|
||||
|
||||
|
||||
def generate_hdf(input_fname, output_basename, batch_size):
|
||||
# Save to 2 HDF5 files for demonstration.
|
||||
df = pd.read_csv(input_fname, header=None, sep='\t')
|
||||
|
||||
mid = len(df) // 2
|
||||
df1 = df.iloc[:mid]
|
||||
df2 = df.iloc[mid:]
|
||||
|
||||
# We can store multiple datasets inside a single HDF5 file.
|
||||
# Separating X and Y for choosing best chunk size for data loading.
|
||||
fname1 = f'{output_basename}1.h5'
|
||||
fname2 = f'{output_basename}2.h5'
|
||||
save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, fname1, batch_size)
|
||||
save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, fname2, batch_size)
|
||||
|
||||
return [fname1, fname2]
|
||||
|
||||
|
||||
def main():
|
||||
batch_size = 64
|
||||
output_basename = 'regression'
|
||||
hdf_files = generate_hdf('../regression/regression.train', output_basename, batch_size)
|
||||
|
||||
create_dataset_from_multiple_hdf(hdf_files, batch_size=batch_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -53,6 +53,32 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError();
|
|||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_RegisterLogCallback(void (*callback)(const char*));
|
||||
|
||||
/*!
|
||||
* \brief Get number of samples based on parameters and total number of rows of data.
|
||||
* \param num_total_row Number of total rows
|
||||
* \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` is used to calculate returned value
|
||||
* \param[out] out Number of samples. This value is used to pre-allocate memory to hold sample indices when calling ``LGBM_SampleIndices``
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t num_total_row,
|
||||
const char* parameters,
|
||||
int* out);
|
||||
|
||||
/*!
|
||||
* \brief Create sample indices for total number of rows.
|
||||
* \note
|
||||
* You should pre-allocate memory for ``out``, you can get its length by ``LGBM_GetSampleCount``.
|
||||
* \param num_total_row Number of total rows
|
||||
* \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` and ``data_random_seed`` are used to produce the output
|
||||
* \param[out] out Created indices, type is int32_t
|
||||
* \param[out] out_len Number of indices. This may be less than the one returned by ``LGBM_GetSampleCount``
|
||||
* \return 0 when succeed, -1 when failure happens
|
||||
*/
|
||||
LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t num_total_row,
|
||||
const char* parameters,
|
||||
void* out,
|
||||
int32_t* out_len);
|
||||
|
||||
// --- start Dataset interface
|
||||
|
||||
/*!
|
||||
|
|
|
@ -5,7 +5,7 @@ Contributors: https://github.com/microsoft/LightGBM/graphs/contributors.
|
|||
"""
|
||||
import os
|
||||
|
||||
from .basic import Booster, Dataset, register_logger
|
||||
from .basic import Booster, Dataset, Sequence, register_logger
|
||||
from .callback import early_stopping, print_evaluation, record_evaluation, reset_parameter
|
||||
from .engine import CVBooster, cv, train
|
||||
|
||||
|
@ -29,7 +29,7 @@ if os.path.isfile(os.path.join(dir_path, 'VERSION.txt')):
|
|||
with open(os.path.join(dir_path, 'VERSION.txt')) as version_file:
|
||||
__version__ = version_file.read().strip()
|
||||
|
||||
__all__ = ['Dataset', 'Booster', 'CVBooster',
|
||||
__all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence',
|
||||
'register_logger',
|
||||
'train', 'cv',
|
||||
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# coding: utf-8
|
||||
"""Wrapper for C API of LightGBM."""
|
||||
import abc
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
|
@ -9,7 +10,7 @@ from copy import deepcopy
|
|||
from functools import wraps
|
||||
from logging import Logger
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
@ -17,6 +18,18 @@ import scipy.sparse
|
|||
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series
|
||||
from .libpath import find_lib_path
|
||||
|
||||
ZERO_THRESHOLD = 1e-35
|
||||
|
||||
|
||||
def _get_sample_count(total_nrow: int, params: str):
|
||||
sample_cnt = ctypes.c_int(0)
|
||||
_safe_call(_LIB.LGBM_GetSampleCount(
|
||||
ctypes.c_int32(total_nrow),
|
||||
c_str(params),
|
||||
ctypes.byref(sample_cnt),
|
||||
))
|
||||
return sample_cnt.value
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def info(self, msg):
|
||||
|
@ -593,6 +606,67 @@ def _load_pandas_categorical(file_name=None, model_str=None):
|
|||
return None
|
||||
|
||||
|
||||
class Sequence(abc.ABC):
|
||||
"""
|
||||
Generic data access interface.
|
||||
|
||||
Object should support the following operations:
|
||||
|
||||
.. code-block::
|
||||
|
||||
# Get total row number.
|
||||
>>> len(seq)
|
||||
# Random access by row index. Used for data sampling.
|
||||
>>> seq[10]
|
||||
# Range data access. Used to read data in batch when constructing Dataset.
|
||||
>>> seq[0:100]
|
||||
# Optionally specify batch_size to control range data read size.
|
||||
>>> seq.batch_size
|
||||
|
||||
- With random access, **data sampling does not need to go through all data**.
|
||||
- With range data access, there's **no need to read all data into memory thus reduce memory usage**.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
batch_size : int
|
||||
Default size of a batch.
|
||||
"""
|
||||
|
||||
batch_size = 4096 # Defaults to read 4K rows in each batch.
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, idx: Union[int, slice]) -> np.ndarray:
|
||||
"""Return data for given row index.
|
||||
|
||||
A basic implementation should look like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if isinstance(idx, numbers.Integral):
|
||||
return self.__get_one_line__(idx)
|
||||
elif isinstance(idx, slice):
|
||||
return np.stack(self.__get_one_line__(i) for i in range(idx.start, idx.stop))
|
||||
else:
|
||||
raise TypeError(f"Sequence index must be integer or slice, got {type(idx)}")
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : int, slice[int]
|
||||
Item index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : numpy 1-D array, numpy 2-D array
|
||||
1-D array if idx is int, 2-D array if idx is slice.
|
||||
"""
|
||||
raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Return row count of this sequence."""
|
||||
raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __len__()")
|
||||
|
||||
|
||||
class _InnerPredictor:
|
||||
"""_InnerPredictor of LightGBM.
|
||||
|
||||
|
@ -1057,7 +1131,7 @@ class Dataset:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse or list of numpy arrays
|
||||
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
|
||||
Data source of Dataset.
|
||||
If string, it represents the path to txt file.
|
||||
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
|
||||
|
@ -1113,6 +1187,7 @@ class Dataset:
|
|||
self.feature_penalty = None
|
||||
self.monotone_constraints = None
|
||||
self.version = 0
|
||||
self._start_row = 0 # Used when pushing rows one by one.
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
|
@ -1120,6 +1195,149 @@ class Dataset:
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _create_sample_indices(self, total_nrow: int) -> np.ndarray:
|
||||
"""Get an array of randomly chosen indices from this ``Dataset``.
|
||||
|
||||
Indices are sampled without replacement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
total_nrow : int
|
||||
Total number of rows to sample from.
|
||||
If this value is greater than the value of parameter ``bin_construct_sample_cnt``, only ``bin_construct_sample_cnt`` indices will be used.
|
||||
If Dataset has multiple input data, this should be the sum of rows of every file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
indices : numpy array
|
||||
Indices for sampled data.
|
||||
"""
|
||||
param_str = param_dict_to_str(self.get_params())
|
||||
sample_cnt = _get_sample_count(total_nrow, param_str)
|
||||
indices = np.empty(sample_cnt, dtype=np.int32)
|
||||
ptr_data, _, _ = c_int_array(indices)
|
||||
actual_sample_cnt = ctypes.c_int32(0)
|
||||
|
||||
_safe_call(_LIB.LGBM_SampleIndices(
|
||||
ctypes.c_int32(total_nrow),
|
||||
c_str(param_str),
|
||||
ptr_data,
|
||||
ctypes.byref(actual_sample_cnt),
|
||||
))
|
||||
return indices[:actual_sample_cnt.value]
|
||||
|
||||
def _init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dataset':
|
||||
"""Create dataset from a reference dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
total_nrow : int
|
||||
Number of rows expected to add to dataset.
|
||||
ref_dataset : Dataset
|
||||
Reference dataset to extract meta from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : Dataset
|
||||
Constructed Dataset object.
|
||||
"""
|
||||
self.handle = ctypes.c_void_p()
|
||||
_safe_call(_LIB.LGBM_DatasetCreateByReference(
|
||||
ref_dataset,
|
||||
ctypes.c_int64(total_nrow),
|
||||
ctypes.byref(self.handle),
|
||||
))
|
||||
return self
|
||||
|
||||
def _init_from_sample(
|
||||
self,
|
||||
sample_data: List[np.ndarray],
|
||||
sample_indices: List[np.ndarray],
|
||||
sample_cnt: int,
|
||||
total_nrow: int,
|
||||
) -> "Dataset":
|
||||
"""Create Dataset from sampled data structures.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sample_data : list of numpy arrays
|
||||
Sample data for each column.
|
||||
sample_indices : list of numpy arrays
|
||||
Sample data row index for each column.
|
||||
sample_cnt : int
|
||||
Number of samples.
|
||||
total_nrow : int
|
||||
Total number of rows for all input files.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : Dataset
|
||||
Constructed Dataset object.
|
||||
"""
|
||||
ncol = len(sample_indices)
|
||||
assert len(sample_data) == ncol, "#sample data column != #column indices"
|
||||
|
||||
for i in range(ncol):
|
||||
if sample_data[i].dtype != np.double:
|
||||
raise ValueError(f"sample_data[{i}] type {sample_data[i].dtype} is not double")
|
||||
if sample_indices[i].dtype != np.int32:
|
||||
raise ValueError(f"sample_indices[{i}] type {sample_indices[i].dtype} is not int32")
|
||||
|
||||
# c type: double**
|
||||
# each double* element points to start of each column of sample data.
|
||||
sample_col_ptr = (ctypes.POINTER(ctypes.c_double) * ncol)()
|
||||
# c type int**
|
||||
# each int* points to start of indices for each column
|
||||
indices_col_ptr = (ctypes.POINTER(ctypes.c_int32) * ncol)()
|
||||
for i in range(ncol):
|
||||
sample_col_ptr[i] = c_float_array(sample_data[i])[0]
|
||||
indices_col_ptr[i] = c_int_array(sample_indices[i])[0]
|
||||
|
||||
num_per_col = np.array([len(d) for d in sample_indices], dtype=np.int32)
|
||||
num_per_col_ptr, _, _ = c_int_array(num_per_col)
|
||||
|
||||
self.handle = ctypes.c_void_p()
|
||||
params_str = param_dict_to_str(self.get_params())
|
||||
_safe_call(_LIB.LGBM_DatasetCreateFromSampledColumn(
|
||||
ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
|
||||
ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))),
|
||||
ctypes.c_int32(ncol),
|
||||
num_per_col_ptr,
|
||||
ctypes.c_int32(sample_cnt),
|
||||
ctypes.c_int32(total_nrow),
|
||||
c_str(params_str),
|
||||
ctypes.byref(self.handle),
|
||||
))
|
||||
return self
|
||||
|
||||
def _push_rows(self, data: np.ndarray) -> 'Dataset':
|
||||
"""Add rows to Dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : numpy 1-D array
|
||||
New data to add to the Dataset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : Dataset
|
||||
Dataset object.
|
||||
"""
|
||||
nrow, ncol = data.shape
|
||||
data = data.reshape(data.size)
|
||||
data_ptr, data_type, _ = c_float_array(data)
|
||||
|
||||
_safe_call(_LIB.LGBM_DatasetPushRows(
|
||||
self.handle,
|
||||
data_ptr,
|
||||
data_type,
|
||||
ctypes.c_int32(nrow),
|
||||
ctypes.c_int32(ncol),
|
||||
ctypes.c_int32(self._start_row),
|
||||
))
|
||||
self._start_row += nrow
|
||||
return self
|
||||
|
||||
def get_params(self):
|
||||
"""Get the used parameters in the Dataset.
|
||||
|
||||
|
@ -1265,8 +1483,15 @@ class Dataset:
|
|||
self.__init_from_csc(data, params_str, ref_dataset)
|
||||
elif isinstance(data, np.ndarray):
|
||||
self.__init_from_np2d(data, params_str, ref_dataset)
|
||||
elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data):
|
||||
self.__init_from_list_np2d(data, params_str, ref_dataset)
|
||||
elif isinstance(data, list) and len(data) > 0:
|
||||
if all(isinstance(x, np.ndarray) for x in data):
|
||||
self.__init_from_list_np2d(data, params_str, ref_dataset)
|
||||
elif all(isinstance(x, Sequence) for x in data):
|
||||
self.__init_from_seqs(data, ref_dataset)
|
||||
else:
|
||||
raise TypeError('Data list can only be of ndarray or Sequence')
|
||||
elif isinstance(data, Sequence):
|
||||
self.__init_from_seqs([data], ref_dataset)
|
||||
elif isinstance(data, dt_DataTable):
|
||||
self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
|
||||
else:
|
||||
|
@ -1294,6 +1519,77 @@ class Dataset:
|
|||
# set feature names
|
||||
return self.set_feature_name(feature_name)
|
||||
|
||||
def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]):
|
||||
offset = 0
|
||||
seq_id = 0
|
||||
seq = seqs[seq_id]
|
||||
for row_id in indices:
|
||||
assert row_id >= offset, "sample indices are expected to be monotonic"
|
||||
while row_id >= offset + len(seq):
|
||||
offset += len(seq)
|
||||
seq_id += 1
|
||||
seq = seqs[seq_id]
|
||||
id_in_seq = row_id - offset
|
||||
row = seq[id_in_seq]
|
||||
yield row if row.flags['OWNDATA'] else row.copy()
|
||||
|
||||
def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
||||
"""Sample data from seqs.
|
||||
|
||||
Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats()
|
||||
|
||||
Returns
|
||||
-------
|
||||
sampled_rows, sampled_row_indices
|
||||
"""
|
||||
indices = self._create_sample_indices(total_nrow)
|
||||
|
||||
# Select sampled rows, transpose to column order.
|
||||
sampled = np.array([row for row in self.__yield_row_from(seqs, indices)])
|
||||
sampled = sampled.T
|
||||
|
||||
filtered = []
|
||||
filtered_idx = []
|
||||
sampled_row_range = np.arange(len(indices), dtype=np.int32)
|
||||
for col in sampled:
|
||||
col_predicate = (np.abs(col) > ZERO_THRESHOLD) | np.isnan(col)
|
||||
filtered_col = col[col_predicate]
|
||||
filtered_row_idx = sampled_row_range[col_predicate]
|
||||
|
||||
filtered.append(filtered_col)
|
||||
filtered_idx.append(filtered_row_idx)
|
||||
|
||||
return filtered, filtered_idx
|
||||
|
||||
def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset'] = None):
|
||||
"""
|
||||
Initialize data from list of Sequence objects.
|
||||
|
||||
Sequence: Generic Data Access Object
|
||||
Supports random access and access by batch if properly defined by user
|
||||
|
||||
Data scheme uniformity are trusted, not checked
|
||||
"""
|
||||
total_nrow = sum(len(seq) for seq in seqs)
|
||||
|
||||
# create validation dataset from ref_dataset
|
||||
if ref_dataset is not None:
|
||||
self._init_from_ref_dataset(total_nrow, ref_dataset)
|
||||
else:
|
||||
param_str = param_dict_to_str(self.get_params())
|
||||
sample_cnt = _get_sample_count(total_nrow, param_str)
|
||||
|
||||
sample_data, col_indices = self.__sample(seqs, total_nrow)
|
||||
self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow)
|
||||
|
||||
for seq in seqs:
|
||||
nrow = len(seq)
|
||||
batch_size = getattr(seq, 'batch_size', None) or Sequence.batch_size
|
||||
for start in range(0, nrow, batch_size):
|
||||
end = min(start + batch_size, nrow)
|
||||
self._push_rows(seq[start:end])
|
||||
return self
|
||||
|
||||
def __init_from_np2d(self, mat, params_str, ref_dataset):
|
||||
"""Initialize data from a 2-D numpy matrix."""
|
||||
if len(mat.shape) != 2:
|
||||
|
@ -1477,7 +1773,7 @@ class Dataset:
|
|||
|
||||
Parameters
|
||||
----------
|
||||
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse or list of numpy arrays
|
||||
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
|
||||
Data source of Dataset.
|
||||
If string, it represents the path to txt file.
|
||||
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
|
||||
|
|
|
@ -898,6 +898,51 @@ int LGBM_RegisterLogCallback(void (*callback)(const char*)) {
|
|||
API_END();
|
||||
}
|
||||
|
||||
static inline int SampleCount(int32_t total_nrow, const Config& config) {
|
||||
return static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
|
||||
}
|
||||
|
||||
static inline std::vector<int32_t> CreateSampleIndices(int32_t total_nrow, const Config& config) {
|
||||
Random rand(config.data_random_seed);
|
||||
int sample_cnt = SampleCount(total_nrow, config);
|
||||
return rand.Sample(total_nrow, sample_cnt);
|
||||
}
|
||||
|
||||
int LGBM_GetSampleCount(int32_t num_total_row,
|
||||
const char* parameters,
|
||||
int* out) {
|
||||
API_BEGIN();
|
||||
if (out == nullptr) {
|
||||
Log::Fatal("LGBM_GetSampleCount output is nullptr");
|
||||
}
|
||||
auto param = Config::Str2Map(parameters);
|
||||
Config config;
|
||||
config.Set(param);
|
||||
|
||||
*out = SampleCount(num_total_row, config);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_SampleIndices(int32_t num_total_row,
|
||||
const char* parameters,
|
||||
void* out,
|
||||
int32_t* out_len) {
|
||||
// This API is to keep python binding's behavior the same with C++ implementation.
|
||||
// Sample count, random seed etc. should be provided in parameters.
|
||||
API_BEGIN();
|
||||
if (out == nullptr) {
|
||||
Log::Fatal("LGBM_SampleIndices output is nullptr");
|
||||
}
|
||||
auto param = Config::Str2Map(parameters);
|
||||
Config config;
|
||||
config.Set(param);
|
||||
|
||||
auto sample_indices = CreateSampleIndices(num_total_row, config);
|
||||
memcpy(out, sample_indices.data(), sizeof(int32_t) * sample_indices.size());
|
||||
*out_len = static_cast<int32_t>(sample_indices.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
int LGBM_DatasetCreateFromFile(const char* filename,
|
||||
const char* parameters,
|
||||
const DatasetHandle reference,
|
||||
|
@ -1038,7 +1083,6 @@ int LGBM_DatasetCreateFromMat(const void* data,
|
|||
out);
|
||||
}
|
||||
|
||||
|
||||
int LGBM_DatasetCreateFromMats(int32_t nmat,
|
||||
const void** data,
|
||||
int data_type,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# coding: utf-8
|
||||
import filecmp
|
||||
import numbers
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
@ -89,6 +91,109 @@ def test_basic(tmp_path):
|
|||
bst.predict, tname)
|
||||
|
||||
|
||||
class NumpySequence(lgb.Sequence):
|
||||
def __init__(self, ndarray, batch_size):
|
||||
self.ndarray = ndarray
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# The simple implementation is just a single "return self.ndarray[idx]"
|
||||
# The following is for demo and testing purpose.
|
||||
if isinstance(idx, numbers.Integral):
|
||||
return self.ndarray[idx]
|
||||
elif isinstance(idx, slice):
|
||||
if not (idx.step is None or idx.step == 1):
|
||||
raise NotImplementedError("No need to implement, caller will not set step by now")
|
||||
return self.ndarray[idx.start:idx.stop]
|
||||
else:
|
||||
raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ndarray)
|
||||
|
||||
|
||||
def _create_sequence_from_ndarray(data, num_seq, batch_size):
|
||||
if num_seq == 1:
|
||||
return NumpySequence(data, batch_size)
|
||||
|
||||
nrow = data.shape[0]
|
||||
seqs = []
|
||||
seq_size = nrow // num_seq
|
||||
for start in range(0, nrow, seq_size):
|
||||
end = min(start + seq_size, nrow)
|
||||
seq = NumpySequence(data[start:end], batch_size)
|
||||
seqs.append(seq)
|
||||
return seqs
|
||||
|
||||
|
||||
@pytest.mark.parametrize('sample_count', [11, 100, None])
|
||||
@pytest.mark.parametrize('batch_size', [3, None])
|
||||
@pytest.mark.parametrize('include_0_and_nan', [False, True])
|
||||
@pytest.mark.parametrize('num_seq', [1, 3])
|
||||
def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
|
||||
params = {'bin_construct_sample_cnt': sample_count}
|
||||
|
||||
nrow = 50
|
||||
half_nrow = nrow // 2
|
||||
ncol = 11
|
||||
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
|
||||
|
||||
if include_0_and_nan:
|
||||
# whole col
|
||||
data[:, 0] = 0
|
||||
data[:, 1] = np.nan
|
||||
|
||||
# half col
|
||||
data[:half_nrow, 3] = 0
|
||||
data[:half_nrow, 2] = np.nan
|
||||
|
||||
data[half_nrow:-2, 4] = 0
|
||||
data[:half_nrow, 4] = np.nan
|
||||
|
||||
X = data[:, :-1]
|
||||
Y = data[:, -1]
|
||||
|
||||
npy_bin_fname = os.path.join(tmpdir, 'data_from_npy.bin')
|
||||
seq_bin_fname = os.path.join(tmpdir, 'data_from_seq.bin')
|
||||
|
||||
# Create dataset from numpy array directly.
|
||||
ds = lgb.Dataset(X, label=Y, params=params)
|
||||
ds.save_binary(npy_bin_fname)
|
||||
|
||||
# Create dataset using Sequence.
|
||||
seqs = _create_sequence_from_ndarray(X, num_seq, batch_size)
|
||||
seq_ds = lgb.Dataset(seqs, label=Y, params=params)
|
||||
seq_ds.save_binary(seq_bin_fname)
|
||||
|
||||
assert filecmp.cmp(npy_bin_fname, seq_bin_fname)
|
||||
|
||||
# Test for validation set.
|
||||
# Select some random rows as valid data.
|
||||
rng = np.random.default_rng() # Pass integer to set seed when needed.
|
||||
valid_idx = (rng.random(10) * nrow).astype(np.int)
|
||||
valid_data = data[valid_idx, :]
|
||||
valid_X = valid_data[:, :-1]
|
||||
valid_Y = valid_data[:, -1]
|
||||
|
||||
valid_npy_bin_fname = os.path.join(tmpdir, 'valid_data_from_npy.bin')
|
||||
valid_seq_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq.bin')
|
||||
valid_seq2_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq2.bin')
|
||||
|
||||
valid_ds = lgb.Dataset(valid_X, label=valid_Y, params=params, reference=ds)
|
||||
valid_ds.save_binary(valid_npy_bin_fname)
|
||||
|
||||
# From Dataset constructor, with dataset from numpy array.
|
||||
valid_seqs = _create_sequence_from_ndarray(valid_X, num_seq, batch_size)
|
||||
valid_seq_ds = lgb.Dataset(valid_seqs, label=valid_Y, params=params, reference=ds)
|
||||
valid_seq_ds.save_binary(valid_seq_bin_fname)
|
||||
assert filecmp.cmp(valid_npy_bin_fname, valid_seq_bin_fname)
|
||||
|
||||
# From Dataset.create_valid, with dataset from sequence.
|
||||
valid_seq_ds2 = seq_ds.create_valid(valid_seqs, label=valid_Y, params=params)
|
||||
valid_seq_ds2.save_binary(valid_seq2_bin_fname)
|
||||
assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname)
|
||||
|
||||
|
||||
def test_chunked_dataset():
|
||||
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
|
||||
random_state=2)
|
||||
|
|
Загрузка…
Ссылка в новой задаче