From 5e592fe6ff2b6eed83dd77942aab8e464768235c Mon Sep 17 00:00:00 2001 From: david-cortes Date: Tue, 12 Sep 2023 04:53:36 +0200 Subject: [PATCH] [python-package] Fix misdetected objective after multiple calls to `LGBMClassifier.fit` (#6002) --- python-package/lightgbm/sklearn.py | 2 ++ tests/python_package_test/test_sklearn.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 7e909342c..c71c233df 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -1103,6 +1103,8 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): self._classes = self._le.classes_ self._n_classes = len(self._classes) # type: ignore[arg-type] + if self.objective is None: + self._objective = None # adjust eval metrics to match whether binary or multiclass # classification is being performed diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index e41719845..2247c9a51 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type ) preds = model.predict(X) assert spearmanr(preds, y).correlation >= 0.99 + + +def test_classifier_fit_detects_classes_every_time(): + rng = np.random.default_rng(seed=123) + nrows = 1000 + ncols = 20 + + X = rng.standard_normal(size=(nrows, ncols)) + y_bin = (rng.random(size=nrows) <= .3).astype(np.float64) + y_multi = rng.integers(4, size=nrows) + + model = lgb.LGBMClassifier(verbose=-1) + for _ in range(2): + model.fit(X, y_multi) + assert model.objective_ == "multiclass" + model.fit(X, y_bin) + assert model.objective_ == "binary"