зеркало из https://github.com/microsoft/LightGBM.git
use json instead of repr/eval for pandas_categorical (#247)
* use json instead of repr/eval for pandas_categorical * fix json dumps with numpy data * add more test cases
This commit is contained in:
Родитель
9c5dbdde5c
Коммит
a4a0235d17
|
@ -11,8 +11,9 @@ from tempfile import NamedTemporaryFile
|
|||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .compat import (DataFrame, Series, integer_types, json, numeric_types,
|
||||
range_, string_type)
|
||||
from .compat import (DataFrame, Series, integer_types, json,
|
||||
json_default_with_numpy, numeric_types, range_,
|
||||
string_type)
|
||||
from .libpath import find_lib_path
|
||||
|
||||
|
||||
|
@ -271,6 +272,19 @@ def _label_from_pandas(label):
|
|||
return label
|
||||
|
||||
|
||||
def _save_pandas_categorical(file_name, pandas_categorical):
|
||||
with open(file_name, 'a') as f:
|
||||
f.write('\npandas_categorical:' + json.dumps(pandas_categorical, default=json_default_with_numpy))
|
||||
|
||||
|
||||
def _load_pandas_categorical(file_name):
|
||||
with open(file_name, 'r') as f:
|
||||
last_line = f.readlines()[-1]
|
||||
if last_line.startswith('pandas_categorical:'):
|
||||
return json.loads(last_line[len('pandas_categorical:'):])
|
||||
return None
|
||||
|
||||
|
||||
class _InnerPredictor(object):
|
||||
"""
|
||||
A _InnerPredictor of LightGBM.
|
||||
|
@ -302,12 +316,7 @@ class _InnerPredictor(object):
|
|||
ctypes.byref(out_num_class)))
|
||||
self.num_class = out_num_class.value
|
||||
self.num_total_iteration = out_num_iterations.value
|
||||
with open(model_file, 'r') as f:
|
||||
last_line = f.readlines()[-1]
|
||||
if last_line.startswith('pandas_categorical:'):
|
||||
self.pandas_categorical = eval(last_line[len('pandas_categorical:'):])
|
||||
else:
|
||||
self.pandas_categorical = None
|
||||
self.pandas_categorical = _load_pandas_categorical(model_file)
|
||||
elif booster_handle is not None:
|
||||
self.__is_manage_handle = False
|
||||
self.handle = booster_handle
|
||||
|
@ -1207,12 +1216,7 @@ class Booster(object):
|
|||
self.handle,
|
||||
ctypes.byref(out_num_class)))
|
||||
self.__num_class = out_num_class.value
|
||||
with open(model_file, 'r') as f:
|
||||
last_line = f.readlines()[-1]
|
||||
if last_line.startswith('pandas_categorical:'):
|
||||
self.pandas_categorical = eval(last_line[len('pandas_categorical:'):])
|
||||
else:
|
||||
self.pandas_categorical = None
|
||||
self.pandas_categorical = _load_pandas_categorical(model_file)
|
||||
elif 'model_str' in params:
|
||||
self.__load_model_from_string(params['model_str'])
|
||||
else:
|
||||
|
@ -1468,8 +1472,7 @@ class Booster(object):
|
|||
self.handle,
|
||||
ctypes.c_int(num_iteration),
|
||||
c_str(filename)))
|
||||
with open(filename, 'a') as f:
|
||||
f.write('\npandas_categorical:' + repr(self.pandas_categorical))
|
||||
_save_pandas_categorical(filename, self.pandas_categorical)
|
||||
|
||||
def __load_model_from_string(self, model_str):
|
||||
"""[Private] Load model from string"""
|
||||
|
|
|
@ -6,6 +6,8 @@ from __future__ import absolute_import
|
|||
import inspect
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
is_py3 = (sys.version_info[0] == 3)
|
||||
|
||||
"""compatibility between python2 and python3"""
|
||||
|
@ -36,6 +38,16 @@ except (ImportError, SyntaxError):
|
|||
# because of u'...' Unicode literals.
|
||||
import json
|
||||
|
||||
|
||||
def json_default_with_numpy(obj):
|
||||
if isinstance(obj, (np.integer, np.floating, np.bool_)):
|
||||
return obj.item()
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
"""pandas"""
|
||||
try:
|
||||
from pandas import Series, DataFrame
|
||||
|
|
|
@ -146,15 +146,18 @@ class TestEngine(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed')
|
||||
def test_pandas_categorical(self):
|
||||
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75),
|
||||
"B": np.random.permutation([1, 2, 3] * 100)})
|
||||
X["A"] = X["A"].astype('category')
|
||||
X["B"] = X["B"].astype('category')
|
||||
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
|
||||
"B": np.random.permutation([1, 2, 3] * 100), # int
|
||||
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
|
||||
"D": np.random.permutation([True, False] * 150)}) # bool
|
||||
y = np.random.permutation([0, 1] * 150)
|
||||
X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20),
|
||||
"B": np.random.permutation([1, 3] * 30)})
|
||||
X_test["A"] = X_test["A"].astype('category')
|
||||
X_test["B"] = X_test["B"].astype('category')
|
||||
"B": np.random.permutation([1, 3] * 30),
|
||||
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
|
||||
"D": np.random.permutation([True, False] * 30)})
|
||||
for col in ["A", "B", "C", "D"]:
|
||||
X[col] = X[col].astype('category')
|
||||
X_test[col] = X_test[col].astype('category')
|
||||
params = {
|
||||
'objective': 'binary',
|
||||
'metric': 'binary_logloss',
|
||||
|
@ -173,7 +176,7 @@ class TestEngine(unittest.TestCase):
|
|||
pred2 = list(gbm2.predict(X_test))
|
||||
lgb_train = lgb.Dataset(X, y)
|
||||
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
|
||||
categorical_feature=['A', 'B'])
|
||||
categorical_feature=['A', 'B', 'C', 'D'])
|
||||
pred3 = list(gbm3.predict(X_test))
|
||||
lgb_train = lgb.Dataset(X, y)
|
||||
gbm3.save_model('categorical.model')
|
||||
|
|
Загрузка…
Ссылка в новой задаче