[python-package] Fix misdetected objective after multiple calls to `LGBMClassifier.fit` (#6002)

This commit is contained in:
david-cortes 2023-09-12 04:53:36 +02:00 коммит произвёл GitHub
Родитель 501ce1cb63
Коммит 5e592fe6ff
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 19 добавлений и 0 удалений

Просмотреть файл

@ -1103,6 +1103,8 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
self._classes = self._le.classes_
self._n_classes = len(self._classes) # type: ignore[arg-type]
if self.objective is None:
self._objective = None
# adjust eval metrics to match whether binary or multiclass
# classification is being performed

Просмотреть файл

@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type
)
preds = model.predict(X)
assert spearmanr(preds, y).correlation >= 0.99
def test_classifier_fit_detects_classes_every_time():
rng = np.random.default_rng(seed=123)
nrows = 1000
ncols = 20
X = rng.standard_normal(size=(nrows, ncols))
y_bin = (rng.random(size=nrows) <= .3).astype(np.float64)
y_multi = rng.integers(4, size=nrows)
model = lgb.LGBMClassifier(verbose=-1)
for _ in range(2):
model.fit(X, y_multi)
assert model.objective_ == "multiclass"
model.fit(X, y_bin)
assert model.objective_ == "binary"