зеркало из https://github.com/microsoft/LightGBM.git
[python-package] Correctly recognize LGBMClassifier(num_class=2, objective="multiclass") as multiclass classification (#6524)
This commit is contained in:
Родитель
3d026629ab
Коммит
f8ec57b8eb
|
@ -157,6 +157,8 @@ _LGBM_SetFieldType = Union[
|
|||
|
||||
ZERO_THRESHOLD = 1e-35
|
||||
|
||||
_MULTICLASS_OBJECTIVES = {"multiclass", "multiclassova", "multiclass_ova", "ova", "ovr", "softmax"}
|
||||
|
||||
|
||||
def _is_zero(x: float) -> bool:
|
||||
return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD
|
||||
|
|
|
@ -10,6 +10,7 @@ import numpy as np
|
|||
import scipy.sparse
|
||||
|
||||
from .basic import (
|
||||
_MULTICLASS_OBJECTIVES,
|
||||
Booster,
|
||||
Dataset,
|
||||
LightGBMError,
|
||||
|
@ -467,7 +468,7 @@ def _extract_evaluation_meta_data(
|
|||
# It's possible, for example, to pass 3 eval sets through `eval_set`,
|
||||
# but only 1 init_score through `eval_init_score`.
|
||||
#
|
||||
# This if-else accounts for that possiblity.
|
||||
# This if-else accounts for that possibility.
|
||||
if len(collection) > i:
|
||||
return collection[i]
|
||||
else:
|
||||
|
@ -1011,7 +1012,7 @@ class LGBMModel(_LGBMModelBase):
|
|||
f"match the input. Model n_features_ is {self._n_features} and "
|
||||
f"input n_features is {n_features}"
|
||||
)
|
||||
# retrive original params that possibly can be used in both training and prediction
|
||||
# retrieve original params that possibly can be used in both training and prediction
|
||||
# and then overwrite them (considering aliases) with params that were passed directly in prediction
|
||||
predict_params = self._process_params(stage="predict")
|
||||
for alias in _ConfigAliases.get_by_alias(
|
||||
|
@ -1251,7 +1252,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
|
|||
eval_metric_list = [eval_metric]
|
||||
else:
|
||||
eval_metric_list = []
|
||||
if self._n_classes > 2:
|
||||
if self.__is_multiclass:
|
||||
for index, metric in enumerate(eval_metric_list):
|
||||
if metric in {"logloss", "binary_logloss"}:
|
||||
eval_metric_list[index] = "multi_logloss"
|
||||
|
@ -1361,7 +1362,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
|
|||
"Returning raw scores instead."
|
||||
)
|
||||
return result
|
||||
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
|
||||
elif self.__is_multiclass or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
|
||||
return result
|
||||
else:
|
||||
return np.vstack((1.0 - result, result)).transpose()
|
||||
|
@ -1389,6 +1390,11 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
|
|||
raise LGBMNotFittedError("No classes found. Need to call fit beforehand.")
|
||||
return self._n_classes
|
||||
|
||||
@property
|
||||
def __is_multiclass(self) -> bool:
|
||||
""":obj:`bool`: Indicator of whether the classifier is used for multiclass."""
|
||||
return self._n_classes > 2 or (isinstance(self._objective, str) and self._objective in _MULTICLASS_OBJECTIVES)
|
||||
|
||||
|
||||
class LGBMRanker(LGBMModel):
|
||||
"""LightGBM ranker.
|
||||
|
|
|
@ -719,6 +719,25 @@ def test_predict():
|
|||
with pytest.raises(AssertionError):
|
||||
np.testing.assert_allclose(res_engine, res_sklearn_params)
|
||||
|
||||
# Test multiclass binary classification
|
||||
num_samples = 100
|
||||
num_classes = 2
|
||||
X_train = np.linspace(start=0, stop=10, num=num_samples * 3).reshape(num_samples, 3)
|
||||
y_train = np.concatenate([np.zeros(int(num_samples / 2 - 10)), np.ones(int(num_samples / 2 + 10))])
|
||||
|
||||
gbm = lgb.train({"objective": "multiclass", "num_class": num_classes, "verbose": -1}, lgb.Dataset(X_train, y_train))
|
||||
clf = lgb.LGBMClassifier(objective="multiclass", num_classes=num_classes).fit(X_train, y_train)
|
||||
|
||||
res_engine = gbm.predict(X_train)
|
||||
res_sklearn = clf.predict_proba(X_train)
|
||||
|
||||
assert res_engine.shape == (num_samples, num_classes)
|
||||
assert res_sklearn.shape == (num_samples, num_classes)
|
||||
np.testing.assert_allclose(res_engine, res_sklearn)
|
||||
|
||||
res_class_sklearn = clf.predict(X_train)
|
||||
np.testing.assert_allclose(res_class_sklearn, y_train)
|
||||
|
||||
|
||||
def test_predict_with_params_from_init():
|
||||
X, y = load_iris(return_X_y=True)
|
||||
|
@ -1035,6 +1054,20 @@ def test_metrics():
|
|||
assert len(gbm.evals_result_["training"]) == 1
|
||||
assert "binary_logloss" in gbm.evals_result_["training"]
|
||||
|
||||
# the evaluation metric changes to multiclass metric even num classes is 2 for multiclass objective
|
||||
gbm = lgb.LGBMClassifier(objective="multiclass", num_classes=2, **params).fit(
|
||||
eval_metric="binary_logloss", **params_fit
|
||||
)
|
||||
assert len(gbm._evals_result["training"]) == 1
|
||||
assert "multi_logloss" in gbm.evals_result_["training"]
|
||||
|
||||
# the evaluation metric changes to multiclass metric even num classes is 2 for ovr objective
|
||||
gbm = lgb.LGBMClassifier(objective="ovr", num_classes=2, **params).fit(eval_metric="binary_error", **params_fit)
|
||||
assert gbm.objective_ == "ovr"
|
||||
assert len(gbm.evals_result_["training"]) == 2
|
||||
assert "multi_logloss" in gbm.evals_result_["training"]
|
||||
assert "multi_error" in gbm.evals_result_["training"]
|
||||
|
||||
|
||||
def test_multiple_eval_metrics():
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче