зеркало из https://github.com/microsoft/LightGBM.git
[python] fix class_weight (#2199)
* fixed class_weight * fixed lint * added test * hotfix
This commit is contained in:
Родитель
7d03ced388
Коммит
b6f6578368
|
@ -10,7 +10,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
|
|||
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
|
||||
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
|
||||
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
|
||||
argc_, range_, string_type, DataFrame, DataTable)
|
||||
argc_, range_, zip_, string_type, DataFrame, DataTable)
|
||||
from .engine import train
|
||||
|
||||
|
||||
|
@ -320,6 +320,8 @@ class LGBMModel(_LGBMModelBase):
|
|||
self._other_params = {}
|
||||
self._objective = objective
|
||||
self.class_weight = class_weight
|
||||
self._class_weight = None
|
||||
self._class_map = None
|
||||
self._n_features = None
|
||||
self._classes = None
|
||||
self._n_classes = None
|
||||
|
@ -529,8 +531,10 @@ class LGBMModel(_LGBMModelBase):
|
|||
else:
|
||||
_X, _y = X, y
|
||||
|
||||
if self.class_weight is not None:
|
||||
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
|
||||
if self._class_weight is None:
|
||||
self._class_weight = self.class_weight
|
||||
if self._class_weight is not None:
|
||||
class_sample_weight = _LGBMComputeSampleWeight(self._class_weight, y)
|
||||
if sample_weight is None or len(sample_weight) == 0:
|
||||
sample_weight = class_sample_weight
|
||||
else:
|
||||
|
@ -547,7 +551,7 @@ class LGBMModel(_LGBMModelBase):
|
|||
valid_sets = []
|
||||
if eval_set is not None:
|
||||
|
||||
def _get_meta_data(collection, i):
|
||||
def _get_meta_data(collection, name, i):
|
||||
if collection is None:
|
||||
return None
|
||||
elif isinstance(collection, list):
|
||||
|
@ -555,8 +559,7 @@ class LGBMModel(_LGBMModelBase):
|
|||
elif isinstance(collection, dict):
|
||||
return collection.get(i, None)
|
||||
else:
|
||||
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group '
|
||||
'should be dict or list')
|
||||
raise TypeError('{} should be dict or list'.format(name))
|
||||
|
||||
if isinstance(eval_set, tuple):
|
||||
eval_set = [eval_set]
|
||||
|
@ -565,16 +568,18 @@ class LGBMModel(_LGBMModelBase):
|
|||
if valid_data[0] is X and valid_data[1] is y:
|
||||
valid_set = train_set
|
||||
else:
|
||||
valid_weight = _get_meta_data(eval_sample_weight, i)
|
||||
if _get_meta_data(eval_class_weight, i) is not None:
|
||||
valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i),
|
||||
valid_data[1])
|
||||
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
|
||||
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
|
||||
if valid_class_weight is not None:
|
||||
if isinstance(valid_class_weight, dict) and self._class_map is not None:
|
||||
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
|
||||
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
|
||||
if valid_weight is None or len(valid_weight) == 0:
|
||||
valid_weight = valid_class_sample_weight
|
||||
else:
|
||||
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
|
||||
valid_init_score = _get_meta_data(eval_init_score, i)
|
||||
valid_group = _get_meta_data(eval_group, i)
|
||||
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
|
||||
valid_group = _get_meta_data(eval_group, 'eval_group', i)
|
||||
valid_set = _construct_dataset(valid_data[0], valid_data[1],
|
||||
valid_weight, valid_init_score, valid_group, params)
|
||||
valid_sets.append(valid_set)
|
||||
|
@ -750,6 +755,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
|
|||
_LGBMCheckClassificationTargets(y)
|
||||
self._le = _LGBMLabelEncoder().fit(y)
|
||||
_y = self._le.transform(y)
|
||||
self._class_map = dict(zip_(self._le.classes_, self._le.transform(self._le.classes_)))
|
||||
if isinstance(self.class_weight, dict):
|
||||
self._class_weight = {self._class_map[k]: v for k, v in self.class_weight.items()}
|
||||
|
||||
self._classes = self._le.classes_
|
||||
self._n_classes = len(self._classes)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# coding: utf-8
|
||||
# pylint: skip-file
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
|
@ -615,3 +616,37 @@ class TestSklearn(unittest.TestCase):
|
|||
'verbose': False, 'early_stopping_rounds': 5}
|
||||
gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
|
||||
np.testing.assert_array_equal(gbm.evals_result_['training']['l2'], np.nan)
|
||||
|
||||
def test_class_weight(self):
|
||||
X, y = load_digits(10, True)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
y_train_str = y_train.astype('str')
|
||||
y_test_str = y_test.astype('str')
|
||||
gbm = lgb.LGBMClassifier(n_estimators=10, class_weight='balanced', silent=True)
|
||||
gbm.fit(X_train, y_train,
|
||||
eval_set=[(X_train, y_train), (X_test, y_test), (X_test, y_test),
|
||||
(X_test, y_test), (X_test, y_test)],
|
||||
eval_class_weight=['balanced', None, 'balanced', {1: 10, 4: 20}, {5: 30, 2: 40}],
|
||||
verbose=False)
|
||||
for eval_set1, eval_set2 in itertools.combinations(gbm.evals_result_.keys(), 2):
|
||||
for metric in gbm.evals_result_[eval_set1]:
|
||||
np.testing.assert_raises(AssertionError,
|
||||
np.testing.assert_allclose,
|
||||
gbm.evals_result_[eval_set1][metric],
|
||||
gbm.evals_result_[eval_set2][metric])
|
||||
gbm_str = lgb.LGBMClassifier(n_estimators=10, class_weight='balanced', silent=True)
|
||||
gbm_str.fit(X_train, y_train_str,
|
||||
eval_set=[(X_train, y_train_str), (X_test, y_test_str),
|
||||
(X_test, y_test_str), (X_test, y_test_str), (X_test, y_test_str)],
|
||||
eval_class_weight=['balanced', None, 'balanced', {'1': 10, '4': 20}, {'5': 30, '2': 40}],
|
||||
verbose=False)
|
||||
for eval_set1, eval_set2 in itertools.combinations(gbm_str.evals_result_.keys(), 2):
|
||||
for metric in gbm_str.evals_result_[eval_set1]:
|
||||
np.testing.assert_raises(AssertionError,
|
||||
np.testing.assert_allclose,
|
||||
gbm_str.evals_result_[eval_set1][metric],
|
||||
gbm_str.evals_result_[eval_set2][metric])
|
||||
for eval_set in gbm.evals_result_:
|
||||
for metric in gbm.evals_result_[eval_set]:
|
||||
np.testing.assert_allclose(gbm.evals_result_[eval_set][metric],
|
||||
gbm_str.evals_result_[eval_set][metric])
|
||||
|
|
Загрузка…
Ссылка в новой задаче