Provide a high level Dataset class for easy use.
This commit is contained in:
Guolin Ke 2016-12-08 21:36:11 +08:00 коммит произвёл GitHub
Родитель f3d33582ec
Коммит b51c7be43e
12 изменённых файлов: 684 добавлений и 257 удалений

Просмотреть файл

@ -14,7 +14,7 @@ before_install:
install:
- sudo apt-get install -y libopenmpi-dev openmpi-bin build-essential
- conda install --yes atlas numpy scipy scikit-learn
- conda install --yes atlas numpy scipy scikit-learn pandas
script:
@ -22,12 +22,12 @@ script:
- mkdir build && cd build && cmake .. && make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_sklearn.py
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py
- cd $TRAVIS_BUILD_DIR
- rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_sklearn.py
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py
notifications:
email: false

Просмотреть файл

@ -76,7 +76,7 @@ add_executable(lightgbm src/main.cpp ${SOURCES})
add_library(_lightgbm SHARED src/c_api.cpp ${SOURCES})
if(MSVC)
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lightgbm")
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm")
endif(MSVC)
if(USE_MPI)

Просмотреть файл

@ -17,13 +17,7 @@ X_test = df_test.drop(0, axis=1)
# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
# ATTENTION: you should carefully use lightgbm.Dataset
# it requires setting up categorical_feature when you init it
# rather than passing from lightgbm.train
# instead, you can simply use a tuple of length=2 like below
# it will help you construct Datasets with parameters in lightgbm.train
lgb_train = (X_train, y_train)
lgb_eval = (X_test, y_test)
# specify your configurations as a dict
params = {
@ -43,9 +37,7 @@ params = {
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_datas=lgb_eval,
# you can use a list to represent multiple valid_datas/valid_names
# don't use tuple, tuple is used to represent one dataset
valid_sets=lgb_eval,
early_stopping_rounds=10)
# save model to file

Просмотреть файл

@ -230,6 +230,7 @@ struct OverallConfig: public ConfigBase {
public:
TaskType task_type = TaskType::kTrain;
NetworkConfig network_config;
int seed = 0;
int num_threads = 0;
bool is_parallel = false;
bool is_parallel_find_bin = false;
@ -317,6 +318,7 @@ struct ParameterAlias {
{
{ "config", "config_file" },
{ "nthread", "num_threads" },
{ "random_seed", "seed" },
{ "num_thread", "num_threads" },
{ "boosting", "boosting_type" },
{ "boost", "boosting_type" },

Просмотреть файл

@ -8,7 +8,7 @@ from __future__ import absolute_import
import os
from .basic import Predictor, Dataset, Booster
from .basic import Dataset, Booster
from .engine import train, cv
try:
from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker

Просмотреть файл

@ -14,7 +14,7 @@ import scipy.sparse
from .libpath import find_lib_path
# pandas
"""pandas"""
try:
from pandas import Series, DataFrame
IS_PANDAS_INSTALLED = True
@ -53,22 +53,27 @@ def _safe_call(ret):
raise LightGBMError(_LIB.LGBM_GetLastError())
def is_str(s):
"""Check is a str or not"""
if IS_PY3:
return isinstance(s, str)
else:
return isinstance(s, basestring)
def is_numpy_object(data):
"""Check is numpy object"""
return type(data).__module__ == np.__name__
def is_numpy_1d_array(data):
"""Check is 1d numpy array"""
return isinstance(data, np.ndarray) and len(data.shape) == 1
def is_1d_list(data):
"""Check is 1d list"""
return isinstance(data, list) and \
(not data or isinstance(data[0], (int, float, bool)))
def list_to_1d_numpy(data, dtype):
"""convert to 1d numpy array"""
if is_numpy_1d_array(data):
if data.dtype == dtype:
return data
@ -112,7 +117,7 @@ def param_dict_to_str(data):
return ""
pairs = []
for key, val in data.items():
if is_str(val) or isinstance(val, (int, float, bool)):
if is_str(val) or isinstance(val, (int, float, bool, np.integer, np.float, np.float32)):
pairs.append(str(key)+'='+str(val))
elif isinstance(val, (list, tuple, set)):
pairs.append(str(key)+'='+','.join(map(str, val)))
@ -126,20 +131,23 @@ C_API_DTYPE_FLOAT32 = 0
C_API_DTYPE_FLOAT64 = 1
C_API_DTYPE_INT32 = 2
C_API_DTYPE_INT64 = 3
"""Matric is row major in python"""
C_API_IS_ROW_MAJOR = 1
"""marco definition of prediction type in c_api of LightGBM"""
C_API_PREDICT_NORMAL = 0
C_API_PREDICT_RAW_SCORE = 1
C_API_PREDICT_LEAF_INDEX = 2
"""data type of data field"""
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
"weight": C_API_DTYPE_FLOAT32,
"init_score": C_API_DTYPE_FLOAT32,
"group": C_API_DTYPE_INT32}
def c_float_array(data):
"""Convert numpy array / list to c float array."""
"""get pointer of float numpy array / list"""
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
@ -157,7 +165,7 @@ def c_float_array(data):
return (ptr_data, type_data)
def c_int_array(data):
"""Convert numpy array to c int array."""
"""get pointer of int numpy array / list"""
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
@ -174,16 +182,21 @@ def c_int_array(data):
raise TypeError("Unknow type({})".format(type(data).__name__))
return (ptr_data, type_data)
class Predictor(object):
""""A Predictor of LightGBM.
class _InnerPredictor(object):
"""
def __init__(self, model_file=None, booster_handle=None, is_manage_handle=True):
"""Initialize the Predictor.
A _InnerPredictor of LightGBM.
Only used for prediction, usually used for continued-train
Note: Can convert from Booster, but cannot convert to Booster
"""
def __init__(self, model_file=None, booster_handle=None):
"""Initialize the _InnerPredictor. Not expose to user
Parameters
----------
model_file : string
Path to the model file.
booster_handle : Handle of Booster
use handle to init
"""
self.handle = ctypes.c_void_p()
self.__is_manage_handle = True
@ -201,7 +214,7 @@ class Predictor(object):
self.num_class = out_num_class.value
self.num_total_iteration = out_num_iterations.value
elif booster_handle is not None:
self.__is_manage_handle = is_manage_handle
self.__is_manage_handle = False
self.handle = booster_handle
out_num_class = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
@ -214,7 +227,7 @@ class Predictor(object):
ctypes.byref(out_num_iterations)))
self.num_total_iteration = out_num_iterations.value
else:
raise TypeError('Need Model file to create a booster')
raise TypeError('Need Model file or Booster handle to create a predictor')
def __del__(self):
if self.__is_manage_handle:
@ -239,7 +252,7 @@ class Predictor(object):
pred_leaf : bool
True for predict leaf index
data_has_header : bool
Used for txt data
Used for txt data, True if txt data has header
is_reshape : bool
Reshape to (nrow, ncol) if true
@ -247,7 +260,7 @@ class Predictor(object):
-------
Prediction result
"""
if isinstance(data, Dataset):
if isinstance(data, (_InnerDataset, Dataset)):
raise TypeError("cannot use Dataset instance for prediction, please use raw data instead")
predict_type = C_API_PREDICT_NORMAL
if raw_score:
@ -299,6 +312,9 @@ class Predictor(object):
return preds
def __get_num_preds(self, num_iteration, nrow, predict_type):
"""
Get size of prediction result
"""
n_preds = self.num_class * nrow
if predict_type == C_API_PREDICT_LEAF_INDEX:
if num_iteration > 0:
@ -398,10 +414,10 @@ def _label_from_pandas(label):
label = label.values.astype('float')
return label
class Dataset(object):
"""Dataset used in LightGBM.
Dataset is a internal data structure that used by LightGBM
class _InnerDataset(object):
"""_InnerDataset used in LightGBM.
_InnerDataset is a internal data structure that used by LightGBM.
This class is not exposed. Please use Dataset instead
"""
def __init__(self, data, label=None, max_bin=255, reference=None,
@ -409,23 +425,25 @@ class Dataset(object):
silent=False, feature_name=None,
categorical_feature=None, params=None):
"""
Dataset used in LightGBM.
_InnerDataset used in LightGBM.
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of Dataset.
Data source of _InnerDataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the data
max_bin : int, required
Max number of discrete bin for features
reference : Other Dataset, optional
reference : Other _InnerDataset, optional
If this dataset validation, need to use training data as reference
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
predictor : _InnerPredictor
Used for continuned train
silent : boolean, optional
Whether print messages during construction
feature_name : list of str
@ -436,10 +454,6 @@ class Dataset(object):
params: dict, optional
Other parameters
"""
self.__label = None
self.__weight = None
self.__init_score = None
self.__group = None
if data is None:
self.handle = None
return
@ -475,7 +489,7 @@ class Dataset(object):
params_str = param_dict_to_str(params)
"""process for reference dataset"""
ref_dataset = None
if isinstance(reference, Dataset):
if isinstance(reference, _InnerDataset):
ref_dataset = ctypes.byref(reference.handle)
elif reference is not None:
raise TypeError('Reference dataset should be None or dataset instance')
@ -500,7 +514,7 @@ class Dataset(object):
csr = scipy.sparse.csr_matrix(data)
self.__init_from_csr(csr, params_str, ref_dataset)
except:
raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__))
raise TypeError('can not initialize _InnerDataset from {}'.format(type(data).__name__))
if label is not None:
self.set_label(label)
if self.get_label() is None:
@ -510,7 +524,7 @@ class Dataset(object):
if group is not None:
self.set_group(group)
# load init score
if self.predictor is not None and isinstance(self.predictor, Predictor):
if isinstance(self.predictor, _InnerPredictor):
init_score = self.predictor.predict(data,
raw_score=True,
data_has_header=self.data_has_header,
@ -524,6 +538,8 @@ class Dataset(object):
new_init_score[j * num_data + i] = init_score[i * self.predictor.num_class + j]
init_score = new_init_score
self.set_init_score(init_score)
elif self.predictor is not None:
raise TypeError('wrong predictor type {}'.format(type(self.predictor).__name__))
# set feature names
self.set_feature_name(feature_name)
@ -535,7 +551,7 @@ class Dataset(object):
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of Dataset.
Data source of _InnerDataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the training data.
@ -548,16 +564,16 @@ class Dataset(object):
params: dict, optional
Other parameters
"""
return Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, predictor=self.predictor,
silent=silent, params=params)
return _InnerDataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, predictor=self.predictor,
silent=silent, params=params)
def subset(self, used_indices, params=None):
"""
Get subset of current dataset
"""
used_indices = list_to_1d_numpy(used_indices, np.int32)
ret = Dataset(None)
ret = _InnerDataset(None)
ret.handle = ctypes.c_void_p()
params_str = param_dict_to_str(params)
_safe_call(_LIB.LGBM_DatasetGetSubset(
@ -573,6 +589,9 @@ class Dataset(object):
return ret
def set_feature_name(self, feature_name):
"""
set feature names
"""
if feature_name is None:
return
if len(feature_name) != self.num_feature():
@ -636,7 +655,7 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
def get_field(self, field_name):
"""Get property from the Dataset.
"""Get property from the _InnerDataset.
Parameters
----------
@ -669,7 +688,7 @@ class Dataset(object):
raise TypeError("unknow type")
def set_field(self, field_name, data):
"""Set property into the Dataset.
"""Set property into the _InnerDataset.
Parameters
----------
@ -711,7 +730,7 @@ class Dataset(object):
type_data))
def save_binary(self, filename):
"""Save Dataset to binary file
"""Save _InnerDataset to binary file
Parameters
----------
@ -723,15 +742,14 @@ class Dataset(object):
c_str(filename)))
def set_label(self, label):
"""Set label of Dataset
"""Set label of _InnerDataset
Parameters
----------
label: numpy array or list or None
The label information to be set into Dataset
The label information to be set into _InnerDataset
"""
label = list_to_1d_numpy(label, np.float32)
self.__label = label
self.set_field('label', label)
def set_weight(self, weight):
@ -744,7 +762,6 @@ class Dataset(object):
"""
if weight is not None:
weight = list_to_1d_numpy(weight, np.float32)
self.__weight = weight
self.set_field('weight', weight)
def set_init_score(self, score):
@ -757,11 +774,10 @@ class Dataset(object):
"""
if score is not None:
score = list_to_1d_numpy(score, np.float32)
self.__init_score = score
self.set_field('init_score', score)
def set_group(self, group):
"""Set group size of Dataset (used for ranking).
"""Set group size of _InnerDataset (used for ranking).
Parameters
----------
@ -770,57 +786,46 @@ class Dataset(object):
"""
if group is not None:
group = list_to_1d_numpy(group, np.int32)
self.__group = group
self.set_field('group', group)
def get_label(self):
"""Get the label of the Dataset.
"""Get the label of the _InnerDataset.
Returns
-------
label : array
"""
if self.__label is None:
self.__label = self.get_field('label')
if self.__label is None:
raise TypeError("label should not be None")
return self.__label
return self.get_field('label')
def get_weight(self):
"""Get the weight of the Dataset.
"""Get the weight of the _InnerDataset.
Returns
-------
weight : array
"""
if self.__weight is None:
self.__weight = self.get_field('weight')
return self.__weight
return self.get_field('weight')
def get_init_score(self):
"""Get the initial score of the Dataset.
"""Get the initial score of the _InnerDataset.
Returns
-------
init_score : array
"""
if self.__init_score is None:
self.__init_score = self.get_field('init_score')
return self.__init_score
return self.get_field('init_score')
def get_group(self):
"""Get the initial score of the Dataset.
"""Get the initial score of the _InnerDataset.
Returns
-------
init_score : array
"""
if self.__group is None:
self.__group = self.get_field('group')
return self.__group
return self.get_field('group')
def num_data(self):
"""Get the number of rows in the Dataset.
"""Get the number of rows in the _InnerDataset.
Returns
-------
@ -832,7 +837,7 @@ class Dataset(object):
return ret.value
def num_feature(self):
"""Get the number of columns (features) in the Dataset.
"""Get the number of columns (features) in the _InnerDataset.
Returns
-------
@ -843,6 +848,326 @@ class Dataset(object):
ctypes.byref(ret)))
return ret.value
class Dataset(object):
"""High level Dataset used in LightGBM.
"""
def __init__(self, data, label=None, max_bin=255, reference=None,
weight=None, group=None, silent=False,
feature_name=None, categorical_feature=None, params=None,
free_raw_data=True):
"""
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of Dataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the data
max_bin : int, required
Max number of discrete bin for features
reference : Other Dataset, optional
If this dataset validation, need to use training data as reference
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
feature_name : list of str
Feature names
categorical_feature : list of str or int
Categorical features, type int represents index, \
type str represents feature names (need to specify feature_name as well)
params: dict, optional
Other parameters
free_raw_data: Bool
True if need to free raw data after construct inner dataset
"""
self.data = data
self.label = label
self.max_bin = max_bin
self.reference = reference
self.weight = weight
self.group = group
self.silent = silent
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.params = params
self.free_raw_data = free_raw_data
self.inner_dataset = None
self.used_indices = None
self._predictor = None
def create_valid(self, data, label=None, weight=None, group=None,
silent=False, params=None):
"""
Create validation data align with current dataset
Parameters
----------
data : string/numpy array/scipy.sparse
Data source of _InnerDataset.
When data type is string, it represents the path of txt file
label : list or numpy 1-D array, optional
Label of the training data.
weight : list or numpy 1-D array , optional
Weight for each instance.
group : list or numpy 1-D array , optional
Group/query size for dataset
silent : boolean, optional
Whether print messages during construction
params: dict, optional
Other parameters
"""
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group,
silent=silent, params=params, free_raw_data=self.free_raw_data)
ret._set_predictor(self._predictor)
return ret
def construct(self):
"""Lazy init"""
if self.inner_dataset is None:
if self.reference is not None:
if self.used_indices is None:
self.inner_dataset = self.reference._get_inner_dataset().create_valid(
self.data, self.label,
self.weight, self.group,
self.silent, self.params)
else:
"""construct subset"""
self.inner_dataset = self.reference._get_inner_dataset().subset(
self.used_indices, self.params)
else:
self.inner_dataset = _InnerDataset(self.data, self.label, self.max_bin,
None, self.weight, self.group, self._predictor,
self.silent, self.feature_name, self.categorical_feature, self.params)
if self.free_raw_data:
self.data = None
def _get_inner_dataset(self):
"""get inner dataset"""
self.construct()
return self.inner_dataset
def __is_constructed(self):
"""check inner_dataset is constructed or not"""
return self.inner_dataset is not None
def set_categorical_feature(self, categorical_feature):
"""
Set categorical features
Parameters
----------
categorical_feature : list of int or str
Name/index of categorical features
"""
if self.categorical_feature == categorical_feature:
return
if self.data is not None:
self.categorical_feature = categorical_feature
self.inner_dataset = None
else:
raise LightGBMError("Cannot set categorical feature after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.")
def _set_predictor(self, predictor):
"""
Set predictor for continued training, not recommand for user to call this function.
Please set init_model in engine.train or engine.cv
"""
if predictor is self._predictor:
return
if self.data is not None:
self._predictor = predictor
self.inner_dataset = None
else:
raise LightGBMError("Cannot set predictor after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.")
def set_reference(self, reference):
"""
Set reference dataset
Parameters
----------
reference : Dataset
will use reference as template to consturct current dataset
"""
self.set_categorical_feature(reference.categorical_feature)
self.set_feature_name(reference.feature_name)
self._set_predictor(reference._predictor)
if self.reference is reference:
return
if self.data is not None:
self.reference = reference
self.inner_dataset = None
else:
raise LightGBMError("Cannot set reference after freed raw data,\
Set free_raw_data=False when construct Dataset to avoid this.")
def set_feature_name(self, feature_name):
"""
Set feature name
Parameters
----------
feature_name : list of str
feature names
"""
self.feature_name = feature_name
if self.__is_constructed():
self.inner_dataset.set_feature_name(self.feature_name)
def subset(self, used_indices, params=None):
"""
Get subset of current dataset
Parameters
----------
used_indices : list of int
use indices of this subset
params : dict
other parameters
"""
ret = Dataset(None)
ret.feature_name = self.feature_name
ret.categorical_feature = self.categorical_feature
ret.reference = self
ret._predictor = self._predictor
ret.used_indices = used_indices
ret.params = params
return ret
def save_binary(self, filename):
"""Save Dataset to binary file
Parameters
----------
filename : string
Name of the output file.
"""
self._get_inner_dataset().save_binary(filename)
def set_label(self, label):
"""Set label of Dataset
Parameters
----------
label: numpy array or list or None
The label information to be set into Dataset
"""
self.label = label
if self.__is_constructed():
self.inner_dataset.set_label(self.label)
def set_weight(self, weight):
""" Set weight of each instance.
Parameters
----------
weight : numpy array or list or None
Weight for each data point
"""
self.weight = weight
if self.__is_constructed():
self.inner_dataset.set_weight(self.weight)
def set_init_score(self, init_score):
""" Set init score of booster to start from.
Parameters
----------
init_score: numpy array or list or None
Init score for booster
"""
self.init_score = init_score
if self.__is_constructed():
self.inner_dataset.set_init_score(self.init_score)
def set_group(self, group):
"""Set group size of Dataset (used for ranking).
Parameters
----------
group : numpy array or list or None
Group size of each group
"""
self.group = group
if self.__is_constructed():
self.inner_dataset.set_group(self.group)
def get_label(self):
"""Get the label of the Dataset.
Returns
-------
label : array
"""
if self.label is None and self.__is_constructed():
self.label = self.inner_dataset.get_label()
return self.label
def get_weight(self):
"""Get the weight of the Dataset.
Returns
-------
weight : array
"""
if self.weight is None and self.__is_constructed():
self.weight = self.inner_dataset.get_weight()
return self.weight
def get_init_score(self):
"""Get the initial score of the Dataset.
Returns
-------
init_score : array
"""
if self.init_score is None and self.__is_constructed():
self.init_score = self.inner_dataset.get_init_score()
return self.init_score
def get_group(self):
"""Get the initial score of the Dataset.
Returns
-------
init_score : array
"""
if self.group is None and self.__is_constructed():
self.group = self.inner_dataset.get_group()
return self.group
def num_data(self):
"""Get the number of rows in the Dataset.
Returns
-------
number of rows : int
"""
if self.__is_constructed():
return self.inner_dataset.num_data()
else:
raise LightGBMError("Cannot call num_data before construct, please call it explicitly")
def num_feature(self):
"""Get the number of columns (features) in the Dataset.
Returns
-------
number of columns : int
"""
if self.__is_constructed():
return self.inner_dataset.num_feature()
else:
raise LightGBMError("Cannot call num_feature before construct, please call it explicitly")
class Booster(object):
""""A Booster of LightGBM.
"""
@ -862,7 +1187,6 @@ class Booster(object):
"""
self.handle = ctypes.c_void_p()
self.__need_reload_eval_info = True
self.__is_manage_handle = True
self.__train_data_name = "training"
self.__attr = {}
self.best_iteration = -1
@ -878,7 +1202,7 @@ class Booster(object):
params_str = param_dict_to_str(params)
"""construct booster object"""
_safe_call(_LIB.LGBM_BoosterCreate(
train_set.handle,
train_set._get_inner_dataset().handle,
c_str(params_str),
ctypes.byref(self.handle)))
"""save reference to data"""
@ -886,11 +1210,11 @@ class Booster(object):
self.valid_sets = []
self.name_valid_sets = []
self.__num_dataset = 1
self.init_predictor = train_set.predictor
if self.init_predictor is not None:
self.__init_predictor = train_set._predictor
if self.__init_predictor is not None:
_safe_call(_LIB.LGBM_BoosterMerge(
self.handle,
self.init_predictor.handle))
self.__init_predictor.handle))
out_num_class = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
@ -916,7 +1240,7 @@ class Booster(object):
raise TypeError('At least need training dataset or model file to create booster instance')
def __del__(self):
if self.handle is not None and self.__is_manage_handle:
if self.handle is not None:
_safe_call(_LIB.LGBM_BoosterFree(self.handle))
def set_train_data_name(self, name):
@ -932,11 +1256,11 @@ class Booster(object):
name : String
Name of validation data
"""
if data.predictor is not self.init_predictor:
raise Exception("Add validation data failed, you should use same predictor for these data")
if data._predictor is not self.__init_predictor:
raise LightGBMError("Add validation data failed, you should use same predictor for these data")
_safe_call(_LIB.LGBM_BoosterAddValidData(
self.handle,
data.handle))
data._get_inner_dataset().handle))
self.valid_sets.append(data)
self.name_valid_sets.append(name)
self.__num_dataset += 1
@ -982,12 +1306,12 @@ class Booster(object):
"""need reset training data"""
if train_set is not None and train_set is not self.train_set:
if train_set.predictor is not self.init_predictor:
raise Exception("Replace training data failed, you should use same predictor for these data")
if train_set._predictor is not self.__init_predictor:
raise LightGBMError("Replace training data failed, you should use same predictor for these data")
self.train_set = train_set
_safe_call(_LIB.LGBM_BoosterResetTrainingData(
self.handle,
self.train_set.handle))
self.train_set._get_inner_dataset().handle))
self.__inner_predict_buffer[0] = None
is_finished = ctypes.c_int(0)
if fobj is None:
@ -1063,7 +1387,7 @@ class Booster(object):
Parameters
----------
data : Dataset object
data : _InnerDataset object
name :
Name of data
feval : function
@ -1073,8 +1397,8 @@ class Booster(object):
result: list
Evaluation result list.
"""
if not isinstance(data, Dataset):
raise TypeError("Can only eval for Dataset instance")
if not isinstance(data, _InnerDataset):
raise TypeError("Can only eval for _InnerDataset instance")
data_idx = -1
if data is self.train_set:
data_idx = 0
@ -1187,15 +1511,13 @@ class Booster(object):
-------
Prediction result
"""
predictor = Predictor(booster_handle=self.handle, is_manage_handle=False)
predictor = _InnerPredictor(booster_handle=self.handle)
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
def to_predictor(self):
def _to_predictor(self):
"""Convert to predictor
Note: Predictor will manage the handle after doing this
"""
predictor = Predictor(booster_handle=self.handle, is_manage_handle=True)
self.__is_manage_handle = False
predictor = _InnerPredictor(booster_handle=self.handle)
return predictor
def feature_importance(self, importance_type='split'):

Просмотреть файл

@ -6,52 +6,12 @@ from __future__ import absolute_import
import collections
from operator import attrgetter
import numpy as np
from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
from .basic import LightGBMError, _InnerPredictor, Dataset, Booster, is_str
from . import callback
def _construct_dataset(X_y, reference=None,
params=None, other_fields=None,
feature_name=None, categorical_feature=None,
predictor=None):
if 'max_bin' in params:
max_bin = int(params['max_bin'])
else:
max_bin = 255
weight = None
group = None
init_score = None
if other_fields is not None:
if not isinstance(other_fields, dict):
raise TypeError("type of other filed data should be dict")
weight = other_fields.get('weight', None)
group = other_fields.get('group', None)
init_score = other_fields.get('init_score', None)
if is_str(X_y):
data = X_y
label = None
else:
if len(X_y) != 2:
raise TypeError("should pass (data, label) tuple for dataset")
data = X_y[0]
label = X_y[1]
if reference is None:
ret = Dataset(data, label=label, max_bin=max_bin,
weight=weight, group=group,
predictor=predictor,
feature_name=feature_name,
categorical_feature=categorical_feature,
params=params)
else:
ret = reference.create_valid(data, label=label, weight=weight,
group=group, params=params)
if init_score is not None:
ret.set_init_score(init_score)
return ret
def train(params, train_data, num_boost_round=100,
valid_datas=None, valid_names=None,
def train(params, train_set, num_boost_round=100,
valid_sets=None, valid_names=None,
fobj=None, feval=None, init_model=None,
train_fields=None, valid_fields=None,
feature_name=None, categorical_feature=None,
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
@ -61,14 +21,14 @@ def train(params, train_data, num_boost_round=100,
----------
params : dict
Parameters for training.
train_data : Dataset, tuple (X, y) or filename of data
train_set : Dataset
Data to be trained.
num_boost_round: int
Number of boosting iterations.
valid_datas: list of Datasets, tuples (valid_X, valid_y) or filenames of data
valid_sets: list of Datasets
List of data to be evaluated during training
valid_names: list of string
Names of valid_datas
Names of valid_sets
fobj : function
Customized objective function.
feval : function
@ -76,13 +36,6 @@ def train(params, train_data, num_boost_round=100,
Note: should return (eval_name, eval_result, is_higher_better) of list of this
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
train_fields : dict
Other data file in training data. e.g. train_fields['weight'] is weight data
Support fields: weight, group, init_score
valid_fields : dict
Other data file in training data. \
e.g. valid_fields[0]['weight'] is weight data for first valid data
Support fields: weight, group, init_score
feature_name : list of str
Feature names
categorical_feature : list of str or int
@ -95,8 +48,8 @@ def train(params, train_data, num_boost_round=100,
Returns the model with (best_iter + early_stopping_rounds)
If early stopping occurs, the model will add 'best_iteration' field
evals_result: dict or None
This dictionary used to store all evaluation results of all the items in valid_datas.
Example: with a valid_datas containing [valid_set, train_set] \
This dictionary used to store all evaluation results of all the items in valid_sets.
Example: with a valid_sets containing [valid_set, train_set] \
and valid_names containing ['eval', 'train'] and a paramater containing ('metric':'logloss')
Returns: {'train': {'logloss': ['0.48253', '0.35953', ...]},
'eval': {'logloss': ['0.480385', '0.357756', ...]}}
@ -127,58 +80,40 @@ def train(params, train_data, num_boost_round=100,
"""
"""create predictor first"""
if is_str(init_model):
predictor = Predictor(model_file=init_model)
predictor = _InnerPredictor(model_file=init_model)
elif isinstance(init_model, Booster):
predictor = init_model.to_predictor()
elif isinstance(init_model, Predictor):
predictor = init_model
predictor = init_model._to_predictor()
else:
predictor = None
init_iteration = predictor.num_total_iteration if predictor else 0
"""create dataset"""
if isinstance(train_data, Dataset):
train_set = train_data
if train_fields is not None:
for field, data in train_fields.items():
train_set.set_field(field, data)
else:
train_set = _construct_dataset(train_data, None, params,
other_fields=train_fields,
feature_name=feature_name,
categorical_feature=categorical_feature,
predictor=predictor)
"""check dataset"""
if not isinstance(train_set, Dataset):
raise TypeError("only can accept Dataset instance for traninig")
train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature)
is_valid_contain_train = False
train_data_name = "training"
valid_sets = []
reduced_valid_sets = []
name_valid_sets = []
if valid_datas:
if isinstance(valid_datas, (Dataset, tuple)):
valid_datas = [valid_datas]
if valid_sets:
if isinstance(valid_sets, Dataset):
valid_sets = [valid_sets]
if isinstance(valid_names, str):
valid_names = [valid_names]
for i, valid_data in enumerate(valid_datas):
other_fields = None if valid_fields is None else valid_fields.get(i, None)
for i, valid_data in enumerate(valid_sets):
"""reduce cost for prediction training data"""
if valid_data[0] is train_data[0] and valid_data[1] is train_data[1]:
if valid_data is train_set:
is_valid_contain_train = True
if valid_names is not None:
train_data_name = valid_names[i]
continue
if isinstance(valid_data, Dataset):
valid_set = valid_data
if other_fields is not None:
for field, data in other_fields.items():
valid_set.set_field(field, data)
else:
valid_set = _construct_dataset(
valid_data,
train_set,
params,
other_fields=other_fields,
feature_name=feature_name,
categorical_feature=categorical_feature,
predictor=predictor)
valid_sets.append(valid_set)
if not isinstance(valid_data, Dataset):
raise TypeError("only can accept Dataset instance for traninig")
valid_data.set_reference(train_set)
reduced_valid_sets.append(valid_data)
if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i])
else:
@ -217,7 +152,7 @@ def train(params, train_data, num_boost_round=100,
booster = Booster(params=params, train_set=train_set)
if is_valid_contain_train:
booster.set_train_data_name(train_data_name)
for valid_set, name_valid_set in zip(valid_sets, name_valid_sets):
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
"""start training"""
@ -294,6 +229,7 @@ def _make_n_folds(full_data, nfold, params, seed, fpreproc=None, stratified=Fals
else:
raise LightGBMError('sklearn needs to be installed in order to use stratified cv')
else:
full_data.construct()
randidx = np.random.permutation(full_data.num_data())
kstep = int(len(randidx) / nfold)
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
@ -322,8 +258,8 @@ def _agg_cv_result(raw_results):
cvmap[one_line[1]].append(one_line[2])
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
metrics=(), fobj=None, feval=None, train_fields=None,
def cv(params, train_set, num_boost_round=10, nfold=5, stratified=False,
metrics=(), fobj=None, feval=None, init_model=None,
feature_name=None, categorical_feature=None,
early_stopping_rounds=None, fpreproc=None,
verbose_eval=None, show_stdv=True, seed=0,
@ -334,7 +270,7 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
----------
params : dict
Booster params.
train_data : tuple (X, y) or filename of data
train_set : Dataset
Data to be trained.
num_boost_round : int
Number of boosting iterations.
@ -350,9 +286,8 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
Custom objective function.
feval : function
Custom evaluation function.
train_fields : dict
Other data file in training data. e.g. train_fields['weight'] is weight data
Support fields: weight, group, init_score
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
feature_name : list of str
Feature names
categorical_feature : list of str or int
@ -382,6 +317,20 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
-------
evaluation history : list(string)
"""
if not isinstance(train_set, Dataset):
raise TypeError("only can accept Dataset instance for traninig")
if is_str(init_model):
predictor = _InnerPredictor(model_file=init_model)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor()
else:
predictor = None
train_set._set_predictor(predictor)
train_set.set_feature_name(feature_name)
train_set.set_categorical_feature(categorical_feature)
if metrics:
params.setdefault('metric', [])
if is_str(metrics):
@ -389,11 +338,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
else:
params['metric'].extend(metrics)
train_set = _construct_dataset(train_data, None, params,
other_fields=train_fields,
feature_name=feature_name,
categorical_feature=categorical_feature)
results = collections.defaultdict(list)
cvfolds = _make_n_folds(train_set, nfold, params, seed, fpreproc, stratified)

Просмотреть файл

@ -19,6 +19,7 @@ def find_lib_path():
if os.name == 'nt':
dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/'))
dll_path.append(os.path.join(curr_path, './windows/x64/Dll/'))
dll_path.append(os.path.join(curr_path, '../../Release/'))
dll_path = [os.path.join(p, 'lib_lightgbm.dll') for p in dll_path]
else:
dll_path = [os.path.join(p, 'lib_lightgbm.so') for p in dll_path]

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import absolute_import
import numpy as np
from .basic import LightGBMError, is_str
from .basic import LightGBMError, Dataset, is_str
from .engine import train
# sklearn
try:
@ -195,9 +195,12 @@ class LGBMModel(LGBMModelBase):
params.pop('nthread', None)
return params
def fit(self, X, y, eval_set=None, eval_metric=None,
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None,
eval_metric=None,
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None,
feature_name=None, categorical_feature=None,
other_params=None):
"""
@ -209,24 +212,29 @@ class LGBMModel(LGBMModelBase):
Feature matrix
y : array_like
Labels
sample_weight : array_like
weight of training data
init_score : array_like
init score of training data
group : array_like
group data of training data
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for early-stopping
eval_sample_weight : List of array
weight of eval data
eval_init_score : List of array
init score of eval data
eval_group : List of array
group data of eval data
eval_metric : str, list of str, callable, optional
If a str, should be a built-in evaluation metric to use.
If callable, a custom evaluation metric. The call \
signature is func(y_predicted, dataset) where dataset will be a \
Dataset fobject such that you may need to call the get_label \
Dateset object such that you may need to call the get_label \
method. And it must return (eval_name->str, eval_result->float, is_bigger_better->Bool)
early_stopping_rounds : int
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
train_fields : dict
Other data file in training data. e.g. train_fields['weight'] is weight data
Support fields: weight, group, init_score
valid_fields : dict
Other data file in training data. \
e.g. valid_fields[0]['weight'] is weight data for first valid data
Support fields: weight, group, init_score
feature_name : list of str
Feature names
categorical_feature : list of str or int
@ -263,12 +271,33 @@ class LGBMModel(LGBMModelBase):
feval = None
feval = eval_metric if callable(eval_metric) else None
self._Booster = train(params, (X, y),
self.n_estimators, valid_datas=eval_set,
def _construct_dataset(X, y, sample_weight, init_score, group):
ret = Dataset(X, label=y, weight=sample_weight, group=group)
ret.set_init_score(init_score)
return ret
train_set = _construct_dataset(X, y, sample_weight, init_score, group)
valid_sets = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
"""reduce cost for prediction training data"""
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = None if eval_sample_weight is None else eval_sample_weight.get(i, None)
valid_init_score = None if eval_init_score is None else eval_init_score.get(i, None)
valid_group = None if eval_group is None else eval_group.get(i, None)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group)
valid_sets.append(valid_set)
self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self.fobj, feval=feval,
verbose_eval=verbose, train_fields=train_fields,
valid_fields=valid_fields, feature_name=feature_name,
verbose_eval=verbose, feature_name=feature_name,
categorical_feature=categorical_feature)
if evals_result:
@ -331,14 +360,48 @@ class LGBMRegressor(LGBMModel, LGBMRegressorBase):
__doc__ = """Implementation of the scikit-learn API for LightGBM regression.
""" + '\n'.join(LGBMModel.__doc__.split('\n')[2:])
def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_sample_weight=None,
eval_init_score=None,
eval_metric=None,
early_stopping_rounds=None, verbose=True,
feature_name=None, categorical_feature=None,
other_params=None):
super(LGBMRegressor, self).fit(X, y, sample_weight, init_score, None,
eval_set, eval_sample_weight, eval_init_score, None,
eval_metric, early_stopping_rounds,
verbose, feature_name, categorical_feature,
other_params)
return self
class LGBMClassifier(LGBMModel, LGBMClassifierBase):
__doc__ = """Implementation of the scikit-learn API for LightGBM classification.
""" + '\n'.join(LGBMModel.__doc__.split('\n')[2:])
def fit(self, X, y, eval_set=None, eval_metric=None,
def __init__(self, num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
silent=True, objective="binary",
nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0):
super(LGBMClassifier, self).__init__(num_leaves, max_depth,
learning_rate, n_estimators, max_bin,
silent, objective,
nthread, min_split_gain, min_child_weight, min_child_samples,
subsample, subsample_freq, colsample_bytree,
reg_alpha, reg_lambda, scale_pos_weight,
is_unbalance, seed)
def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_sample_weight=None,
eval_init_score=None,
eval_metric=None,
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None,
feature_name=None, categorical_feature=None,
other_params=None):
@ -350,12 +413,6 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
# Switch to using a multiclass objective in the underlying LGBM instance
self.objective = "multiclass"
other_params['num_class'] = self.n_classes_
if eval_metric is None and eval_set is not None:
eval_metric = "multi_logloss"
else:
self.objective = "binary"
if eval_metric is None and eval_set is not None:
eval_metric = "binary_logloss"
self._le = LGBMLabelEncoder().fit(y)
training_labels = self._le.transform(y)
@ -363,10 +420,10 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
if eval_set is not None:
eval_set = list((x[0], self._le.transform(x[1])) for x in eval_set)
super(LGBMClassifier, self).fit(X, training_labels, eval_set,
super(LGBMClassifier, self).fit(X, training_labels, sample_weight, init_score, None,
eval_set, eval_sample_weight, eval_init_score, None,
eval_metric, early_stopping_rounds,
verbose, train_fields, valid_fields,
feature_name, categorical_feature,
verbose, feature_name, categorical_feature,
other_params)
return self
@ -442,34 +499,59 @@ class LGBMRanker(LGBMModel):
""" + '\n'.join(LGBMModel.__doc__.split('\n')[2:])
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
train_fields=None, valid_fields=None, other_params=None):
"""check group data"""
if "group" not in train_fields:
raise ValueError("should set group in train_fields for ranking task")
if eval_set is not None:
if valid_fields is None:
raise ValueError("valid_fields cannot be None when eval_set is not None")
elif len(valid_fields) != len(eval_set):
raise ValueError("lenght of valid_fields should equal with eval_set")
else:
for inner in valid_fields:
if "group" not in inner:
raise ValueError("should set group in valid_fields for ranking task")
def __init__(self, num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
silent=True, objective="lambdarank",
nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0):
super(LGBMRanker, self).__init__(num_leaves, max_depth,
learning_rate, n_estimators, max_bin,
silent, objective,
nthread, min_split_gain, min_child_weight, min_child_samples,
subsample, subsample_freq, colsample_bytree,
reg_alpha, reg_lambda, scale_pos_weight,
is_unbalance, seed)
if callable(self.objective):
self.fobj = _group_wise_objective(self.objective)
else:
self.objective = "lambdarank"
self.fobj = None
if eval_metric is None and eval_set is not None:
eval_metric = "ndcg"
super(LGBMRanker, self).fit(X, y, eval_set, eval_metric,
early_stopping_rounds, verbose,
train_fields, valid_fields,
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None,
eval_metric=None, eval_at=None,
early_stopping_rounds=None, verbose=True,
feature_name=None, categorical_feature=None,
other_params=None):
"""
Most arguments like LGBMModel.fit except following:
eval_at : list of int
The evaulation positions of NDCG
"""
"""check group data"""
if group is None:
raise ValueError("should use group for ranking task")
if eval_set is not None:
if eval_group is None:
raise ValueError("eval_group cannot be None when eval_set is not None")
elif len(eval_group) != len(eval_set):
raise ValueError("length of eval_group should equal with eval_set")
else:
for inner_group in eval_group:
if inner_group is None:
raise ValueError("should set group for all eval data for ranking task")
if eval_at is not None:
other_params = {} if other_params is None else other_params
other_params['ndcg_eval_at'] = list(eval_at)
super(LGBMRanker, self).fit(X, y, sample_weight, init_score, group,
eval_set, eval_sample_weight, eval_init_score, eval_group,
eval_metric, early_stopping_rounds,
verbose, feature_name, categorical_feature,
other_params)
return self

Просмотреть файл

@ -1,12 +1,14 @@
#include <LightGBM/config.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/log.h>
#include <vector>
#include <string>
#include <unordered_set>
#include <algorithm>
#include <limits>
namespace LightGBM {
@ -22,7 +24,7 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
continue;
}
params[key] = value;
} else {
} else if(Common::Trim(arg).size() > 0){
Log::Warning("Unknown parameter %s", arg.c_str());
}
}
@ -33,12 +35,21 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Random rand(seed);
int int_max = std::numeric_limits<int>::max();
io_config.data_random_seed = static_cast<int>(rand.NextInt(0, int_max));
boosting_config.bagging_seed = static_cast<int>(rand.NextInt(0, int_max));
boosting_config.drop_seed = static_cast<int>(rand.NextInt(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextInt(0, int_max));
}
GetTaskType(params);
GetBoostingType(params);
GetObjectiveType(params);
GetMetricType(params);
// sub-config setup
network_config.Set(params);
io_config.Set(params);

Просмотреть файл

@ -8,10 +8,6 @@ x_train, x_test, y_train, y_test = model_selection.train_test_split(X, Y, test_s
train_data = lgb.Dataset(x_train, max_bin=255, label=y_train)
num_features = train_data.num_feature()
names = ["name_%d" %(i) for i in range(num_features)]
train_data.set_feature_name(names)
valid_data = train_data.create_valid(x_test, label=y_test)
config={"objective":"binary","metric":"auc", "min_data":1, "num_leaves":15}

Просмотреть файл

@ -0,0 +1,77 @@
# coding: utf-8
# pylint: disable = invalid-name, C0111
import json
import lightgbm as lgb
import pandas as pd
from sklearn.metrics import mean_squared_error
# load or create your dataset
df_train = pd.read_csv('../../examples/regression/regression.train', header=None, sep='\t')
df_test = pd.read_csv('../../examples/regression/regression.test', header=None, sep='\t')
y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train, free_raw_data=False)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train,free_raw_data=False)
# specify your configurations as a dict
params = {
'task' : 'train',
'boosting_type' : 'gbdt',
'objective' : 'regression',
'metric' : {'l2', 'auc'},
'num_leaves' : 31,
'learning_rate' : 0.05,
'feature_fraction' : 0.9,
'bagging_fraction' : 0.8,
'bagging_freq': 5,
'verbose' : 0
}
# train
init_gbm = lgb.train(params,
lgb_train,
num_boost_round=5,
valid_sets=lgb_eval)
print('Start continue train')
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_sets=lgb_eval,
early_stopping_rounds=10,
init_model=init_gbm)
# save model to file
gbm.save_model('model.txt')
# predict
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)
# dump model to json (and save to file)
model_json = gbm.dump_model()
with open('model.json', 'w+') as f:
json.dump(model_json, f, indent=4)
# feature importances
print('Feature importances:', gbm.feature_importance())
print('Feature importances:', gbm.feature_importance("gain"))
print('Start test cv')
lgb.cv(params,
lgb_train,
num_boost_round=100,
nfold=5,
verbose_eval=5,
init_model=init_gbm)