зеркало из https://github.com/microsoft/LightGBM.git
* adding pred_contrib support * add tests * linting * remove raw_score * add pred kwargs * faster tests * Apply suggestions from code review Co-authored-by: Nikita Titov <nekit94-08@mail.ru> * changes to tests * Update tests/python_package_test/test_dask.py Co-authored-by: Nikita Titov <nekit94-08@mail.ru> Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
This commit is contained in:
Родитель
3c7e7e0b7e
Коммит
d9a96c90cb
|
@ -280,18 +280,30 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
|
|||
return results[0]
|
||||
|
||||
|
||||
def _predict_part(part, model, proba, **kwargs):
|
||||
def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
|
||||
data = part.values if isinstance(part, pd.DataFrame) else part
|
||||
|
||||
if data.shape[0] == 0:
|
||||
result = np.array([])
|
||||
elif proba:
|
||||
result = model.predict_proba(data, **kwargs)
|
||||
elif pred_proba:
|
||||
result = model.predict_proba(
|
||||
data,
|
||||
raw_score=raw_score,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contrib=pred_contrib,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
result = model.predict(data, **kwargs)
|
||||
result = model.predict(
|
||||
data,
|
||||
raw_score=raw_score,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contrib=pred_contrib,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if isinstance(part, pd.DataFrame):
|
||||
if proba:
|
||||
if pred_proba or pred_contrib:
|
||||
result = pd.DataFrame(result, index=part.index)
|
||||
else:
|
||||
result = pd.Series(result, index=part.index, name='predictions')
|
||||
|
@ -299,7 +311,8 @@ def _predict_part(part, model, proba, **kwargs):
|
|||
return result
|
||||
|
||||
|
||||
def _predict(model, data, proba=False, dtype=np.float32, **kwargs):
|
||||
def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False,
|
||||
dtype=np.float32, **kwargs):
|
||||
"""Inner predict routine.
|
||||
|
||||
Parameters
|
||||
|
@ -307,20 +320,42 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):
|
|||
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
|
||||
data : dask array of shape = [n_samples, n_features]
|
||||
Input feature matrix.
|
||||
proba : bool
|
||||
Should method return results of predict_proba (proba == True) or predict (proba == False).
|
||||
pred_proba : bool, optional (default=False)
|
||||
Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
|
||||
pred_leaf : bool, optional (default=False)
|
||||
Whether to predict leaf index.
|
||||
pred_contrib : bool, optional (default=False)
|
||||
Whether to predict feature contributions.
|
||||
dtype : np.dtype
|
||||
Dtype of the output.
|
||||
kwargs : other parameters passed to predict or predict_proba method
|
||||
kwargs : dict
|
||||
Other parameters passed to ``predict`` or ``predict_proba`` method.
|
||||
"""
|
||||
if isinstance(data, dd._Frame):
|
||||
return data.map_partitions(_predict_part, model=model, proba=proba, **kwargs).values
|
||||
return data.map_partitions(
|
||||
_predict_part,
|
||||
model=model,
|
||||
raw_score=raw_score,
|
||||
pred_proba=pred_proba,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contrib=pred_contrib,
|
||||
**kwargs
|
||||
).values
|
||||
elif isinstance(data, da.Array):
|
||||
if proba:
|
||||
if pred_proba:
|
||||
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
|
||||
else:
|
||||
kwargs['drop_axis'] = 1
|
||||
return data.map_blocks(_predict_part, model=model, proba=proba, dtype=dtype, **kwargs)
|
||||
return data.map_blocks(
|
||||
_predict_part,
|
||||
model=model,
|
||||
raw_score=raw_score,
|
||||
pred_proba=pred_proba,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contrib=pred_contrib,
|
||||
dtype=dtype,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data)))
|
||||
|
||||
|
@ -370,7 +405,7 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
|
|||
|
||||
def predict_proba(self, X, **kwargs):
|
||||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
|
||||
return _predict(self.to_local(), X, proba=True, **kwargs)
|
||||
return _predict(self.to_local(), X, pred_proba=True, **kwargs)
|
||||
predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__
|
||||
|
||||
def to_local(self):
|
||||
|
|
|
@ -235,6 +235,55 @@ def test_classifier(output, centers, client, listen_port):
|
|||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output', data_output)
|
||||
@pytest.mark.parametrize('centers', data_centers)
|
||||
def test_classifier_pred_contrib(output, centers, client, listen_port):
|
||||
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)
|
||||
|
||||
dask_classifier = dlgbm.DaskLGBMClassifier(
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
tree_learner='data',
|
||||
n_estimators=10,
|
||||
num_leaves=10
|
||||
)
|
||||
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
|
||||
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()
|
||||
|
||||
local_classifier = lightgbm.LGBMClassifier(
|
||||
n_estimators=10,
|
||||
num_leaves=10
|
||||
)
|
||||
local_classifier.fit(X, y, sample_weight=w)
|
||||
local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True)
|
||||
|
||||
if output == 'scipy_csr_matrix':
|
||||
preds_with_contrib = np.array(preds_with_contrib.todense())
|
||||
|
||||
# shape depends on whether it is binary or multiclass classification
|
||||
num_features = dask_classifier.n_features_
|
||||
num_classes = dask_classifier.n_classes_
|
||||
if num_classes == 2:
|
||||
expected_num_cols = num_features + 1
|
||||
else:
|
||||
expected_num_cols = (num_features + 1) * num_classes
|
||||
|
||||
# * shape depends on whether it is binary or multiclass classification
|
||||
# * matrix for binary classification is of the form [feature_contrib, base_value],
|
||||
# for multi-class it's [feat_contrib_class1, base_value_class1, feat_contrib_class2, base_value_class2, etc.]
|
||||
# * contrib outputs for distributed training are different than from local training, so we can just test
|
||||
# that the output has the right shape and base values are in the right position
|
||||
assert preds_with_contrib.shape[1] == expected_num_cols
|
||||
assert preds_with_contrib.shape == local_preds_with_contrib.shape
|
||||
|
||||
if num_classes == 2:
|
||||
assert len(np.unique(preds_with_contrib[:, num_features]) == 1)
|
||||
else:
|
||||
for i in range(num_classes):
|
||||
base_value_col = num_features * (i + 1) + i
|
||||
assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1)
|
||||
|
||||
|
||||
def test_training_does_not_fail_on_port_conflicts(client):
|
||||
_, _, _, dX, dy, dw = _create_data('classification', output='array')
|
||||
|
||||
|
@ -315,6 +364,37 @@ def test_regressor(output, client, listen_port):
|
|||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output', data_output)
|
||||
def test_regressor_pred_contrib(output, client, listen_port):
|
||||
X, y, w, dX, dy, dw = _create_data('regression', output=output)
|
||||
|
||||
dask_regressor = dlgbm.DaskLGBMRegressor(
|
||||
time_out=5,
|
||||
local_listen_port=listen_port,
|
||||
tree_learner='data',
|
||||
n_estimators=10,
|
||||
num_leaves=10
|
||||
)
|
||||
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
|
||||
preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute()
|
||||
|
||||
local_regressor = lightgbm.LGBMRegressor(
|
||||
n_estimators=10,
|
||||
num_leaves=10
|
||||
)
|
||||
local_regressor.fit(X, y, sample_weight=w)
|
||||
local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True)
|
||||
|
||||
if output == "scipy_csr_matrix":
|
||||
preds_with_contrib = np.array(preds_with_contrib.todense())
|
||||
|
||||
# contrib outputs for distributed training are different than from local training, so we can just test
|
||||
# that the output has the right shape and base values are in the right position
|
||||
num_features = dX.shape[1]
|
||||
assert preds_with_contrib.shape[1] == num_features + 1
|
||||
assert preds_with_contrib.shape == local_preds_with_contrib.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output', data_output)
|
||||
@pytest.mark.parametrize('alpha', [.1, .5, .9])
|
||||
def test_regressor_quantile(output, client, listen_port, alpha):
|
||||
|
|
Загрузка…
Ссылка в новой задаче