зеркало из https://github.com/microsoft/LightGBM.git
[python-package] Fix misdetected objective after multiple calls to `LGBMClassifier.fit` (#6002)
This commit is contained in:
Родитель
501ce1cb63
Коммит
5e592fe6ff
|
@ -1103,6 +1103,8 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
|
|||
|
||||
self._classes = self._le.classes_
|
||||
self._n_classes = len(self._classes) # type: ignore[arg-type]
|
||||
if self.objective is None:
|
||||
self._objective = None
|
||||
|
||||
# adjust eval metrics to match whether binary or multiclass
|
||||
# classification is being performed
|
||||
|
|
|
@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type
|
|||
)
|
||||
preds = model.predict(X)
|
||||
assert spearmanr(preds, y).correlation >= 0.99
|
||||
|
||||
|
||||
def test_classifier_fit_detects_classes_every_time():
|
||||
rng = np.random.default_rng(seed=123)
|
||||
nrows = 1000
|
||||
ncols = 20
|
||||
|
||||
X = rng.standard_normal(size=(nrows, ncols))
|
||||
y_bin = (rng.random(size=nrows) <= .3).astype(np.float64)
|
||||
y_multi = rng.integers(4, size=nrows)
|
||||
|
||||
model = lgb.LGBMClassifier(verbose=-1)
|
||||
for _ in range(2):
|
||||
model.fit(X, y_multi)
|
||||
assert model.objective_ == "multiclass"
|
||||
model.fit(X, y_bin)
|
||||
assert model.objective_ == "binary"
|
||||
|
|
Загрузка…
Ссылка в новой задаче