зеркало из https://github.com/microsoft/LightGBM.git
[python][sklearn] add `n_estimators_` and `n_iter_` post-fit attributes (#4753)
* add n_estimators_ and n_iter_ post-fit attributes * address review comments
This commit is contained in:
Родитель
99f0f3ecf1
Коммит
aab212a782
|
@ -847,6 +847,28 @@ class LGBMModel(_LGBMModelBase):
|
|||
raise LGBMNotFittedError('No objective found. Need to call fit beforehand.')
|
||||
return self._objective
|
||||
|
||||
@property
|
||||
def n_estimators_(self) -> int:
|
||||
""":obj:`int`: True number of boosting iterations performed.
|
||||
|
||||
This might be less than parameter ``n_estimators`` if early stopping was enabled or
|
||||
if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
|
||||
"""
|
||||
if not self.__sklearn_is_fitted__():
|
||||
raise LGBMNotFittedError('No n_estimators found. Need to call fit beforehand.')
|
||||
return self._Booster.current_iteration()
|
||||
|
||||
@property
|
||||
def n_iter_(self) -> int:
|
||||
""":obj:`int`: True number of boosting iterations performed.
|
||||
|
||||
This might be less than parameter ``n_estimators`` if early stopping was enabled or
|
||||
if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
|
||||
"""
|
||||
if not self.__sklearn_is_fitted__():
|
||||
raise LGBMNotFittedError('No n_iter found. Need to call fit beforehand.')
|
||||
return self._Booster.current_iteration()
|
||||
|
||||
@property
|
||||
def booster_(self):
|
||||
"""Booster: The underlying Booster of this model."""
|
||||
|
|
|
@ -1158,6 +1158,17 @@ def test_continue_training_with_model():
|
|||
assert gbm.evals_result_['valid_0']['multi_logloss'][-1] < init_gbm.evals_result_['valid_0']['multi_logloss'][-1]
|
||||
|
||||
|
||||
def test_actual_number_of_trees():
|
||||
X = [[1, 2, 3], [1, 2, 3]]
|
||||
y = [1, 1]
|
||||
n_estimators = 5
|
||||
gbm = lgb.LGBMRegressor(n_estimators=n_estimators).fit(X, y)
|
||||
assert gbm.n_estimators == n_estimators
|
||||
assert gbm.n_estimators_ == 1
|
||||
assert gbm.n_iter_ == 1
|
||||
np.testing.assert_array_equal(gbm.predict(np.array(X) * 10), y)
|
||||
|
||||
|
||||
# sklearn < 0.22 requires passing "attributes" argument
|
||||
@pytest.mark.skipif(sk_version < parse_version('0.22'), reason='scikit-learn version is less than 0.22')
|
||||
def test_check_is_fitted():
|
||||
|
|
Загрузка…
Ссылка в новой задаче