diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 3b4db41c5..b985292dc 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -11,8 +11,9 @@ from tempfile import NamedTemporaryFile import numpy as np import scipy.sparse -from .compat import (DataFrame, Series, integer_types, json, numeric_types, - range_, string_type) +from .compat import (DataFrame, Series, integer_types, json, + json_default_with_numpy, numeric_types, range_, + string_type) from .libpath import find_lib_path @@ -271,6 +272,19 @@ def _label_from_pandas(label): return label +def _save_pandas_categorical(file_name, pandas_categorical): + with open(file_name, 'a') as f: + f.write('\npandas_categorical:' + json.dumps(pandas_categorical, default=json_default_with_numpy)) + + +def _load_pandas_categorical(file_name): + with open(file_name, 'r') as f: + last_line = f.readlines()[-1] + if last_line.startswith('pandas_categorical:'): + return json.loads(last_line[len('pandas_categorical:'):]) + return None + + class _InnerPredictor(object): """ A _InnerPredictor of LightGBM. @@ -302,12 +316,7 @@ class _InnerPredictor(object): ctypes.byref(out_num_class))) self.num_class = out_num_class.value self.num_total_iteration = out_num_iterations.value - with open(model_file, 'r') as f: - last_line = f.readlines()[-1] - if last_line.startswith('pandas_categorical:'): - self.pandas_categorical = eval(last_line[len('pandas_categorical:'):]) - else: - self.pandas_categorical = None + self.pandas_categorical = _load_pandas_categorical(model_file) elif booster_handle is not None: self.__is_manage_handle = False self.handle = booster_handle @@ -1207,12 +1216,7 @@ class Booster(object): self.handle, ctypes.byref(out_num_class))) self.__num_class = out_num_class.value - with open(model_file, 'r') as f: - last_line = f.readlines()[-1] - if last_line.startswith('pandas_categorical:'): - self.pandas_categorical = eval(last_line[len('pandas_categorical:'):]) - else: - self.pandas_categorical = None + self.pandas_categorical = _load_pandas_categorical(model_file) elif 'model_str' in params: self.__load_model_from_string(params['model_str']) else: @@ -1468,8 +1472,7 @@ class Booster(object): self.handle, ctypes.c_int(num_iteration), c_str(filename))) - with open(filename, 'a') as f: - f.write('\npandas_categorical:' + repr(self.pandas_categorical)) + _save_pandas_categorical(filename, self.pandas_categorical) def __load_model_from_string(self, model_str): """[Private] Load model from string""" diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 19fe82dee..26af67e43 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -6,6 +6,8 @@ from __future__ import absolute_import import inspect import sys +import numpy as np + is_py3 = (sys.version_info[0] == 3) """compatibility between python2 and python3""" @@ -36,6 +38,16 @@ except (ImportError, SyntaxError): # because of u'...' Unicode literals. import json + +def json_default_with_numpy(obj): + if isinstance(obj, (np.integer, np.floating, np.bool_)): + return obj.item() + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return obj + + """pandas""" try: from pandas import Series, DataFrame diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index ffc160be4..0a275e0d3 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -146,15 +146,18 @@ class TestEngine(unittest.TestCase): @unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed') def test_pandas_categorical(self): - X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), - "B": np.random.permutation([1, 2, 3] * 100)}) - X["A"] = X["A"].astype('category') - X["B"] = X["B"].astype('category') + X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str + "B": np.random.permutation([1, 2, 3] * 100), # int + "C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float + "D": np.random.permutation([True, False] * 150)}) # bool y = np.random.permutation([0, 1] * 150) X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20), - "B": np.random.permutation([1, 3] * 30)}) - X_test["A"] = X_test["A"].astype('category') - X_test["B"] = X_test["B"].astype('category') + "B": np.random.permutation([1, 3] * 30), + "C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15), + "D": np.random.permutation([True, False] * 30)}) + for col in ["A", "B", "C", "D"]: + X[col] = X[col].astype('category') + X_test[col] = X_test[col].astype('category') params = { 'objective': 'binary', 'metric': 'binary_logloss', @@ -173,7 +176,7 @@ class TestEngine(unittest.TestCase): pred2 = list(gbm2.predict(X_test)) lgb_train = lgb.Dataset(X, y) gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False, - categorical_feature=['A', 'B']) + categorical_feature=['A', 'B', 'C', 'D']) pred3 = list(gbm3.predict(X_test)) lgb_train = lgb.Dataset(X, y) gbm3.save_model('categorical.model')