зеркало из https://github.com/microsoft/LightGBM.git
move pandas support into basic.py
This commit is contained in:
Родитель
b12f99682e
Коммит
c67d289086
|
@ -358,6 +358,39 @@ class Predictor(object):
|
|||
raise ValueError("incorrect number for predict result")
|
||||
return preds, nrow
|
||||
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame
|
||||
except ImportError:
|
||||
class DataFrame(object):
|
||||
pass
|
||||
|
||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
|
||||
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
|
||||
'float16': 'float', 'float32': 'float', 'float64': 'float',
|
||||
'bool': 'i'}
|
||||
|
||||
def _data_from_pandas(data):
|
||||
if isinstance(data, DataFrame):
|
||||
data_dtypes = data.dtypes
|
||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
|
||||
bad_fields = [data.columns[i] for i, dtype in
|
||||
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
|
||||
|
||||
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
data = data.values.astype('float')
|
||||
return data
|
||||
|
||||
def _label_from_pandas(label):
|
||||
if isinstance(label, DataFrame):
|
||||
if len(label.columns) > 1:
|
||||
raise ValueError('DataFrame for label cannot have multiple columns')
|
||||
label_dtypes = label.dtypes
|
||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
|
||||
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
|
||||
label = label.values.astype('float')
|
||||
return label
|
||||
|
||||
class Dataset(object):
|
||||
"""Dataset used in LightGBM.
|
||||
|
@ -398,6 +431,8 @@ class Dataset(object):
|
|||
if data is None:
|
||||
self.handle = None
|
||||
return
|
||||
data = _data_from_pandas(data)
|
||||
label = _label_from_pandas(label)
|
||||
self.data_has_header = False
|
||||
"""process for args"""
|
||||
params = {} if params is None else params
|
||||
|
|
|
@ -6,40 +6,6 @@ import numpy as np
|
|||
from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
|
||||
from . import callback
|
||||
|
||||
# pandas
|
||||
try:
|
||||
from pandas import DataFrame
|
||||
except ImportError:
|
||||
class DataFrame(object):
|
||||
pass
|
||||
|
||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
|
||||
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
|
||||
'float16': 'float', 'float32': 'float', 'float64': 'float',
|
||||
'bool': 'i'}
|
||||
|
||||
def _data_from_pandas(data):
|
||||
if isinstance(data, DataFrame):
|
||||
data_dtypes = data.dtypes
|
||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
|
||||
bad_fields = [data.columns[i] for i, dtype in
|
||||
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
|
||||
|
||||
msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """
|
||||
raise ValueError(msg + ', '.join(bad_fields))
|
||||
data = data.values.astype('float')
|
||||
return data
|
||||
|
||||
def _label_from_pandas(label):
|
||||
if isinstance(label, DataFrame):
|
||||
if len(label.columns) > 1:
|
||||
raise ValueError('DataFrame for label cannot have multiple columns')
|
||||
label_dtypes = label.dtypes
|
||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
|
||||
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
|
||||
label = label.values.astype('float')
|
||||
return label
|
||||
|
||||
def _construct_dataset(X_y, reference=None,
|
||||
params=None, other_fields=None, predictor=None):
|
||||
if 'max_bin' in params:
|
||||
|
@ -61,8 +27,8 @@ def _construct_dataset(X_y, reference=None,
|
|||
else:
|
||||
if len(X_y) != 2:
|
||||
raise TypeError("should pass (data, label) pair")
|
||||
data = _data_from_pandas(X_y[0])
|
||||
label = _label_from_pandas(X_y[1])
|
||||
data = X_y[0]
|
||||
label = X_y[1]
|
||||
if reference is None:
|
||||
ret = Dataset(data, label=label, max_bin=max_bin,
|
||||
weight=weight, group=group, predictor=predictor, params=params)
|
||||
|
|
|
@ -54,7 +54,7 @@ def test_regression():
|
|||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
|
||||
preds = lgb_model.predict(x_test)
|
||||
assert mean_squared_error(preds, y_test) < 30
|
||||
assert mean_squared_error(preds, y_test) < 40
|
||||
|
||||
def test_regression_with_custom_objective():
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
@ -71,7 +71,7 @@ def test_regression_with_custom_objective():
|
|||
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1)
|
||||
lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2')
|
||||
preds = lgb_model.predict(x_test)
|
||||
assert mean_squared_error(preds, y_test) < 30
|
||||
assert mean_squared_error(preds, y_test) < 40
|
||||
|
||||
|
||||
def test_binary_classification_with_custom_objective():
|
||||
|
|
Загрузка…
Ссылка в новой задаче