* fixed class_weight

* fixed lint

* added test

* hotfix
This commit is contained in:
Nikita Titov 2019-06-04 08:44:26 +03:00 коммит произвёл Qiwei Ye
Родитель 7d03ced388
Коммит b6f6578368
2 изменённых файлов: 55 добавлений и 12 удалений

Просмотреть файл

@ -10,7 +10,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame, DataTable)
argc_, range_, zip_, string_type, DataFrame, DataTable)
from .engine import train
@ -320,6 +320,8 @@ class LGBMModel(_LGBMModelBase):
self._other_params = {}
self._objective = objective
self.class_weight = class_weight
self._class_weight = None
self._class_map = None
self._n_features = None
self._classes = None
self._n_classes = None
@ -529,8 +531,10 @@ class LGBMModel(_LGBMModelBase):
else:
_X, _y = X, y
if self.class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if self._class_weight is None:
self._class_weight = self.class_weight
if self._class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self._class_weight, y)
if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight
else:
@ -547,7 +551,7 @@ class LGBMModel(_LGBMModelBase):
valid_sets = []
if eval_set is not None:
def _get_meta_data(collection, i):
def _get_meta_data(collection, name, i):
if collection is None:
return None
elif isinstance(collection, list):
@ -555,8 +559,7 @@ class LGBMModel(_LGBMModelBase):
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group '
'should be dict or list')
raise TypeError('{} should be dict or list'.format(name))
if isinstance(eval_set, tuple):
eval_set = [eval_set]
@ -565,16 +568,18 @@ class LGBMModel(_LGBMModelBase):
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _get_meta_data(eval_sample_weight, i)
if _get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i),
valid_data[1])
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, i)
valid_group = _get_meta_data(eval_group, i)
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
valid_group = _get_meta_data(eval_group, 'eval_group', i)
valid_set = _construct_dataset(valid_data[0], valid_data[1],
valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set)
@ -750,6 +755,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
_LGBMCheckClassificationTargets(y)
self._le = _LGBMLabelEncoder().fit(y)
_y = self._le.transform(y)
self._class_map = dict(zip_(self._le.classes_, self._le.transform(self._le.classes_)))
if isinstance(self.class_weight, dict):
self._class_weight = {self._class_map[k]: v for k, v in self.class_weight.items()}
self._classes = self._le.classes_
self._n_classes = len(self._classes)

Просмотреть файл

@ -1,5 +1,6 @@
# coding: utf-8
# pylint: skip-file
import itertools
import math
import os
import unittest
@ -615,3 +616,37 @@ class TestSklearn(unittest.TestCase):
'verbose': False, 'early_stopping_rounds': 5}
gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
np.testing.assert_array_equal(gbm.evals_result_['training']['l2'], np.nan)
def test_class_weight(self):
X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
y_train_str = y_train.astype('str')
y_test_str = y_test.astype('str')
gbm = lgb.LGBMClassifier(n_estimators=10, class_weight='balanced', silent=True)
gbm.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test), (X_test, y_test),
(X_test, y_test), (X_test, y_test)],
eval_class_weight=['balanced', None, 'balanced', {1: 10, 4: 20}, {5: 30, 2: 40}],
verbose=False)
for eval_set1, eval_set2 in itertools.combinations(gbm.evals_result_.keys(), 2):
for metric in gbm.evals_result_[eval_set1]:
np.testing.assert_raises(AssertionError,
np.testing.assert_allclose,
gbm.evals_result_[eval_set1][metric],
gbm.evals_result_[eval_set2][metric])
gbm_str = lgb.LGBMClassifier(n_estimators=10, class_weight='balanced', silent=True)
gbm_str.fit(X_train, y_train_str,
eval_set=[(X_train, y_train_str), (X_test, y_test_str),
(X_test, y_test_str), (X_test, y_test_str), (X_test, y_test_str)],
eval_class_weight=['balanced', None, 'balanced', {'1': 10, '4': 20}, {'5': 30, '2': 40}],
verbose=False)
for eval_set1, eval_set2 in itertools.combinations(gbm_str.evals_result_.keys(), 2):
for metric in gbm_str.evals_result_[eval_set1]:
np.testing.assert_raises(AssertionError,
np.testing.assert_allclose,
gbm_str.evals_result_[eval_set1][metric],
gbm_str.evals_result_[eval_set2][metric])
for eval_set in gbm.evals_result_:
for metric in gbm.evals_result_[eval_set]:
np.testing.assert_allclose(gbm.evals_result_[eval_set][metric],
gbm_str.evals_result_[eval_set][metric])