[dask][docs] initial setup for Dask docs (#3822)

* initial Dask docs

* fix MRO

* address review comments
This commit is contained in:
Nikita Titov 2021-01-25 05:58:52 +03:00 коммит произвёл GitHub
Родитель 98a85a83ce
Коммит 36322ceeae
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 50 добавлений и 16 удалений

Просмотреть файл

@ -87,8 +87,6 @@ ML.NET (.NET/C#-package): https://github.com/dotnet/machinelearning
LightGBM.NET (.NET/C#-package): https://github.com/rca22/LightGBM.Net
Dask-LightGBM (distributed and parallel Python-package): https://github.com/dask/dask-lightgbm
Ruby gem: https://github.com/ankane/lightgbm
LightGBM4j (Java high-level binding): https://github.com/metarank/lightgbm4j

Просмотреть файл

@ -24,7 +24,7 @@ You may also ping a member of the core team according to the relevant area of ex
- `@chivee <https://github.com/chivee>`__ **Qiwei Ye** (C++ code / Python-package)
- `@btrotta <https://github.com/btrotta>`__ **Belinda Trotta** (C++ code)
- `@Laurae2 <https://github.com/Laurae2>`__ **Damien Soukhavong** (R-package)
- `@jameslamb <https://github.com/jameslamb>`__ **James Lamb** (R-package)
- `@jameslamb <https://github.com/jameslamb>`__ **James Lamb** (R-package / Dask-package)
- `@wxchan <https://github.com/wxchan>`__ **Wenxuan Chen** (Python-package)
- `@henry0312 <https://github.com/henry0312>`__ **Tsukasa Omoto** (Python-package)
- `@StrikerRUS <https://github.com/StrikerRUS>`__ **Nikita Titov** (Python-package)

Просмотреть файл

@ -7,7 +7,7 @@ Follow the `Quick Start <./Quick-Start.rst>`__ to know how to use LightGBM first
**List of external libraries in which LightGBM can be used in a distributed fashion**
- `Dask-LightGBM`_ allows to create ML workflow on Dask distributed data structures.
- `Dask API of LightGBM <./Python-API.rst#dask-api>`__ (formerly it was a separate package) allows to create ML workflow on Dask distributed data structures.
- `MMLSpark`_ integrates LightGBM into Apache Spark ecosystem.
`The following example`_ demonstrates how easy it's possible to utilize the great power of Spark.
@ -134,8 +134,6 @@ Example
- `A simple parallel example`_
.. _Dask-LightGBM: https://github.com/dask/dask-lightgbm
.. _MMLSpark: https://aka.ms/spark
.. _The following example: https://github.com/Azure/mmlspark/blob/master/notebooks/samples/LightGBM%20-%20Quantile%20Regression%20for%20Drug%20Discovery.ipynb

Просмотреть файл

@ -33,6 +33,16 @@ Scikit-learn API
LGBMRegressor
LGBMRanker
Dask API
--------
.. autosummary::
:toctree: pythonapi/
DaskLGBMClassifier
DaskLGBMRegressor
DaskLGBMRanker
Callbacks
---------

Просмотреть файл

@ -39,7 +39,7 @@ INTERNAL_REF_REGEX = compile(r"(?P<url>\.\/.+)(?P<extension>\.rst)(?P<anchor>$|#
# -- mock out modules
MOCK_MODULES = ['numpy', 'scipy', 'scipy.sparse',
'sklearn', 'matplotlib', 'pandas', 'graphviz']
'sklearn', 'matplotlib', 'pandas', 'graphviz', 'dask', 'dask.distributed']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = Mock()

Просмотреть файл

@ -183,12 +183,22 @@ Run ``python setup.py install --bit32``, if you want to use 32-bit version. All
If you get any errors during installation or due to any other reasons, you may want to build dynamic library from sources by any method you prefer (see `Installation Guide <https://github.com/microsoft/LightGBM/blob/master/docs/Installation-Guide.rst>`__) and then just run ``python setup.py install --precompile``.
Build Wheel File
****************
You can use ``python setup.py bdist_wheel`` instead of ``python setup.py install`` to build wheel file and use it for installation later. This might be useful for systems with restricted or completely without network access.
Install Dask-package
''''''''''''''''''''
To install all additional dependencies required for Dask-package, you can append ``[dask]`` to LightGBM package name:
.. code:: sh
pip install lightgbm[dask]
Or replace ``python setup.py install`` with ``pip install -e .[dask]`` if you are installing the package from source files.
Troubleshooting
---------------

Просмотреть файл

@ -19,6 +19,10 @@ try:
plot_tree, create_tree_digraph)
except ImportError:
pass
try:
from .dask import DaskLGBMRegressor, DaskLGBMClassifier, DaskLGBMRanker
except ImportError:
pass
dir_path = os.path.dirname(os.path.realpath(__file__))
@ -31,5 +35,6 @@ __all__ = ['Dataset', 'Booster', 'CVBooster',
'register_logger',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']

Просмотреть файл

@ -105,3 +105,12 @@ except ImportError:
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
"""dask"""
try:
from dask import array
from dask import dataframe
from dask.distributed import Client
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False

Просмотреть файл

@ -21,7 +21,8 @@ from dask import dataframe as dd
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait
from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call
from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
@ -393,6 +394,9 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
class _LGBMModel:
def __init__(self):
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
@ -431,7 +435,7 @@ class _LGBMModel:
setattr(dest, name, attributes[name])
class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
class DaskLGBMClassifier(LGBMClassifier, _LGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
@ -479,7 +483,7 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
return self._to_local(LGBMClassifier)
class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):
class DaskLGBMRegressor(LGBMRegressor, _LGBMModel):
"""Docstring is inherited from the lightgbm.LGBMRegressor."""
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
@ -515,7 +519,7 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):
return self._to_local(LGBMRegressor)
class DaskLGBMRanker(_LGBMModel, LGBMRanker):
class DaskLGBMRanker(LGBMRanker, _LGBMModel):
"""Docstring is inherited from the lightgbm.LGBMRanker."""
def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs):

Просмотреть файл

@ -334,7 +334,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
"xe_ndcg", "xe_ndcg_mart", "xendcg_mart"}
for obj_alias in _ConfigAliases.get("objective")):
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for ranking cv.')
raise LightGBMError('scikit-learn is required for ranking cv')
# ranking task, split according to groups
group_info = np.array(full_data.get_group(), dtype=np.int32, copy=False)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
@ -342,7 +342,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv.')
raise LightGBMError('scikit-learn is required for stratified cv')
skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else:

Просмотреть файл

@ -289,7 +289,7 @@ class LGBMModel(_LGBMModelBase):
and you should group grad and hess in this way as well.
"""
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for this module')
raise LightGBMError('scikit-learn is required for lightgbm.sklearn')
self.boosting_type = boosting_type
self.objective = objective

Просмотреть файл

@ -344,7 +344,7 @@ if __name__ == "__main__":
extras_require={
'dask': [
'dask[array]>=2.0.0',
'dask[dataframe]>=2.0.0'
'dask[dataframe]>=2.0.0',
'dask[distributed]>=2.0.0',
'pandas',
],