[python][sklearn] add `n_estimators_` and `n_iter_` post-fit attributes (#4753)

* add n_estimators_ and n_iter_ post-fit attributes

* address review comments
This commit is contained in:
Nikita Titov 2021-11-05 20:29:49 +03:00 коммит произвёл GitHub
Родитель 99f0f3ecf1
Коммит aab212a782
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 33 добавлений и 0 удалений

Просмотреть файл

@ -847,6 +847,28 @@ class LGBMModel(_LGBMModelBase):
raise LGBMNotFittedError('No objective found. Need to call fit beforehand.')
return self._objective
@property
def n_estimators_(self) -> int:
""":obj:`int`: True number of boosting iterations performed.
This might be less than parameter ``n_estimators`` if early stopping was enabled or
if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
"""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_estimators found. Need to call fit beforehand.')
return self._Booster.current_iteration()
@property
def n_iter_(self) -> int:
""":obj:`int`: True number of boosting iterations performed.
This might be less than parameter ``n_estimators`` if early stopping was enabled or
if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
"""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_iter found. Need to call fit beforehand.')
return self._Booster.current_iteration()
@property
def booster_(self):
"""Booster: The underlying Booster of this model."""

Просмотреть файл

@ -1158,6 +1158,17 @@ def test_continue_training_with_model():
assert gbm.evals_result_['valid_0']['multi_logloss'][-1] < init_gbm.evals_result_['valid_0']['multi_logloss'][-1]
def test_actual_number_of_trees():
X = [[1, 2, 3], [1, 2, 3]]
y = [1, 1]
n_estimators = 5
gbm = lgb.LGBMRegressor(n_estimators=n_estimators).fit(X, y)
assert gbm.n_estimators == n_estimators
assert gbm.n_estimators_ == 1
assert gbm.n_iter_ == 1
np.testing.assert_array_equal(gbm.predict(np.array(X) * 10), y)
# sklearn < 0.22 requires passing "attributes" argument
@pytest.mark.skipif(sk_version < parse_version('0.22'), reason='scikit-learn version is less than 0.22')
def test_check_is_fitted():