From e43e5f002efea8b837993111ed4010cb70e12adc Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Mon, 21 Nov 2016 19:11:31 +0800 Subject: [PATCH 01/60] Add folder for python package --- python-package/lightgbm/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 python-package/lightgbm/__init__.py diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py new file mode 100644 index 000000000..e69de29bb From 6837efe74185b258a092b2330857ceef869ea06e Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 22 Nov 2016 13:54:02 +0800 Subject: [PATCH 02/60] Add draft for dataset --- python-package/lightgbm/basic.py | 538 +++++++++++++++++++++++++++++++ 1 file changed, 538 insertions(+) create mode 100644 python-package/lightgbm/basic.py diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py new file mode 100644 index 000000000..0803a1207 --- /dev/null +++ b/python-package/lightgbm/basic.py @@ -0,0 +1,538 @@ +"""Wrapper c_api of LightGBM""" +from __future__ import absolute_import + +import sys +import os +import ctypes +import collections +import re + +import numpy as np +import scipy.sparse + + +IS_PY3 = (sys.version_info[0] == 3) + + +def find_lib_path(): + """Find the path to LightGBM library files. + Returns + ------- + lib_path: list(string) + List of all found library path to LightGBM + """ + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + dll_path = [curr_path, os.path.join(curr_path, '../../lib/'), + os.path.join(curr_path, './lib/'), + os.path.join(sys.prefix, 'lightgbm')] + if os.name == 'nt': + dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/')) + dll_path.append(os.path.join(curr_path, './windows/x64/Dll/')) + dll_path = [os.path.join(p, 'lib_lightgbm.dll') for p in dll_path] + else: + dll_path = [os.path.join(p, 'lib_lightgbm.so') for p in dll_path] + lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] + if not lib_path: + raise Exception('Cannot find lightgbm Library') + return lib_path + +def _load_lib(): + """Load LightGBM Library.""" + lib_path = find_lib_path() + if len(lib_path) == 0: + return None + lib = ctypes.cdll.LoadLibrary(lib_path[0]) + lib.LGBM_GetLastError.restype = ctypes.c_char_p + return lib + +_LIB = _load_lib() + +class LightGBMError(Exception): + """Error throwed by LightGBM""" + pass + +def _safe_call(ret): + """Check the return value of C API call + Parameters + ---------- + ret : int + return value from API calls + """ + if ret != 0: + raise LightGBMError(_LIB.LGBM_GetLastError()) + +def is_str(s): + if IS_PY3: + return isinstance(s, str) + else: + return isinstance(s, basestring) + +def is_numpy_object(data): + return type(data).__module__ == np.__name__ + +def is_numpy_1d_array(data): + if isinstance(data, np.ndarray) and len(data.shape) == 1: + return True + else: + return False + +def list_to_1d_numpy(data, dtype): + if is_numpy_1d_array(data): + return data + elif isinstance(data, list): + return np.array(data, dtype=dtype, copy=False) + else: + raise TypeError("Unknow type({})".format(type(data).__name__)) + +def cfloat32_array_to_numpy(cptr, length): + """Convert a ctypes float pointer array to a numpy array. + """ + if isinstance(cptr, ctypes.POINTER(ctypes.c_float)): + res = np.fromiter(cptr, dtype=np.float32, count=length) + return res + else: + raise RuntimeError('expected float pointer') + +def cint32_array_to_numpy(cptr, length): + """Convert a ctypes float pointer array to a numpy array. + """ + if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)): + res = np.fromiter(cptr, dtype=np.int32, count=length) + return res + else: + raise RuntimeError('expected int pointer') + +def c_str(string): + """Convert a python string to cstring.""" + return ctypes.c_char_p(string.encode('utf-8')) + +def c_array(ctype, values): + """Convert a python array to c array.""" + return (ctype * len(values))(*values) + +"""marco definition of data type in c_api of LightGBM""" +C_API_DTYPE_FLOAT32 =0 +C_API_DTYPE_FLOAT64 =1 +C_API_DTYPE_INT32 =2 +C_API_DTYPE_INT64 =3 +"""Matric is row major in python""" +C_API_IS_ROW_MAJOR =1 + +def c_float_array(data): + """Convert numpy array / list to c float array.""" + if isinstance(data, list): + data = np.array(data, copy=False) + if is_numpy_1d_array(data): + if data.dtype == np.float32: + ptr_data = c_array(ctypes.c_float, data) + type_data = C_API_DTYPE_FLOAT32 + elif data.dtype == np.float64: + ptr_data = c_array(ctypes.c_double, data) + type_data = C_API_DTYPE_FLOAT64 + else: + raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype)) + else: + raise TypeError("Unknow type({})".format(type(data).__name__)) + return (ptr_data, type_data) + +def c_int_array(data): + """Convert numpy array to c int array.""" + if isinstance(data, list): + data = np.array(data, copy=False) + if is_numpy_1d_array(data): + if data.dtype == np.int32: + ptr_data = c_array(ctypes.c_int32, data) + type_data = C_API_DTYPE_INT32 + elif data.dtype == np.int64: + ptr_data = c_array(ctypes.c_int64, data) + type_data = C_API_DTYPE_INT64 + else: + raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype)) + else: + raise TypeError("Unknow type({})".format(type(data).__name__)) + return (ptr_data, type_data) + +class Dataset(object): + """Dataset used in LightGBM. + + Dataset is a internal data structure that used by LightGBM + You can construct Dataset from numpy.arrays + """ + + _feature_names = None + + def __init__(self, data, max_bin=255, reference=None, + label=None, weight=None, group_id=None, + silent=False, feature_names=None, + other_args=None): + """ + Dataset used in LightGBM. + + Parameters + ---------- + data : string/numpy array/scipy.sparse + Data source of Dataset. + When data is string type, it represents the path of txt file, + max_bin : int, required + max number of discrete bin for features + reference : Other Dataset, optional + If this dataset validation, need to use training data as reference + label : list or numpy 1-D array, optional + Label of the training data. + weight : list or numpy 1-D array , optional + Weight for each instance. + group_id : list or numpy 1-D array , optional + group/query id for each instance. Note: if having group/query id, data should group by this id + silent : boolean, optional + Whether print messages during construction + feature_names : list, optional + Set names for features. + other_args: list, optional + other parameters, format: ['key1=val1','key2=val2'] + """ + + if data is None: + self.handle = None + return + """process for args""" + pass_args = ["max_bin={}".format(max_bin)] + if silent: + pass_args.append("verbose=0") + if other_args: + pass_args += other_args + pass_args_str = ' '.join(pass_args) + """process for reference dataset""" + ref_dataset = None + if isinstance(reference, Dataset): + ref_dataset = ctypes.byref(reference.handle) + elif reference is not None: + raise TypeError('Reference dataset should be None or dataset instance') + """start construct data""" + if is_str(data): + self.handle = ctypes.c_void_p() + _safe_call(_LIB.LGBM_CreateDatasetFromFile( + c_str(data), + c_str(pass_args_str), + ref_dataset, + ctypes.byref(self.handle))) + elif isinstance(data, scipy.sparse.csr_matrix): + self._init_from_csr(data, pass_args_str, ref_dataset) + elif isinstance(data, scipy.sparse.csc_matrix): + self._init_from_csc(data, pass_args_str, ref_dataset) + elif isinstance(data, np.ndarray): + self._init_from_npy2d(data, pass_args_str, ref_dataset) + else: + try: + csr = scipy.sparse.csr_matrix(data) + self._init_from_csr(csr) + except: + raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) + if label is not None: + self.set_label(label) + if weight is not None: + self.set_weight(weight) + if group_id is not None: + self.set_group_id(group_id) + self.feature_names = feature_names + + def _init_from_csr(self, csr, pass_args_str, ref_dataset): + """ + Initialize data from a CSR matrix. + """ + if len(csr.indices) != len(csr.data): + raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) + self.handle = ctypes.c_void_p() + + ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) + ptr_data, type_ptr_data = c_float_array(csr.data) + + _safe_call(_LIB.LGBM_CreateDatasetFromCSR( + ptr_indptr, + type_ptr_indptr, + c_array(ctypes.c_int32, csr.indices), + ptr_data, + type_ptr_data, + len(csr.indptr), + len(csr.data), + csr.shape[1], + c_str(pass_args_str), + ref_dataset, + ctypes.byref(self.handle))) + + def _init_from_csc(self, csr, pass_args_str, ref_dataset): + """ + Initialize data from a CSC matrix. + """ + if len(csc.indices) != len(csc.data): + raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data))) + self.handle = ctypes.c_void_p() + + ptr_indptr, type_ptr_indptr = c_int_array(csc.indptr) + ptr_data, type_ptr_data = c_float_array(csc.data) + + _safe_call(_LIB.LGBM_CreateDatasetFromCSC( + ptr_indptr, + type_ptr_indptr, + c_array(ctypes.c_int32, csc.indices), + ptr_data, + type_ptr_data, + len(csc.indptr), + len(csc.data), + csc.shape[0], + c_str(pass_args_str), + ref_dataset, + ctypes.byref(self.handle))) + + def _init_from_npy2d(self, mat, pass_args_str, ref_dataset): + """ + Initialize data from a 2-D numpy matrix. + """ + if len(mat.shape) != 2: + raise ValueError('Input numpy.ndarray must be 2 dimensional') + + self.handle = ctypes.c_void_p() + if mat.dtype == np.float32 or mat.dtype == np.float64: + data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) + else: + """change non-float data to float data, need to copy""" + data = np.array(mat.reshape(mat.size), dtype=np.float32) + + ptr_data, type_ptr_data = c_float_array(data) + _safe_call(LIB.LGBM_CreateDatasetFromMat( + ptr_data, + type_ptr_data, + mat.shape[0], + mat.shape[1], + C_API_IS_ROW_MAJOR, + c_str(pass_args_str), + ref_dataset, + ctypes.byref(self.handle))) + + def __del__(self): + _safe_call(_LIB.LGBM_DatasetFree(self.handle)) + + def get_field(self, field_name): + """Get property from the Dataset. + + Parameters + ---------- + field_name: str + The field name of the information + + Returns + ------- + info : array + a numpy array of information of the data + """ + out_len = ctypes.c_int32() + out_type = ctypes.c_int32() + ret = ctypes.POINTER(ctypes.c_void_p)() + _safe_call(_LIB.LGBM_DatasetGetField( + self.handle, + c_str(field_name), + ctypes.byref(out_len), + ctypes.byref(ret), + ctypes.byref(out_type))) + if out_type.value == C_API_DTYPE_INT32: + return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(c_int32), out_len.value)) + elif out_type.value == C_API_DTYPE_FLOAT32: + return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(c_float), out_len.value)) + else: + raise TypeError("unknow type") + + def set_field(self, field_name, data): + """Set property into the Dataset. + + Parameters + ---------- + field_name: str + The field name of the information + + data: numpy array or list + The array ofdata to be set + """ + if not is_numpy_1d_array(data): + raise TypeError("Unknow type({})".format(type(data).__name__)) + if data.dtype == np.float32: + ptr_data = c_array(ctypes.c_float, data) + type_data = C_API_DTYPE_FLOAT32 + elif data.dtype == np.int32: + ptr_data = c_array(ctypes.c_int32, data) + type_data = C_API_DTYPE_INT32 + else: + raise TypeError("excepted np.float32 or np.int32, met type({})".format(data.dtype)) + _safe_call(_LIB.LGBM_DatasetSetField( + self.handle, + c_str(field_name), + ptr_data, + len(data), + type_data)) + + + def save_binary(self, filename): + """Save Dataset to binary file + + Parameters + ---------- + filename : string + Name of the output file. + """ + _safe_call(_LIB.LGBM_DatasetSaveBinary( + self.handle, + c_str(filename))) + + def set_label(self, label): + """Set label of Dataset + + Parameters + ---------- + label: array like + The label information to be set into Dataset + """ + label = list_to_1d_numpy(label, np.float32) + if label.dtype != np.float32: + label = label.astype(np.float32, copy=False) + self.set_field('label', label) + + def set_weight(self, weight): + """ Set weight of each instance. + + Parameters + ---------- + weight : array like + Weight for each data point + """ + weight = list_to_1d_numpy(weight, np.float32) + if weight.dtype != np.float32: + weight = weight.astype(np.float32, copy=False) + self.set_field('weight', weight) + + def set_init_score(self, score): + """ Set init score of booster to start from. + Parameters + ---------- + score: array like + + """ + score = list_to_1d_numpy(score, np.float32) + if score.dtype != np.float32: + score = score.astype(np.float32, copy=False) + self.set_field('init_score', score) + + def set_group(self, group): + """Set group size of Dataset (used for ranking). + + Parameters + ---------- + group : array like + Group size of each group + """ + group = list_to_1d_numpy(group, np.int32) + if group.dtype != np.int32: + group = group.astype(np.int32, copy=False) + self.set_field('group', group) + + def set_group_id(self, group_id): + + """Set group_id of Dataset (used for ranking). + + Parameters + ---------- + group : array like + group_id of Dataset (used for ranking). + """ + group_id = list_to_1d_numpy(group_id, np.int32) + if group_id.dtype != np.int32: + group_id = group_id.astype(np.int32, copy=False) + self.set_field('group_id', group_id) + + def get_label(self): + """Get the label of the Dataset. + + Returns + ------- + label : array + """ + return self.get_field('label') + + def get_weight(self): + """Get the weight of the Dataset. + + Returns + ------- + weight : array + """ + return self.get_field('weight') + + def get_init_score(self): + """Get the initial score of the Dataset. + + Returns + ------- + init_score : array + """ + return self.get_field('init_score') + + def num_data(self): + """Get the number of rows in the Dataset. + + Returns + ------- + number of rows : int + """ + ret = ctypes.c_int64() + _safe_call(_LIB.LGBM_DatasetGetNumData(self.handle, + ctypes.byref(ret))) + return ret.value + + def num_feature(self): + """Get the number of columns (features) in the Dataset. + + Returns + ------- + number of columns : int + """ + ret = ctypes.c_int64() + _safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle, + ctypes.byref(ret))) + return ret.value + + @property + def feature_names(self): + """Get feature names (column labels). + + Returns + ------- + feature_names : list + """ + if self._feature_names is None: + self._feature_names = ['Column_{0}'.format(i) for i in range(self.num_col())] + return self._feature_names + + @feature_names.setter + def feature_names(self, feature_names): + """Set feature names (column labels). + + Parameters + ---------- + feature_names : list + Labels for features + """ + if feature_names is not None: + # validate feature name + if not isinstance(feature_names, list): + feature_names = list(feature_names) + if len(feature_names) != len(set(feature_names)): + raise ValueError('feature_names must be unique') + if len(feature_names) != self.num_col(): + msg = 'feature_names must have the same length as data' + raise ValueError(msg) + # prohibit to use symbols may affect to parse. e.g. []< + if not all(isinstance(f, STRING_TYPES) and + not any(x in f for x in set(('[', ']', '<'))) + for f in feature_names): + raise ValueError('feature_names may not contain [, ] or <') + self._feature_names = feature_names + else: + self._feature_names = None + From a178b75b676735a1ec300cf82eba2b31a50cf8bc Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 22 Nov 2016 14:39:03 +0800 Subject: [PATCH 03/60] change some c_api interfaces for better compatibility --- include/LightGBM/boosting.h | 1 + include/LightGBM/c_api.h | 8 +++++--- include/LightGBM/dataset.h | 2 ++ src/boosting/boosting.cpp | 2 +- src/c_api.cpp | 14 ++++++++++---- src/io/dataset.cpp | 2 ++ src/io/metadata.cpp | 33 +++++++++++++++++++++++++++++++++ tests/c_api_test/test.py | 5 ++--- 8 files changed, 56 insertions(+), 11 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 6c889569c..0faee90d8 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -151,6 +151,7 @@ public: /*! \brief Disable copy */ Boosting(const Boosting&) = delete; + static void LoadFileToBoosting(Boosting* boosting, const char* filename); /*! * \brief Create boosting object * \param type Type of boosting diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 28c23cf0b..3027a33a7 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -165,7 +165,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, * \param field_name field name, can be label, weight, group * \param field_data pointer to vector * \param num_element number of element in field_data -* \param type float_32:0, int32_t:1 +* \param type float32 or int32 * \return 0 when success, -1 when failure happens */ DllExport int LGBM_DatasetSetField(DatesetHandle handle, @@ -180,7 +180,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, * \param field_name field name * \param out_len used to set result length * \param out_ptr pointer to the result -* \param out_type float_32:0, int32_t:1 +* \param out_type float32 or int32 * \return 0 when success, -1 when failure happens */ DllExport int LGBM_DatasetGetField(DatesetHandle handle, @@ -216,6 +216,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, * \param valid_names names of validation data sets * \param n_valid_datas number of validation set * \param parameters format: 'key1=value1 key2=value2' +* \param init_model_filename filename of model * \prama out handle of created Booster * \return 0 when success, -1 when failure happens */ @@ -224,6 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, const char* valid_names[], int n_valid_datas, const char* parameters, + const char* init_model_filename, BoosterHandle* out); /*! @@ -232,7 +234,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, * \param out handle of created Booster * \return 0 when success, -1 when failure happens */ -DllExport int LGBM_BoosterLoadFromModelfile( +DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, BoosterHandle* out); diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 6cbef3300..5ee25ff02 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -83,6 +83,8 @@ public: void SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len); + void SetQueryId(const data_size_t* query_id, data_size_t len); + /*! * \brief Set initial scores * \param init_score Initial scores, this class will manage memory for init_score. diff --git a/src/boosting/boosting.cpp b/src/boosting/boosting.cpp index 337c22ddf..c3721f57f 100644 --- a/src/boosting/boosting.cpp +++ b/src/boosting/boosting.cpp @@ -15,7 +15,7 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) { return BoostingType::kUnknow; } -void LoadFileToBoosting(Boosting* boosting, const char* filename) { +void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) { if (boosting != nullptr) { TextReader model_reader(filename, true); model_reader.ReadAllLines(); diff --git a/src/c_api.cpp b/src/c_api.cpp index 2d7c4ff1f..317ed4d22 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -82,11 +82,12 @@ public: Common::ConstPtrInVectorWrapper(valid_metrics_[i])); } } - + void LoadModelFromFile(const char* filename) { + Boosting::LoadFileToBoosting(boosting_.get(), filename); + } ~Booster() { } - bool TrainOneIter() { return boosting_->TrainOneIter(nullptr, nullptr, false); } @@ -414,6 +415,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, const char* valid_names[], int n_valid_datas, const char* parameters, + const char* init_model_filename, BoosterHandle* out) { API_BEGIN(); const Dataset* p_train_data = reinterpret_cast(train_data); @@ -423,11 +425,15 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, p_valid_datas.emplace_back(reinterpret_cast(valid_datas[i])); p_valid_names.emplace_back(valid_names[i]); } - *out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters); + auto ret = std::unique_ptr(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters)); + if (init_model_filename != nullptr) { + ret->LoadModelFromFile(init_model_filename); + } + *out = ret.release(); API_END(); } -DllExport int LGBM_BoosterLoadFromModelfile( +DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, BoosterHandle* out) { API_BEGIN(); diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index d61aaa774..efbad4e21 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -78,6 +78,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si name = Common::Trim(name); if (name == std::string("query") || name == std::string("group")) { metadata_.SetQueryBoundaries(field_data, num_element); + } else if (name == std::string("query_id") || name == std::string("group_id")) { + metadata_.SetQueryId(field_data, num_element); } else { return false; } diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 7203842b9..b61f3395f 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -248,6 +248,39 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size LoadQueryWeights(); } +void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) { + if (num_data_ != len) { + Log::Fatal("len of query id is not same with #data"); + } + if (queries_.size() > 0) { queries_.clear(); } + queries_ = std::vector(num_data_); + for (data_size_t i = 0; i < num_weights_; ++i) { + queries_[i] = query_id[i]; + } + // need convert query_id to boundaries + std::vector tmp_buffer; + data_size_t last_qid = -1; + data_size_t cur_cnt = 0; + for (data_size_t i = 0; i < num_data_; ++i) { + if (last_qid != queries_[i]) { + if (cur_cnt > 0) { + tmp_buffer.push_back(cur_cnt); + } + cur_cnt = 0; + last_qid = queries_[i]; + } + ++cur_cnt; + } + tmp_buffer.push_back(cur_cnt); + query_boundaries_ = std::vector(tmp_buffer.size() + 1); + num_queries_ = static_cast(tmp_buffer.size()); + query_boundaries_[0] = 0; + for (size_t i = 0; i < tmp_buffer.size(); ++i) { + query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i]; + } + queries_.clear(); + LoadQueryWeights(); +} void Metadata::LoadWeights() { num_weights_ = 0; diff --git a/tests/c_api_test/test.py b/tests/c_api_test/test.py index 1cd93e3ee..92c8d4999 100644 --- a/tests/c_api_test/test.py +++ b/tests/c_api_test/test.py @@ -178,7 +178,7 @@ def test_booster(): name = [c_str('test')] booster = ctypes.c_void_p() LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name), - len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster)) + len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster)) is_finished = ctypes.c_int(0) for i in range(100): LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) @@ -191,7 +191,7 @@ def test_booster(): test_free_dataset(train) test_free_dataset(test[0]) booster2 = ctypes.c_void_p() - LIB.LGBM_BoosterLoadFromModelfile(c_str('model.txt'), ctypes.byref(booster2)) + LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(booster2)) data = [] inp = open('../../examples/binary_classification/binary.test', 'r') for line in inp.readlines(): @@ -214,4 +214,3 @@ def test_booster(): test_dataset() test_booster() - From de114be5c90185757212922adb6449c99ff8985d Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 22 Nov 2016 15:23:35 +0800 Subject: [PATCH 04/60] add nullptr check for get_field --- include/LightGBM/dataset.h | 34 ++++++++++++++++++++++++++++------ src/c_api.cpp | 2 +- src/io/dataset.cpp | 14 +++++++++++--- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 5ee25ff02..53235c25a 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -143,8 +143,13 @@ public: * \brief Get weights, if not exists, will return nullptr * \return Pointer of weights */ - inline const float* weights() - const { return weights_.data(); } + inline const float* weights() const { + if (weights_.size() > 0) { + return weights_.data(); + } else { + return nullptr; + } + } /*! * \brief Get data boundaries on queries, if not exists, will return nullptr @@ -153,8 +158,13 @@ public: * is the data indices for query i. * \return Pointer of data boundaries on queries */ - inline const data_size_t* query_boundaries() - const { return query_boundaries_.data(); } + inline const data_size_t* query_boundaries() const { + if (query_boundaries_.size() > 0) { + return query_boundaries_.data(); + } else { + return nullptr; + } + } /*! * \brief Get Number of queries @@ -166,13 +176,25 @@ public: * \brief Get weights for queries, if not exists, will return nullptr * \return Pointer of weights for queries */ - inline const float* query_weights() const { return query_weights_.data(); } + inline const float* query_weights() const { + if (query_weights_.size() > 0) { + return query_weights_.data(); + } else { + return nullptr; + } + } /*! * \brief Get initial scores, if not exists, will return nullptr * \return Pointer of initial scores */ - inline const float* init_score() const { return init_score_.data(); } + inline const float* init_score() const { + if (init_score_.size() > 0) { + return init_score_.data(); + } else { + return nullptr; + } + } /*! \brief Disable copy */ Metadata& operator=(const Metadata&) = delete; diff --git a/src/c_api.cpp b/src/c_api.cpp index 317ed4d22..2cb5c3511 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -387,7 +387,7 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, *out_type = C_API_DTYPE_INT32; is_success = true; } - if (!is_success) { throw std::runtime_error("Field not found"); } + if (!is_success) { throw std::runtime_error("Field not found or not exist"); } API_END(); } diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index efbad4e21..3024d1177 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -101,7 +101,11 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa } else { return false; } - return true; + if (*out_ptr != nullptr) { + return true; + } else { + return false; + } } bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) { @@ -109,11 +113,15 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** name = Common::Trim(name); if (name == std::string("query") || name == std::string("group")) { *out_ptr = metadata_.query_boundaries(); - *out_len = num_data_; + *out_len = metadata_.num_queries(); + } else { + return false; + } + if (*out_ptr != nullptr) { + return true; } else { return false; } - return true; } void Dataset::SaveBinaryFile(const char* bin_filename) { From fa4ecfda239f084b04a554ff864ab73f755e14d5 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 22 Nov 2016 16:45:25 +0800 Subject: [PATCH 05/60] add constructor for booster --- python-package/lightgbm/basic.py | 129 ++++++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 18 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 0803a1207..cc72fd400 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -110,6 +110,13 @@ def c_array(ctype, values): """Convert a python array to c array.""" return (ctype * len(values))(*values) +def dict_to_str(data): + if len(data) == 0: + return "" + pairs = [] + for key in data: + pairs.append(str(key)+'='+str(data[key])) + return ' '.join(pairs) """marco definition of data type in c_api of LightGBM""" C_API_DTYPE_FLOAT32 =0 C_API_DTYPE_FLOAT64 =1 @@ -164,7 +171,7 @@ class Dataset(object): def __init__(self, data, max_bin=255, reference=None, label=None, weight=None, group_id=None, silent=False, feature_names=None, - other_args=None): + other_params=None, is_continue_train=False): """ Dataset used in LightGBM. @@ -187,20 +194,27 @@ class Dataset(object): Whether print messages during construction feature_names : list, optional Set names for features. - other_args: list, optional - other parameters, format: ['key1=val1','key2=val2'] + other_params: dict, optional + other parameters """ if data is None: self.handle = None return + """save raw data for continue train """ + if is_continue_train: + self.raw_data = data + else: + self.raw_data = None """process for args""" - pass_args = ["max_bin={}".format(max_bin)] + params = {} + params["max_bin"] = max_bin if silent: - pass_args.append("verbose=0") - if other_args: - pass_args += other_args - pass_args_str = ' '.join(pass_args) + params["verbose"] = 0 + if other_params: + other_params.update(params) + params = other_params + params_str = dict_to_str(params) """process for reference dataset""" ref_dataset = None if isinstance(reference, Dataset): @@ -212,15 +226,15 @@ class Dataset(object): self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_CreateDatasetFromFile( c_str(data), - c_str(pass_args_str), + c_str(params_str), ref_dataset, ctypes.byref(self.handle))) elif isinstance(data, scipy.sparse.csr_matrix): - self._init_from_csr(data, pass_args_str, ref_dataset) + self._init_from_csr(data, params_str, ref_dataset) elif isinstance(data, scipy.sparse.csc_matrix): - self._init_from_csc(data, pass_args_str, ref_dataset) + self._init_from_csc(data, params_str, ref_dataset) elif isinstance(data, np.ndarray): - self._init_from_npy2d(data, pass_args_str, ref_dataset) + self._init_from_npy2d(data, params_str, ref_dataset) else: try: csr = scipy.sparse.csr_matrix(data) @@ -235,7 +249,10 @@ class Dataset(object): self.set_group_id(group_id) self.feature_names = feature_names - def _init_from_csr(self, csr, pass_args_str, ref_dataset): + def free_raw_data(self): + self.raw_data = None + + def _init_from_csr(self, csr, params_str, ref_dataset): """ Initialize data from a CSR matrix. """ @@ -255,11 +272,11 @@ class Dataset(object): len(csr.indptr), len(csr.data), csr.shape[1], - c_str(pass_args_str), + c_str(params_str), ref_dataset, ctypes.byref(self.handle))) - def _init_from_csc(self, csr, pass_args_str, ref_dataset): + def _init_from_csc(self, csr, params_str, ref_dataset): """ Initialize data from a CSC matrix. """ @@ -279,11 +296,11 @@ class Dataset(object): len(csc.indptr), len(csc.data), csc.shape[0], - c_str(pass_args_str), + c_str(params_str), ref_dataset, ctypes.byref(self.handle))) - def _init_from_npy2d(self, mat, pass_args_str, ref_dataset): + def _init_from_npy2d(self, mat, params_str, ref_dataset): """ Initialize data from a 2-D numpy matrix. """ @@ -304,7 +321,7 @@ class Dataset(object): mat.shape[0], mat.shape[1], C_API_IS_ROW_MAJOR, - c_str(pass_args_str), + c_str(params_str), ref_dataset, ctypes.byref(self.handle))) @@ -536,3 +553,79 @@ class Dataset(object): else: self._feature_names = None + +class Booster(object): + """"A Booster of of LightGBM. + """ + + feature_names = None + + def __init__(self, params=None, + train_set=None, + valid_sets=None, + name_valid_sets=None, + model_file=None, + fobj=None): + # pylint: disable=invalid-name + """Initialize the Booster. + + Parameters + ---------- + params : dict + Parameters for boosters. + train_set : Dataset + training dataset + valid_sets : List of Dataset or None + validation datasets + name_valid_sets : List of string + name of validation datasets + model_file : string + Path to the model file. + """ + self.handle = ctypes.c_void_p() + if train_set is not None: + if not isinstance(train_set, Dataset): + raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__)) + + valid_handles = None + valid_cnames = None + n_valid = 0 + if valid_sets is not None: + for valid in valid_sets: + if not isinstance(valid, Dataset): + raise TypeError('valid data should be Dataset instance, met{}'.format(type(valid).__name__)) + valid_handles = c_array(ctypes.c_void_p, [valid.handle for valid in valid_sets]) + if name_valid_sets is None: + name_valid_sets = ["valid_{}".format(x) for x in range(len(valid_sets)) ] + if len(valid_sets) != len(name_valid_sets): + raise Exception('len of valid_sets should be equal with len of name_valid_sets') + valid_cnames = c_array(ctypes.c_char_p, [c_str(x) for x in name_valid_sets]) + n_valid = len(valid_sets) + ref_input_model = None + params_str = dict_to_str(params) + if model_file is not None: + ref_input_model = c_str(model_file) + """construct booster object""" + _safe_call(LIB.LGBM_BoosterCreate( + train_set.handle, + valid_handles, + valid_cnames, + n_valid, + params_str, + ref_input_model, + ctypes.byref(self.handle))) + """if need to continue train""" + if model_file is not None: + self.init_continue_train(train_set) + if valid_sets is not None: + for valid in valid_sets: + self.init_continue_train(valid) + + elif model_file is not None: + _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(c_str(model_file), ctypes.byref(self.handle))) + else: + raise TypeError('At least need training dataset or model file to create booster instance') + + def __del__(self): + _LIB.LGBM_BoosterFree(self.handle) + From 8639107f191a6b923a38ce16b5319f974badc12e Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 23 Nov 2016 10:55:46 +0800 Subject: [PATCH 06/60] Add Fatal when machine file format error --- src/network/linkers_socket.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/network/linkers_socket.cpp b/src/network/linkers_socket.cpp index b9ad584d1..22b2c18a3 100644 --- a/src/network/linkers_socket.cpp +++ b/src/network/linkers_socket.cpp @@ -28,10 +28,6 @@ Linkers::Linkers(NetworkConfig config) { // parser clients from file ParseMachineList(config.machine_list_filename.c_str()); - if (num_machines_ <= 1) { - return; - } - if (rank_ == -1) { // get ip list of local machine std::unordered_set local_ip_list = TcpSocket::GetLocalIpList(); @@ -101,10 +97,15 @@ void Linkers::ParseMachineList(const char * filename) { client_ips_.push_back(str_after_split[0]); client_ports_.push_back(atoi(str_after_split[1].c_str())); } + if (client_ips_.size() == 0) { + Log::Fatal("Machine list file doesn't contain any ip and port. \ + Please check it again"); + } if (client_ips_.size() != static_cast(num_machines_)) { Log::Warning("World size is larger than the machine_list size, change world size to %d", client_ips_.size()); num_machines_ = static_cast(client_ips_.size()); } + } void Linkers::TryBind(int port) { From fc383361b771828b7871bc6fbeca8970a4d891fe Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 23 Nov 2016 16:49:02 +0800 Subject: [PATCH 07/60] remove data name in metric --- include/LightGBM/c_api.h | 17 +++++++-- include/LightGBM/metric.h | 3 +- src/application/application.cpp | 8 ++-- src/boosting/gbdt.cpp | 4 +- src/c_api.cpp | 63 +++++++++++++++++++++++++------- src/metric/binary_metric.hpp | 13 ++----- src/metric/multiclass_metric.hpp | 7 ++-- src/metric/rank_metric.hpp | 7 +--- src/metric/regression_metric.hpp | 6 +-- tests/c_api_test/test.py | 5 +-- 10 files changed, 83 insertions(+), 50 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 3027a33a7..d9bcebab5 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -222,7 +222,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, */ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, const DatesetHandle valid_datas[], - const char* valid_names[], int n_valid_datas, const char* parameters, const char* init_model_filename, @@ -267,6 +266,18 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, const float* hess, int* is_finished); +/*! +* \brief Get number of eval +* \return total number of eval result +*/ +DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len); + +/*! +* \brief Get number of eval +* \return total number of eval result +*/ +DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs); + /*! * \brief get evaluation for training data and validation data * \param handle handle @@ -275,7 +286,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, * \param out_result the string containing evaluation statistics, should allocate memory before call this function * \return 0 when success, -1 when failure happens */ -DllExport int LGBM_BoosterEval(BoosterHandle handle, +DllExport int LGBM_BoosterGetEval(BoosterHandle handle, int data, int64_t* out_len, float* out_results); @@ -287,7 +298,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, * \param out_result used to set a pointer to array * \return 0 when success, -1 when failure happens */ -DllExport int LGBM_BoosterGetScore(BoosterHandle handle, +DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle, int64_t* out_len, const float** out_result); diff --git a/include/LightGBM/metric.h b/include/LightGBM/metric.h index c3ab30879..e54a2598b 100644 --- a/include/LightGBM/metric.h +++ b/include/LightGBM/metric.h @@ -24,8 +24,7 @@ public: * \param metadata Label data * \param num_data Number of data */ - virtual void Init(const char* test_name, - const Metadata& metadata, data_size_t num_data) = 0; + virtual void Init(const Metadata& metadata, data_size_t num_data) = 0; virtual const std::vector& GetName() const = 0; diff --git a/src/application/application.cpp b/src/application/application.cpp index 5e37d3743..ef53a7adb 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -139,8 +139,7 @@ void Application::LoadData() { for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } - metric->Init("training", train_data_->metadata(), - train_data_->num_data()); + metric->Init(train_data_->metadata(), train_data_->num_data()); train_metric_.push_back(std::move(metric)); } } @@ -164,9 +163,8 @@ void Application::LoadData() { for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } - metric->Init(config_.io_config.valid_data_filenames[i].c_str(), - valid_datas_.back()->metadata(), - valid_datas_.back()->num_data()); + metric->Init(valid_datas_.back()->metadata(), + valid_datas_.back()->num_data()); valid_metrics_.back().push_back(std::move(metric)); } valid_metrics_.back().shrink_to_fit(); diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 52b64f005..d34e045cb 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -236,7 +236,7 @@ bool GBDT::OutputMetric(int iter) { auto name = sub_metric->GetName(); auto scores = sub_metric->Eval(train_score_updater_->score()); for (size_t k = 0; k < name.size(); ++k) { - Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), scores[k]); + Log::Info("Iteration:%d, training %s : %f", iter, name[k].c_str(), scores[k]); } } } @@ -248,7 +248,7 @@ bool GBDT::OutputMetric(int iter) { if ((iter % gbdt_config_->output_freq) == 0) { auto name = valid_metrics_[i][j]->GetName(); for (size_t k = 0; k < name.size(); ++k) { - Log::Info("Iteration: %d, %s : %f", iter, name[k].c_str(), test_scores[k]); + Log::Info("Iteration:%d, valid_%d %s : %f", iter, i + 1, name[k].c_str(), test_scores[k]); } } if (!ret && early_stopping_round_ > 0) { diff --git a/src/c_api.cpp b/src/c_api.cpp index 2cb5c3511..a060c03fd 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -29,7 +29,6 @@ public: Booster(const Dataset* train_data, std::vector valid_data, - std::vector valid_names, const char* parameters) :train_data_(train_data), valid_datas_(valid_data) { config_.LoadFromString(parameters); @@ -50,8 +49,7 @@ public: auto metric = std::unique_ptr( Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } - metric->Init("training", train_data_->metadata(), - train_data_->num_data()); + metric->Init(train_data_->metadata(), train_data_->num_data()); train_metric_.push_back(std::move(metric)); } train_metric_.shrink_to_fit(); @@ -61,9 +59,7 @@ public: for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } - metric->Init(valid_names[i].c_str(), - valid_datas_[i]->metadata(), - valid_datas_[i]->num_data()); + metric->Init(valid_datas_[i]->metadata(), valid_datas_[i]->num_data()); valid_metrics_.back().push_back(std::move(metric)); } valid_metrics_.back().shrink_to_fit(); @@ -82,12 +78,15 @@ public: Common::ConstPtrInVectorWrapper(valid_metrics_[i])); } } + void LoadModelFromFile(const char* filename) { Boosting::LoadFileToBoosting(boosting_.get(), filename); } + ~Booster() { } + bool TrainOneIter() { return boosting_->TrainOneIter(nullptr, nullptr, false); } @@ -121,7 +120,25 @@ public: void SaveModelToFile(int num_used_model, const char* filename) { boosting_->SaveModelToFile(num_used_model, true, filename); } - + + int GetEvalCounts() const { + int ret = 0; + for (const auto& metric : train_metric_) { + ret += static_cast(metric->GetName().size()); + } + return ret; + } + + int GetEvalNames(const char*** out_strs) const { + int idx = 0; + for (const auto& metric : train_metric_) { + for (const auto& name : metric->GetName()) { + *(out_strs[idx++]) = name.c_str(); + } + } + return idx; + } + const Boosting* GetBoosting() const { return boosting_.get(); } const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); } @@ -412,7 +429,6 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, const DatesetHandle valid_datas[], - const char* valid_names[], int n_valid_datas, const char* parameters, const char* init_model_filename, @@ -420,12 +436,10 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, API_BEGIN(); const Dataset* p_train_data = reinterpret_cast(train_data); std::vector p_valid_datas; - std::vector p_valid_names; for (int i = 0; i < n_valid_datas; ++i) { p_valid_datas.emplace_back(reinterpret_cast(valid_datas[i])); - p_valid_names.emplace_back(valid_names[i]); } - auto ret = std::unique_ptr(new Booster(p_train_data, p_valid_datas, p_valid_names, parameters)); + auto ret = std::unique_ptr(new Booster(p_train_data, p_valid_datas, parameters)); if (init_model_filename != nullptr) { ret->LoadModelFromFile(init_model_filename); } @@ -472,7 +486,30 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, API_END(); } -DllExport int LGBM_BoosterEval(BoosterHandle handle, +/*! +* \brief Get number of eval +* \return total number of eval result +*/ +DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + *out_len = ref_booster->GetEvalCounts(); + API_END(); +} + +/*! +* \brief Get number of eval +* \return total number of eval result +*/ +DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + *out_len = ref_booster->GetEvalNames(out_strs); + API_END(); +} + + +DllExport int LGBM_BoosterGetEval(BoosterHandle handle, int data, int64_t* out_len, float* out_results) { @@ -487,7 +524,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle, API_END(); } -DllExport int LGBM_BoosterGetScore(BoosterHandle handle, +DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle, int64_t* out_len, const float** out_result) { API_BEGIN(); diff --git a/src/metric/binary_metric.hpp b/src/metric/binary_metric.hpp index 179caa93b..c1f2c2982 100644 --- a/src/metric/binary_metric.hpp +++ b/src/metric/binary_metric.hpp @@ -29,11 +29,8 @@ public: } - void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { - - std::stringstream str_buf; - str_buf << test_name << "'s : " << PointWiseLossCalculator::Name(); - name_.emplace_back(str_buf.str()); + void Init(const Metadata& metadata, data_size_t num_data) override { + name_.emplace_back(PointWiseLossCalculator::Name()); num_data_ = num_data; // get label @@ -162,10 +159,8 @@ public: return 1.0f; } - void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { - std::stringstream str_buf; - str_buf << test_name << "'s : AUC"; - name_.emplace_back(str_buf.str()); + void Init(const Metadata& metadata, data_size_t num_data) override { + name_.emplace_back("AUC"); num_data_ = num_data; // get label diff --git a/src/metric/multiclass_metric.hpp b/src/metric/multiclass_metric.hpp index 9681240d6..9b5c3c7b6 100644 --- a/src/metric/multiclass_metric.hpp +++ b/src/metric/multiclass_metric.hpp @@ -23,10 +23,9 @@ public: } - void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { - std::stringstream str_buf; - str_buf << test_name << " : " << PointWiseLossCalculator::Name(); - name_.emplace_back(str_buf.str()); + void Init(const Metadata& metadata, data_size_t num_data) override { + + name_.emplace_back(PointWiseLossCalculator::Name()); num_data_ = num_data; // get label label_ = metadata.label(); diff --git a/src/metric/rank_metric.hpp b/src/metric/rank_metric.hpp index 75fa472b8..bc5ae96c3 100644 --- a/src/metric/rank_metric.hpp +++ b/src/metric/rank_metric.hpp @@ -33,12 +33,9 @@ public: ~NDCGMetric() { } - void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { + void Init(const Metadata& metadata, data_size_t num_data) override { for (auto k : eval_at_) { - std::stringstream str_buf; - str_buf << test_name << "'s : "; - str_buf << "NDCG@" + std::to_string(k) + " "; - name_.emplace_back(str_buf.str()); + name_.emplace_back(std::string("NDCG@") + std::to_string(k)); } num_data_ = num_data; // get label diff --git a/src/metric/regression_metric.hpp b/src/metric/regression_metric.hpp index 1bce8ac91..7e7f21241 100644 --- a/src/metric/regression_metric.hpp +++ b/src/metric/regression_metric.hpp @@ -31,10 +31,8 @@ public: return -1.0f; } - void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override { - std::stringstream str_buf; - str_buf << test_name << " : " << PointWiseLossCalculator::Name(); - name_.emplace_back(str_buf.str()); + void Init(const Metadata& metadata, data_size_t num_data) override { + name_.emplace_back(PointWiseLossCalculator::Name()); num_data_ = num_data; // get label diff --git a/tests/c_api_test/test.py b/tests/c_api_test/test.py index 92c8d4999..b7db8b7fc 100644 --- a/tests/c_api_test/test.py +++ b/tests/c_api_test/test.py @@ -175,16 +175,15 @@ def test_dataset(): def test_booster(): train = test_load_from_mat('../../examples/binary_classification/binary.train', None) test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)] - name = [c_str('test')] booster = ctypes.c_void_p() - LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name), + LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster)) is_finished = ctypes.c_int(0) for i in range(100): LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) result = np.array([0.0], dtype=np.float32) out_len = ctypes.c_ulong(0) - LIB.LGBM_BoosterEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))) + LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))) print ('%d Iteration test AUC %f' %(i, result[0])) LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterFree(booster) From 422c0ef728e1748d96c017f8de41105cc4a1df67 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 23 Nov 2016 22:04:46 +0800 Subject: [PATCH 08/60] almost finish, need some tests --- include/LightGBM/boosting.h | 6 +- include/LightGBM/c_api.h | 53 ++-- include/LightGBM/config.h | 2 +- python-package/lightgbm/basic.py | 499 ++++++++++++++++++++++++++----- src/application/application.cpp | 4 +- src/boosting/dart.hpp | 15 +- src/boosting/gbdt.cpp | 32 +- src/boosting/gbdt.h | 21 +- src/c_api.cpp | 111 ++++--- src/io/config.cpp | 2 +- tests/c_api_test/test.py | 9 +- 11 files changed, 564 insertions(+), 190 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 0faee90d8..e325b789a 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -73,7 +73,7 @@ public: * \param result used to store prediction result, should allocate memory before call this function * \param out_len lenght of returned score */ - virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) const = 0; + virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) = 0; /*! * \brief Prediction for one record, not sigmoid transform @@ -127,7 +127,7 @@ public: * \brief Get number of weak sub-models * \return Number of weak sub-models */ - virtual int NumberOfSubModels() const = 0; + virtual int NumberOfTotalModel() const = 0; /*! * \brief Get number of classes @@ -138,7 +138,7 @@ public: /*! * \brief Set number of used model for prediction */ - virtual void SetNumUsedModel(int num_used_model) = 0; + virtual void SetNumIterationForPred(int num_iteration) = 0; /*! * \brief Get Type name of this boosting object diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index d9bcebab5..a3aeb90a6 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -230,11 +230,13 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, /*! * \brief load an existing boosting from model file * \param filename filename of model +* \param out_num_total_model number of total models * \param out handle of created Booster * \return 0 when success, -1 when failure happens */ DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, + int64_t* out_num_total_model, BoosterHandle* out); /*! @@ -244,6 +246,12 @@ DllExport int LGBM_BoosterCreateFromModelfile( */ DllExport int LGBM_BoosterFree(BoosterHandle handle); +/*! +* \brief Get number of class +* \return number of class +*/ +DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len); + /*! * \brief update the model in one round * \param handle handle @@ -276,7 +284,7 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len); * \brief Get number of eval * \return total number of eval result */ -DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs); +DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs); /*! * \brief get evaluation for training data and validation data @@ -291,17 +299,6 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, int64_t* out_len, float* out_results); -/*! -* \brief get raw score for training data, used to calculate gradients outside -* \param handle handle -* \param out_len len of output result -* \param out_result used to set a pointer to array -* \return 0 when success, -1 when failure happens -*/ -DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle, - int64_t* out_len, - const float** out_result); - /*! * \brief Get prediction for training data and validation data this can be used to support customized eval function @@ -319,21 +316,21 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, /*! * \brief make prediction for file * \param handle handle +* \param data_filename filename of data file +* \param data_has_header data file has header or not * \param predict_type * 0:raw score * 1:with transform(if needed) * 2:leaf index -* \param n_used_trees number of used tree -* \param data_has_header data file has header or not -* \param data_filename filename of data file +* \param num_iteration number of iteration for prediction * \param result_filename filename of result file * \return 0 when success, -1 when failure happens */ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, - int predict_type, - int64_t n_used_trees, - int data_has_header, const char* data_filename, + int data_has_header, + int predict_type, + int64_t num_iteration, const char* result_filename); /*! @@ -351,7 +348,8 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, * 0:raw score * 1:with transform(if needed) * 2:leaf index -* \param n_used_trees number of used tree +* \param num_iteration number of iteration for prediction +* \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function * \return 0 when success, -1 when failure happens */ @@ -365,8 +363,9 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, int64_t nelem, int64_t num_col, int predict_type, - int64_t n_used_trees, - double* out_result); + int64_t num_iteration, + int64_t* out_len, + float* out_result); /*! * \brief make prediction for an new data set @@ -380,7 +379,8 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, * 0:raw score * 1:with transform(if needed) * 2:leaf index -* \param n_used_trees number of used tree +* \param num_iteration number of iteration for prediction +* \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function * \return 0 when success, -1 when failure happens */ @@ -391,18 +391,19 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, int32_t ncol, int is_row_major, int predict_type, - int64_t n_used_trees, - double* out_result); + int64_t num_iteration, + int64_t* out_len, + float* out_result); /*! * \brief save model into file * \param handle handle -* \param num_used_model +* \param num_iteration * \param filename file name * \return 0 when success, -1 when failure happens */ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, - int num_used_model, + int num_iteration, const char* filename); diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 0a5ddf3b7..ea968177f 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -97,7 +97,7 @@ public: std::string output_result = "LightGBM_predict_result.txt"; std::string input_model = ""; int verbosity = 1; - int num_model_predict = NO_LIMIT; + int num_iteration_predict = NO_LIMIT; bool is_pre_partition = false; bool is_enable_sparse = true; bool use_two_round_loading = false; diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index cc72fd400..1aef75fe0 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -6,6 +6,7 @@ import os import ctypes import collections import re +import tempfile import numpy as np import scipy.sparse @@ -111,7 +112,7 @@ def c_array(ctype, values): return (ctype * len(values))(*values) def dict_to_str(data): - if len(data) == 0: + if data is None or len(data) == 0: return "" pairs = [] for key in data: @@ -131,10 +132,10 @@ def c_float_array(data): data = np.array(data, copy=False) if is_numpy_1d_array(data): if data.dtype == np.float32: - ptr_data = c_array(ctypes.c_float, data) + ptr_data = data.ctypes.data_as(ctypes.c_float) type_data = C_API_DTYPE_FLOAT32 elif data.dtype == np.float64: - ptr_data = c_array(ctypes.c_double, data) + ptr_data = data.ctypes.data_as(ctypes.c_double) type_data = C_API_DTYPE_FLOAT64 else: raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype)) @@ -148,10 +149,10 @@ def c_int_array(data): data = np.array(data, copy=False) if is_numpy_1d_array(data): if data.dtype == np.int32: - ptr_data = c_array(ctypes.c_int32, data) + ptr_data = data.ctypes.data_as(ctypes.c_int32) type_data = C_API_DTYPE_INT32 elif data.dtype == np.int64: - ptr_data = c_array(ctypes.c_int64, data) + ptr_data = data.ctypes.data_as(ctypes.c_int64) type_data = C_API_DTYPE_INT64 else: raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype)) @@ -206,6 +207,7 @@ class Dataset(object): self.raw_data = data else: self.raw_data = None + self.data_has_header = False """process for args""" params = {} params["max_bin"] = max_bin @@ -223,6 +225,10 @@ class Dataset(object): raise TypeError('Reference dataset should be None or dataset instance') """start construct data""" if is_str(data): + """check data has header or not""" + if "has_header" in params or "header" in params: + if params["has_header"].lower() == "true" or params["header"].lower() == "true": + data_has_header = True self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_CreateDatasetFromFile( c_str(data), @@ -230,17 +236,21 @@ class Dataset(object): ref_dataset, ctypes.byref(self.handle))) elif isinstance(data, scipy.sparse.csr_matrix): - self._init_from_csr(data, params_str, ref_dataset) - elif isinstance(data, scipy.sparse.csc_matrix): - self._init_from_csc(data, params_str, ref_dataset) + self.__init_from_csr(data, params_str, ref_dataset) elif isinstance(data, np.ndarray): - self._init_from_npy2d(data, params_str, ref_dataset) + self.__init_from_np2d(data, params_str, ref_dataset) else: try: csr = scipy.sparse.csr_matrix(data) - self._init_from_csr(csr) + if self.raw_data is not None: + self.raw_data = csr + self.__init_from_csr(csr) except: raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) + self.__label = None + self.__weight = None + self.__init_score = None + self.__group = None if label is not None: self.set_label(label) if weight is not None: @@ -252,55 +262,7 @@ class Dataset(object): def free_raw_data(self): self.raw_data = None - def _init_from_csr(self, csr, params_str, ref_dataset): - """ - Initialize data from a CSR matrix. - """ - if len(csr.indices) != len(csr.data): - raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) - self.handle = ctypes.c_void_p() - - ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) - ptr_data, type_ptr_data = c_float_array(csr.data) - - _safe_call(_LIB.LGBM_CreateDatasetFromCSR( - ptr_indptr, - type_ptr_indptr, - c_array(ctypes.c_int32, csr.indices), - ptr_data, - type_ptr_data, - len(csr.indptr), - len(csr.data), - csr.shape[1], - c_str(params_str), - ref_dataset, - ctypes.byref(self.handle))) - - def _init_from_csc(self, csr, params_str, ref_dataset): - """ - Initialize data from a CSC matrix. - """ - if len(csc.indices) != len(csc.data): - raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data))) - self.handle = ctypes.c_void_p() - - ptr_indptr, type_ptr_indptr = c_int_array(csc.indptr) - ptr_data, type_ptr_data = c_float_array(csc.data) - - _safe_call(_LIB.LGBM_CreateDatasetFromCSC( - ptr_indptr, - type_ptr_indptr, - c_array(ctypes.c_int32, csc.indices), - ptr_data, - type_ptr_data, - len(csc.indptr), - len(csc.data), - csc.shape[0], - c_str(params_str), - ref_dataset, - ctypes.byref(self.handle))) - - def _init_from_npy2d(self, mat, params_str, ref_dataset): + def __init_from_np2d(self, mat, params_str, ref_dataset): """ Initialize data from a 2-D numpy matrix. """ @@ -325,6 +287,30 @@ class Dataset(object): ref_dataset, ctypes.byref(self.handle))) + def __init_from_csr(self, csr, params_str, ref_dataset): + """ + Initialize data from a CSR matrix. + """ + if len(csr.indices) != len(csr.data): + raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) + self.handle = ctypes.c_void_p() + + ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) + ptr_data, type_ptr_data = c_float_array(csr.data) + + _safe_call(_LIB.LGBM_CreateDatasetFromCSR( + ptr_indptr, + type_ptr_indptr, + csr.indices.ctypes.data_as(ctypes.c_int32), + ptr_data, + type_ptr_data, + len(csr.indptr), + len(csr.data), + csr.shape[1], + c_str(params_str), + ref_dataset, + ctypes.byref(self.handle))) + def __del__(self): _safe_call(_LIB.LGBM_DatasetFree(self.handle)) @@ -371,10 +357,10 @@ class Dataset(object): if not is_numpy_1d_array(data): raise TypeError("Unknow type({})".format(type(data).__name__)) if data.dtype == np.float32: - ptr_data = c_array(ctypes.c_float, data) + ptr_data = data.ctypes.data_as(ctypes.c_float) type_data = C_API_DTYPE_FLOAT32 elif data.dtype == np.int32: - ptr_data = c_array(ctypes.c_int32, data) + ptr_data = data.ctypes.data_as(ctypes.c_int32) type_data = C_API_DTYPE_INT32 else: raise TypeError("excepted np.float32 or np.int32, met type({})".format(data.dtype)) @@ -409,6 +395,7 @@ class Dataset(object): label = list_to_1d_numpy(label, np.float32) if label.dtype != np.float32: label = label.astype(np.float32, copy=False) + self.__label = label self.set_field('label', label) def set_weight(self, weight): @@ -422,6 +409,7 @@ class Dataset(object): weight = list_to_1d_numpy(weight, np.float32) if weight.dtype != np.float32: weight = weight.astype(np.float32, copy=False) + self.__weight = weight self.set_field('weight', weight) def set_init_score(self, score): @@ -434,6 +422,7 @@ class Dataset(object): score = list_to_1d_numpy(score, np.float32) if score.dtype != np.float32: score = score.astype(np.float32, copy=False) + self.__init_score = init_score self.set_field('init_score', score) def set_group(self, group): @@ -447,6 +436,7 @@ class Dataset(object): group = list_to_1d_numpy(group, np.int32) if group.dtype != np.int32: group = group.astype(np.int32, copy=False) + self.__group = group self.set_field('group', group) def set_group_id(self, group_id): @@ -470,7 +460,9 @@ class Dataset(object): ------- label : array """ - return self.get_field('label') + if self.__label is None: + self.__label = self.get_field('label') + return self.__label def get_weight(self): """Get the weight of the Dataset. @@ -479,7 +471,9 @@ class Dataset(object): ------- weight : array """ - return self.get_field('weight') + if self.__weight is None: + self.__weight = self.get_field('weight') + return self.__weight def get_init_score(self): """Get the initial score of the Dataset. @@ -488,7 +482,20 @@ class Dataset(object): ------- init_score : array """ - return self.get_field('init_score') + if self.__init_score is None: + self.__init_score = self.get_field('init_score') + return self.__init_score + + def get_group(self): + """Get the initial score of the Dataset. + + Returns + ------- + init_score : array + """ + if self.__group is None: + self.__group = self.get_field('group') + return self.__group def num_data(self): """Get the number of rows in the Dataset. @@ -553,6 +560,9 @@ class Dataset(object): else: self._feature_names = None +C_API_PREDICT_NORMAL =0 +C_API_PREDICT_RAW_SCORE =1 +C_API_PREDICT_LEAF_INDEX =2 class Booster(object): """"A Booster of of LightGBM. @@ -560,12 +570,9 @@ class Booster(object): feature_names = None - def __init__(self, params=None, - train_set=None, - valid_sets=None, - name_valid_sets=None, - model_file=None, - fobj=None): + def __init__(self,params=None, + train_set=None, valid_sets=None, + name_valid_sets=None, model_file=None): # pylint: disable=invalid-name """Initialize the Booster. @@ -580,15 +587,17 @@ class Booster(object): name_valid_sets : List of string name of validation datasets model_file : string - Path to the model file. + Path to the model file. + If tarin_set is not None, used for continued train. + else used for loading model prediction task """ self.handle = ctypes.c_void_p() if train_set is not None: + """Training task""" if not isinstance(train_set, Dataset): raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__)) valid_handles = None - valid_cnames = None n_valid = 0 if valid_sets is not None: for valid in valid_sets: @@ -596,36 +605,364 @@ class Booster(object): raise TypeError('valid data should be Dataset instance, met{}'.format(type(valid).__name__)) valid_handles = c_array(ctypes.c_void_p, [valid.handle for valid in valid_sets]) if name_valid_sets is None: - name_valid_sets = ["valid_{}".format(x) for x in range(len(valid_sets)) ] + name_valid_sets = ["valid_{}".format(x+1) for x in range(len(valid_sets)) ] if len(valid_sets) != len(name_valid_sets): raise Exception('len of valid_sets should be equal with len of name_valid_sets') - valid_cnames = c_array(ctypes.c_char_p, [c_str(x) for x in name_valid_sets]) n_valid = len(valid_sets) ref_input_model = None params_str = dict_to_str(params) if model_file is not None: ref_input_model = c_str(model_file) """construct booster object""" - _safe_call(LIB.LGBM_BoosterCreate( + _safe_call(_LIB.LGBM_BoosterCreate( train_set.handle, valid_handles, - valid_cnames, n_valid, - params_str, + c_str(params_str), ref_input_model, ctypes.byref(self.handle))) """if need to continue train""" if model_file is not None: - self.init_continue_train(train_set) + self.__init_continue_train(train_set) if valid_sets is not None: for valid in valid_sets: - self.init_continue_train(valid) + self.__init_continue_train(valid) + """save reference to data""" + self.train_set = train_set + self.valid_sets = valid_sets + self.name_valid_sets = name_valid_sets + self.__num_dataset = 1 + n_valid + self.__training_score = None + out_len = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterGetNumClasses( + self.handle, + ctypes.byref(out_len))) + self.__num_class = out_len.value + """buffer for inner predict""" + self.__inner_predict_buffer = [None for _ in range(self.__num_dataset)] + """Get num of inner evals""" + _safe_call(_LIB.LGBM_BoosterGetEvalCounts( + self.handle, + ctypes.byref(out_len))) + self.__num_inner_eval = out_len.value + if self.__num_inner_eval > 0: + """Get name of evals""" + string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)] + ptr_string_buffers = (ctypes.c_char_p*self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) + _safe_call(_LIB.LGBM_BoosterGetEvalNames( + self.handle, + ctypes.byref(out_len), + ptr_string_buffers)) + if self.__num_inner_eval != out_len.value: + raise ValueError("size of eval names doesn't equal with num_evals") + self.__name_inner_eval = [] + for i in range(self.__num_inner_eval): + self.__name_inner_eval.append(string_buffers[i].value.decode()) elif model_file is not None: - _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(c_str(model_file), ctypes.byref(self.handle))) + """Prediction task""" + out_num_total_model = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( + c_str(model_file), + ctypes.byref(out_num_total_model), + ctypes.byref(self.handle))) + self.__num_total_model = out_num_total_model.value + out_len = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterGetNumClasses( + self.handle, + ctypes.byref(out_len))) + self.__num_class = out_len.value else: raise TypeError('At least need training dataset or model file to create booster instance') def __del__(self): - _LIB.LGBM_BoosterFree(self.handle) + _safe_call(_LIB.LGBM_BoosterFree(self.handle)) + def update(self, fobj=None): + """ + Update for one iteration + Note: for multi-class task, the score is group by class_id first, then group by row_id + if you want to get i-th row score in j-th class, the access way is score[j*num_data+i] + and you should group grad and hess in this way as well + Parameters + ---------- + fobj : function + Customized objective function. + + Returns + ------- + is_finished, bool + """ + is_finished = ctypes.c_int(0) + if fobj is None: + _safe_call(_LIB.LGBM_BoosterUpdateOneIter( + self.handle, + ctypes.byref(is_finished))) + return is_finished.value == 1 + else: + grad, hess = fobj(self.__inner_predict(0), self.train_set) + return self.boost(grad, hess) + + def boost(self, grad, hess): + """ + Boost the booster for one iteration, with customized gradient statistics. + Note: for multi-class task, the score is group by class_id first, then group by row_id + if you want to get i-th row score in j-th class, the access way is score[j*num_data+i] + and you should group grad and hess in this way as well + Parameters + ---------- + grad : 1d numpy with dtype=float32 + The first order of gradient. + hess : 1d numpy with dtype=float32 + The second order of gradient. + + Returns + ------- + is_finished, bool + """ + if not is_numpy_1d_array(grad) and not is_numpy_1d_array(hess): + raise TypeError('type of grad / hess should be 1d numpy object') + if not grad.dtype == np.float32 and not hess.dtype == np.float32: + raise TypeError('type of grad / hess should be np.float32') + if len(grad) != len(hess): + raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) + is_finished = ctypes.c_int(0) + _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( + self.handle, + grad.ctypes.data_as(ctypes.c_float), + hess.ctypes.data_as(ctypes.c_float), + ctypes.byref(is_finished))) + return is_finished.value == 1 + + def eval_train(self, feval=None): + """Evaluate for training data + + Parameters + ---------- + feval : function + Custom evaluation function. + + Returns + ------- + result: str + Evaluation result string. + """ + return self.__inner_eval("training", 0, feval) + + def eval_valid(self, feval=None): + """Evaluate for validation data + + Parameters + ---------- + feval : function + Custom evaluation function. + + Returns + ------- + result: str + Evaluation result string. + """ + ret = [] + for i in range(1, self.__num_dataset): + ret.append(self.__inner_eval(self.name_valid_sets[i-1], i, feval)) + return '\n'.join(ret) + + def save_model(self, filename, num_iteration=-1): + _safe_call(_LIB.LGBM_BoosterSaveModel( + self.handle, + num_iteration, + c_str(filename))) + + def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): + if isinstance(data, Dataset): + raise TypeError("cannot use Dataset instance for prediction, please use raw data instead") + predict_type = C_API_PREDICT_NORMAL + if raw_score: + predict_type = cC_API_PREDICT_RAW_SCORE + if pred_leaf: + predict_type = C_API_PREDICT_LEAF_INDEX + int_data_has_header = 0 + if data_has_header: + int_data_has_header = 1 + if is_str(data): + tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name + _safe_call(_LIB.LGBM_BoosterPredictForFile( + self.handle, + c_str(data), + int_data_has_header, + predict_type, + num_iteration, + c_str(tmp_pred_fname))) + lines = open(tmp_pred_fname,"r").readlines() + nrow = len(lines) + preds = [] + for line in lines: + for token in line.split('\t'): + preds.append(float(token)) + preds = np.array(preds, copy=False) + os.remove(tmp_pred_fname) + elif isinstance(data, scipy.sparse.csr_matrix): + preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type) + elif isinstance(data, np.ndarray): + preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) + else: + try: + csr = scipy.sparse.csr_matrix(data) + res = self.__pred_for_csr(csr, num_iteration, predict_type) + except: + raise TypeError('can not predict data for type {}'.format(type(data).__name__)) + if pred_leaf: + preds = preds.astype(np.int32) + if preds.size != nrow and is_reshape: + if preds.size % nrow == 0: + ncol = int(preds.size / nrow) + preds = preds.reshape(nrow, ncol) + else: + raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) ) + return preds + + def __pred_for_np2d(self, mat, num_iteration, predict_type): + """ + Predict for a 2-D numpy matrix. + """ + if len(mat.shape) != 2: + raise ValueError('Input numpy.ndarray must be 2 dimensional') + + if mat.dtype == np.float32 or mat.dtype == np.float64: + data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) + else: + """change non-float data to float data, need to copy""" + data = np.array(mat.reshape(mat.size), dtype=np.float32) + ptr_data, type_ptr_data = c_float_array(data) + n_preds = self.__num_class * mat.shape[0] + if predict_type == C_API_PREDICT_LEAF_INDEX: + if num_iteration > 0: + n_preds *= num_iteration + else: + used_iteration = self.__num_total_model / self.__num_class + n_preds *= used_iteration + preds = np.zeros(n_preds, dtype=np.float32) + out_num_preds = ctypes.c_int64(0) + _safe_call(LIB.LGBM_BoosterPredictForMat( + self.handle, + ptr_data, + type_ptr_data, + mat.shape[0], + mat.shape[1], + C_API_IS_ROW_MAJOR, + predict_type, + num_iteration, + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + )) + if n_preds != out_num_preds.value: + raise ValueError("incorrect number for predict result") + return preds, mat.shape[0] + + def __pred_for_csr(self, csr, num_iteration, predict_type): + """ + Predict for a csr data + """ + nrow = len(csr.indptr) - 1 + n_preds = self.__num_class * nrow + if predict_type == C_API_PREDICT_LEAF_INDEX: + if num_iteration > 0: + n_preds *= num_iteration + else: + used_iteration = self.__num_total_model / self.__num_class + n_preds *= used_iteration + preds = np.zeros(n_preds, dtype=np.float32) + out_num_preds = ctypes.c_int64(0) + + ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) + ptr_data, type_ptr_data = c_float_array(csr.data) + + _safe_call(LIB.LGBM_BoosterPredictForCSR( + self.handle, + ptr_indptr, + type_ptr_indptr, + csr.indices.ctypes.data_as(ctypes.c_int32), + ptr_data, + type_ptr_data, + len(csr.indptr), + len(csr.data), + csr.shape[1], + predict_type, + num_iteration, + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + )) + if n_preds != out_num_preds.value: + raise ValueError("incorrect number for predict result") + return preds, nrow + + def __inner_eval(self, data_name, data_idx, feval=None): + if data_idx >= self.__num_dataset: + raise ValueError("data_idx should be smaller than number of dataset") + ret = [] + if self.__num_inner_eval > 0: + result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float32) + out_len = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterGetEval( + self.handle, + data_idx, + ctypes.byref(out_len), + result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))) + if out_len.value != self.__num_inner_eval: + raise ValueError("incorrect number of eval results") + for i in range(self.__num_inner_eval): + ret.append('%s %s : %f' %(data_name, self.__name_inner_eval[i], result[i])) + if feval is not None: + if data_idx == 0: + cur_data = self.train_set + else: + cur_data = self.valid_sets[data_idx - 1] + feval_ret = feval(self.__inner_predict(data_idx), cur_data) + if isinstance(feval_ret, list): + for name, val in feval_ret: + ret.append('%s %s : %f' % (data_name, name, val)) + else: + name, val = feval_ret + ret.append('%s %s : %f' % (data_name, name, val)) + return '\t'.join(ret) + + def __inner_predict(self, data_idx): + if data_idx >= self.__num_dataset: + raise ValueError("data_idx should be smaller than number of dataset") + if self.__inner_predict_buffer[data_idx] is None: + if data_idx == 0: + num_data = self.train_set.num_data() * self.__num_class + else: + num_data = self.valid_sets[data_idx - 1].num_data() * self.__num_class + self.__inner_predict_buffer[data_idx] = \ + np.array([0.0 for _ in range(num_data)], dtype=np.float32, copy=False) + out_len = ctypes.c_int64(0) + data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + _safe_call(_LIB.LGBM_BoosterGetPredict( + self.handle, + data_idx, + ctypes.byref(out_len), + data_ptr)) + if out_len.value != len(self.__inner_predict_buffer[data_idx]): + raise ValueError("incorrect number of predict results for data %d" %(data_idx) ) + return self.__inner_predict_buffer[data_idx] + + + def __init_continue_train(self, dataset): + if dataset.raw_data is None: + raise ValueError("should set is_continue_train=True in dataset while need to continue train") + init_score = self.predict(dataset.raw_data, raw_score=True,data_has_header=dataset.data_has_header, is_reshape=False) + dataset.set_init_score(init_score) + dataset.free_raw_data() + + +#tmp test +train_data = Dataset('../../examples/binary_classification/binary.train') +test_data = Dataset('../../examples/binary_classification/binary.test', reference = train_data) +param = {"metric":"l2,l1"} +lgb = Booster(train_set=train_data, valid_sets=[test_data], params=param) +for i in range(100): + lgb.update() + print(lgb.eval_valid()) + print(lgb.eval_train()) +print(lgb.predict('../../examples/binary_classification/binary.train')) \ No newline at end of file diff --git a/src/application/application.cpp b/src/application/application.cpp index ef53a7adb..3a00dc44f 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -108,7 +108,7 @@ void Application::LoadData() { // prediction is needed if using input initial model(continued train) PredictFunction predict_fun = nullptr; // need to continue training - if (boosting_->NumberOfSubModels() > 0) { + if (boosting_->NumberOfTotalModel() > 0) { Predictor predictor(boosting_.get(), true, false); predict_fun = predictor.GetPredictFunction(); } @@ -235,7 +235,7 @@ void Application::Train() { void Application::Predict() { - boosting_->SetNumUsedModel(config_.io_config.num_model_predict); + boosting_->SetNumIterationForPred(config_.io_config.num_iteration_predict); // create predictor Predictor predictor(boosting_.get(), config_.io_config.is_predict_raw_score, config_.io_config.is_predict_leaf_index); diff --git a/src/boosting/dart.hpp b/src/boosting/dart.hpp index 4b2ef71c6..9df28dd6a 100644 --- a/src/boosting/dart.hpp +++ b/src/boosting/dart.hpp @@ -43,6 +43,7 @@ public: * \brief one training iteration */ bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override { + is_update_score_cur_iter_ = false; GBDT::TrainOneIter(gradient, hessian, false); // normalize Normalize(); @@ -58,20 +59,24 @@ public: * \return training score */ const score_t* GetTrainingScore(data_size_t* out_len) override { - DroppingTrees(); + if (!is_update_score_cur_iter_) { + // only drop one time in one iteration + DroppingTrees(); + is_update_score_cur_iter_ = true; + } *out_len = train_score_updater_->num_data() * num_class_; return train_score_updater_->score(); } /*! * \brief save model to file - * \param num_used_model number of model that want to save, -1 means save all + * \param num_iteration -1 means save all * \param is_finish is training finished or not * \param filename filename that want to save to */ - void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override { + void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override { // only save model once when is_finish = true if (is_finish && saved_model_size_ < 0) { - GBDT::SaveModelToFile(num_used_model, is_finish, filename); + GBDT::SaveModelToFile(num_iteration, is_finish, filename); } } /*! @@ -133,6 +138,8 @@ private: double drop_rate_; /*! \brief Random generator, used to select dropping trees */ Random random_for_drop_; + /*! \brief Flag that the score is update on current iter or not*/ + bool is_update_score_cur_iter_; }; } // namespace LightGBM diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index d34e045cb..2d7b5083c 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -16,7 +16,7 @@ namespace LightGBM { -GBDT::GBDT() : saved_model_size_(-1), num_used_model_(0) { +GBDT::GBDT() : saved_model_size_(-1), num_iteration_for_pred_(0) { } @@ -29,7 +29,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O gbdt_config_ = config; iter_ = 0; saved_model_size_ = -1; - num_used_model_ = 0; + num_iteration_for_pred_ = 0; max_feature_idx_ = 0; early_stopping_round_ = gbdt_config_->early_stopping_round; shrinkage_rate_ = gbdt_config_->learning_rate; @@ -296,24 +296,23 @@ const score_t* GBDT::GetTrainingScore(data_size_t* out_len) { return train_score_updater_->score(); } -void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const { +void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { CHECK(data_idx >= 0 && data_idx <= static_cast(valid_metrics_.size())); std::vector ret; const score_t* raw_scores = nullptr; data_size_t num_data = 0; if (data_idx == 0) { - raw_scores = train_score_updater_->score(); + raw_scores = GetTrainingScore(out_len); num_data = train_score_updater_->num_data(); } else { auto used_idx = data_idx - 1; raw_scores = valid_score_updater_[used_idx]->score(); num_data = valid_score_updater_[used_idx]->num_data(); + *out_len = num_data * num_class_; } - *out_len = num_data * num_class_; - if (num_class_ > 1) { -#pragma omp parallel for schedule(guided) +#pragma omp parallel for schedule(static) for (data_size_t i = 0; i < num_data; ++i) { std::vector tmp_result; for (int j = 0; j < num_class_; ++j) { @@ -325,12 +324,12 @@ void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) } } } else if(sigmoid_ > 0.0f){ -#pragma omp parallel for schedule(guided) +#pragma omp parallel for schedule(static) for (data_size_t i = 0; i < num_data; ++i) { out_result[i] = static_cast(1.0f / (1.0f + std::exp(-2.0f * sigmoid_ * raw_scores[i]))); } } else { -#pragma omp parallel for schedule(guided) +#pragma omp parallel for schedule(static) for (data_size_t i = 0; i < num_data; ++i) { out_result[i] = raw_scores[i]; } @@ -348,7 +347,7 @@ void GBDT::Boosting() { GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data()); } -void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filename) { +void GBDT::SaveModelToFile(int num_iteration, bool is_finish, const char* filename) { // first time to this function, open file if (saved_model_size_ < 0) { model_output_file_.open(filename); @@ -373,10 +372,11 @@ void GBDT::SaveModelToFile(int num_used_model, bool is_finish, const char* filen if (!model_output_file_.is_open()) { return; } - if (num_used_model == NO_LIMIT) { + int num_used_model = 0; + if (num_iteration == NO_LIMIT) { num_used_model = static_cast(models_.size()); } else { - num_used_model = num_used_model * num_class_; + num_used_model = num_iteration * num_class_; } int rest = num_used_model - early_stopping_round_ * num_class_; // output tree models @@ -452,7 +452,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { } } Log::Info("Finished loading %d models", models_.size()); - num_used_model_ = static_cast(models_.size()) / num_class_; + num_iteration_for_pred_ = static_cast(models_.size()) / num_class_; } std::string GBDT::FeatureImportance() const { @@ -486,7 +486,7 @@ std::string GBDT::FeatureImportance() const { std::vector GBDT::PredictRaw(const double* value) const { std::vector ret(num_class_, 0.0f); - for (int i = 0; i < num_used_model_; ++i) { + for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int j = 0; j < num_class_; ++j) { ret[j] += models_[i * num_class_ + j]->Predict(value); } @@ -496,7 +496,7 @@ std::vector GBDT::PredictRaw(const double* value) const { std::vector GBDT::Predict(const double* value) const { std::vector ret(num_class_, 0.0f); - for (int i = 0; i < num_used_model_; ++i) { + for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int j = 0; j < num_class_; ++j) { ret[j] += models_[i * num_class_ + j]->Predict(value); } @@ -512,7 +512,7 @@ std::vector GBDT::Predict(const double* value) const { std::vector GBDT::PredictLeafIndex(const double* value) const { std::vector ret; - for (int i = 0; i < num_used_model_; ++i) { + for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int j = 0; j < num_class_; ++j) { ret.push_back(models_[i * num_class_ + j]->PredictLeafIndex(value)); } diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index fbc7b6154..e7063c1b2 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -73,7 +73,7 @@ public: * \param result used to store prediction result, should allocate memory before call this function * \param out_len lenght of returned score */ - void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const override; + void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) override; /*! * \brief Predtion for one record without sigmoid transformation @@ -98,11 +98,11 @@ public: /*! * \brief save model to file - * \param num_used_model number of model that want to save, -1 means save all + * \param num_iteration -1 means save all * \param is_finish is training finished or not * \param filename filename that want to save to */ - virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override; + virtual void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override; /*! * \brief Restore from a serialized string */ @@ -119,11 +119,12 @@ public: */ inline int LabelIdx() const override { return label_idx_; } + /*! * \brief Get number of weak sub-models * \return Number of weak sub-models */ - inline int NumberOfSubModels() const override { return static_cast(models_.size()); } + inline int NumberOfTotalModel() const override { return static_cast(models_.size()); } /*! * \brief Get number of classes @@ -132,11 +133,13 @@ public: inline int NumberOfClasses() const override { return num_class_; } /*! - * \brief Set number of used model for prediction + * \brief Set number of iterations for prediction */ - inline void SetNumUsedModel(int num_used_model) { - if (num_used_model >= 0) { - num_used_model_ = static_cast(num_used_model / num_class_); + inline void SetNumIterationForPred(int num_iteration) override { + if (num_iteration > 0) { + num_iteration_for_pred_ = num_iteration; + } else { + num_iteration_for_pred_ = static_cast(models_.size()) / num_class_; } } @@ -236,7 +239,7 @@ protected: /*! \brief File to write models */ std::ofstream model_output_file_; /*! \brief number of used model */ - int num_used_model_; + int num_iteration_for_pred_; /*! \brief Shrinkage rate for one iteration */ double shrinkage_rate_; }; diff --git a/src/c_api.cpp b/src/c_api.cpp index a060c03fd..f9f6dc500 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -95,8 +95,8 @@ public: return boosting_->TrainOneIter(gradients, hessians, false); } - void PrepareForPrediction(int num_used_model, int predict_type) { - boosting_->SetNumUsedModel(num_used_model); + void PrepareForPrediction(int num_iteration, int predict_type) { + boosting_->SetNumIterationForPred(num_iteration); bool is_predict_leaf = false; bool is_raw_score = false; if (predict_type == C_API_PREDICT_LEAF_INDEX) { @@ -109,6 +109,10 @@ public: predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf)); } + void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { + boosting_->GetPredictAt(data_idx, out_result, out_len); + } + std::vector Predict(const std::vector>& features) { return predictor_->GetPredictFunction()(features); } @@ -117,8 +121,8 @@ public: predictor_->Predict(data_filename, result_filename, data_has_header); } - void SaveModelToFile(int num_used_model, const char* filename) { - boosting_->SaveModelToFile(num_used_model, true, filename); + void SaveModelToFile(int num_iteration, const char* filename) { + boosting_->SaveModelToFile(num_iteration, true, filename); } int GetEvalCounts() const { @@ -129,22 +133,25 @@ public: return ret; } - int GetEvalNames(const char*** out_strs) const { + int GetEvalNames(char** out_strs) const { int idx = 0; for (const auto& metric : train_metric_) { for (const auto& name : metric->GetName()) { - *(out_strs[idx++]) = name.c_str(); + int j = 0; + auto name_cstr = name.c_str(); + while (name_cstr[j] != '\0') { + out_strs[idx][j] = name_cstr[j]; + ++j; + } + out_strs[idx][j] = '\0'; + ++idx; } } return idx; } const Boosting* GetBoosting() const { return boosting_.get(); } - - const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); } - - const inline int NumberOfClasses() const { return boosting_->NumberOfClasses(); } - + private: std::unique_ptr boosting_; @@ -449,9 +456,12 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, + int64_t* num_total_model, BoosterHandle* out) { API_BEGIN(); - *out = new Booster(filename); + auto ret = std::unique_ptr(new Booster(filename)); + *num_total_model = static_cast(ret->GetBoosting()->NumberOfTotalModel()); + *out = ret.release(); API_END(); } @@ -461,6 +471,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { API_END(); } +DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + *out_len = ref_booster->GetBoosting()->NumberOfClasses(); + API_END(); +} + DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); @@ -501,7 +518,7 @@ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) * \brief Get number of eval * \return total number of eval result */ -DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, const char*** out_strs) { +DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); *out_len = ref_booster->GetEvalNames(out_strs); @@ -524,39 +541,27 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, API_END(); } -DllExport int LGBM_BoosterGetTrainingScore(BoosterHandle handle, - int64_t* out_len, - const float** out_result) { - API_BEGIN(); - Booster* ref_booster = reinterpret_cast(handle); - int len = 0; - *out_result = ref_booster->GetTrainingScore(&len); - *out_len = static_cast(len); - API_END(); -} - DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, int data, int64_t* out_len, float* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - auto boosting = ref_booster->GetBoosting(); int len = 0; - boosting->GetPredictAt(data, out_result, &len); + ref_booster->GetPredictAt(data, out_result, &len); *out_len = static_cast(len); API_END(); } DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, - int predict_type, - int64_t n_used_trees, - int data_has_header, const char* data_filename, + int data_has_header, + int predict_type, + int64_t num_iteration, const char* result_filename) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->PrepareForPrediction(static_cast(n_used_trees), predict_type); + ref_booster->PrepareForPrediction(static_cast(num_iteration), predict_type); bool bool_data_has_header = data_has_header > 0 ? true : false; ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header); API_END(); @@ -572,23 +577,32 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, int64_t nelem, int64_t, int predict_type, - int64_t n_used_trees, - double* out_result) { + int64_t num_iteration, + int64_t* out_len, + float* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->PrepareForPrediction(static_cast(n_used_trees), predict_type); + ref_booster->PrepareForPrediction(static_cast(num_iteration), predict_type); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); - int num_class = ref_booster->NumberOfClasses(); + int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses(); + if (predict_type == C_API_PREDICT_LEAF_INDEX) { + if (num_iteration > 0) { + num_preb_in_one_row *= static_cast(num_iteration); + } else { + num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row; + } + } int nrow = static_cast(nindptr - 1); #pragma omp parallel for schedule(guided) for (int i = 0; i < nrow; ++i) { auto one_row = get_row_fun(i); auto predicton_result = ref_booster->Predict(one_row); - for (int j = 0; j < num_class; ++j) { - out_result[i * num_class + j] = predicton_result[j]; + for (int j = 0; j < static_cast(predicton_result.size()); ++j) { + out_result[i * num_preb_in_one_row + j] = static_cast(predicton_result[j]); } } + *out_len = nrow * num_preb_in_one_row; API_END(); } @@ -599,31 +613,40 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, int32_t ncol, int is_row_major, int predict_type, - int64_t n_used_trees, - double* out_result) { + int64_t num_iteration, + int64_t* out_len, + float* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->PrepareForPrediction(static_cast(n_used_trees), predict_type); + ref_booster->PrepareForPrediction(static_cast(num_iteration), predict_type); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); - int num_class = ref_booster->NumberOfClasses(); + int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses(); + if (predict_type == C_API_PREDICT_LEAF_INDEX) { + if (num_iteration > 0) { + num_preb_in_one_row *= static_cast(num_iteration); + } else { + num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row; + } + } #pragma omp parallel for schedule(guided) for (int i = 0; i < nrow; ++i) { auto one_row = get_row_fun(i); auto predicton_result = ref_booster->Predict(one_row); - for (int j = 0; j < num_class; ++j) { - out_result[i * num_class + j] = predicton_result[j]; + for (int j = 0; j < static_cast(predicton_result.size()); ++j) { + out_result[i * num_preb_in_one_row + j] = static_cast(predicton_result[j]); } } + *out_len = nrow * num_preb_in_one_row; API_END(); } DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, - int num_used_model, + int num_iteration, const char* filename) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->SaveModelToFile(num_used_model, filename); + ref_booster->SaveModelToFile(num_iteration, filename); API_END(); } diff --git a/src/io/config.cpp b/src/io/config.cpp index 5d99be30a..ebe45eeb4 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -183,7 +183,7 @@ void IOConfig::Set(const std::unordered_map& params) { GetInt(params, "data_random_seed", &data_random_seed); GetString(params, "data", &data_filename); GetInt(params, "verbose", &verbosity); - GetInt(params, "num_model_predict", &num_model_predict); + GetInt(params, "num_iteration_predict", &num_iteration_predict); GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt); GetBool(params, "is_pre_partition", &is_pre_partition); GetBool(params, "is_enable_sparse", &is_enable_sparse); diff --git a/tests/c_api_test/test.py b/tests/c_api_test/test.py index b7db8b7fc..892ad3e6d 100644 --- a/tests/c_api_test/test.py +++ b/tests/c_api_test/test.py @@ -190,14 +190,16 @@ def test_booster(): test_free_dataset(train) test_free_dataset(test[0]) booster2 = ctypes.c_void_p() - LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(booster2)) + num_total_model = ctypes.c_long() + LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(num_total_model), ctypes.byref(booster2)) data = [] inp = open('../../examples/binary_classification/binary.test', 'r') for line in inp.readlines(): data.append( [float(x) for x in line.split('\t')[1:]] ) inp.close() mat = np.array(data) - preb = np.zeros(( mat.shape[0],1 ), dtype=np.float64) + preb = np.zeros(mat.shape[0], dtype=np.float32) + num_preb = ctypes.c_long() data = np.array(mat.reshape(mat.size), copy=False) LIB.LGBM_BoosterPredictForMat(booster2, data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), @@ -207,8 +209,9 @@ def test_booster(): 1, 1, 50, + ctypes.byref(num_preb), preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) - LIB.LGBM_BoosterPredictForFile(booster2, 1, 50, 0, c_str('../../examples/binary_classification/binary.test'), c_str('preb.txt')) + LIB.LGBM_BoosterPredictForFile(booster2,c_str('../../examples/binary_classification/binary.test'),0 , 0, 50, c_str('preb.txt')) LIB.LGBM_BoosterFree(booster2) test_dataset() From 133296828d3cc1a47d183c1ccb295a1e820851db Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 23 Nov 2016 23:28:09 +0800 Subject: [PATCH 09/60] support rollback iteration and reset config during training. --- include/LightGBM/boosting.h | 19 +++++++++++++++ include/LightGBM/c_api.h | 21 +++++++++++++++++ src/boosting/gbdt.cpp | 46 +++++++++++++++++++++++++++++++++++++ src/boosting/gbdt.h | 14 +++++++++++ src/c_api.cpp | 31 +++++++++++++++++++++++++ 5 files changed, 131 insertions(+) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index e325b789a..bcffcc5d6 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -35,6 +35,12 @@ public: const ObjectiveFunction* object_function, const std::vector& training_metrics) = 0; + /*! + * \brief Reset Config for current boosting + * \param config Configs for boosting + */ + virtual void ResetConfig(const BoostingConfig* config) = 0; + /*! * \brief Add a validation data * \param valid_data Validation data @@ -52,6 +58,19 @@ public: */ virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0; + /*! + * \brief Rollback one iteration + */ + virtual void RollbackOneIter() = 0; + + /*! + * \brief return current iteration + */ + virtual int GetCurrentIteration() const = 0; + + /*! + * \brief Eval metrics and check is met early stopping or not + */ virtual bool EvalAndCheckEarlyStopping() = 0; /*! * \brief Get evaluation result at data_idx data diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index a3aeb90a6..48316dc99 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -239,6 +239,7 @@ DllExport int LGBM_BoosterCreateFromModelfile( int64_t* out_num_total_model, BoosterHandle* out); + /*! * \brief free obj in handle * \param handle handle to be freed @@ -246,6 +247,13 @@ DllExport int LGBM_BoosterCreateFromModelfile( */ DllExport int LGBM_BoosterFree(BoosterHandle handle); +/*! +* \brief Reset config for current booster +* \param parameters format: 'key1=value1 key2=value2' +* \return 0 when success, -1 when failure happens +*/ +DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters); + /*! * \brief Get number of class * \return number of class @@ -274,6 +282,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, const float* hess, int* is_finished); +/*! +* \brief Rollback one iteration +* \param handle handle +* \return 0 when success, -1 when failure happens +*/ +DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle); + +/*! +* \brief Get iteration of current boosting rounds +* \return iteration of boosting rounds +*/ +DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration); + /*! * \brief Get number of eval * \return total number of eval result diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 2d7b5083c..ebe251404 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -36,6 +36,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O train_data_ = train_data; num_class_ = config->num_class; // create tree learner + tree_learner_.clear(); for (int i = 0; i < num_class_; ++i) { auto new_tree_learner = std::unique_ptr(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); new_tree_learner->Init(train_data_); @@ -82,6 +83,32 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O } +void GBDT::ResetConfig(const BoostingConfig* config) { + gbdt_config_ = config; + early_stopping_round_ = gbdt_config_->early_stopping_round; + shrinkage_rate_ = gbdt_config_->learning_rate; + // create tree learner + tree_learner_.clear(); + for (int i = 0; i < num_class_; ++i) { + auto new_tree_learner = std::unique_ptr(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); + new_tree_learner->Init(train_data_); + // init tree learner + tree_learner_.push_back(std::move(new_tree_learner)); + } + tree_learner_.shrink_to_fit(); + // if need bagging, create buffer + if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { + out_of_bag_data_indices_ = std::vector(num_data_); + bag_data_indices_ = std::vector(num_data_); + } else { + out_of_bag_data_cnt_ = 0; + out_of_bag_data_indices_.clear(); + bag_data_cnt_ = num_data_; + bag_data_indices_.clear(); + } + // initialize random generator + random_ = Random(gbdt_config_->bagging_seed); +} void GBDT::AddDataset(const Dataset* valid_data, const std::vector& valid_metrics) { if (iter_ > 0) { @@ -204,6 +231,25 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is } +void GBDT::RollbackOneIter() { + if (iter_ == 0) { return; } + int cur_iter = iter_ - 1; + // reset score + for (int curr_class = 0; curr_class < num_class_; ++curr_class) { + auto curr_tree = cur_iter * num_class_ + curr_class; + models_[curr_tree]->Shrinkage(-1.0); + train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); + for (auto& score_updater : valid_score_updater_) { + score_updater->AddScore(models_[curr_tree].get(), curr_class); + } + } + // remove model + for (int curr_class = 0; curr_class < num_class_; ++curr_class) { + models_.pop_back(); + } + --iter_; +} + bool GBDT::EvalAndCheckEarlyStopping() { bool is_met_early_stopping = false; // print message for metric diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index e7063c1b2..57116cf18 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -35,6 +35,13 @@ public: void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) override; + + /*! + * \brief Reset Config for current boosting + * \param config Configs for boosting + */ + void ResetConfig(const BoostingConfig* config) override; + /*! * \brief Adding a validation dataset * \param valid_data Validation dataset @@ -51,6 +58,13 @@ public: */ virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; + /*! + * \brief Rollback one iteration + */ + void RollbackOneIter() override; + + int GetCurrentIteration() const override { return iter_; } + bool EvalAndCheckEarlyStopping() override; /*! diff --git a/src/c_api.cpp b/src/c_api.cpp index f9f6dc500..03df81d11 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -150,6 +150,17 @@ public: return idx; } + void ResetBoostingConfig(const char* parameters) { + OverallConfig new_config; + new_config.LoadFromString(parameters); + config_.boosting_config = new_config.boosting_config; + boosting_->ResetConfig(&config_.boosting_config); + } + + void RollbackOneIter() { + boosting_->RollbackOneIter(); + } + const Boosting* GetBoosting() const { return boosting_.get(); } private: @@ -471,6 +482,13 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { API_END(); } +DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + ref_booster->ResetBoostingConfig(parameters); + API_END(); +} + DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); @@ -503,6 +521,19 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, API_END(); } +DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + ref_booster->RollbackOneIter(); + API_END(); +} + +DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration(); + API_END(); +} /*! * \brief Get number of eval * \return total number of eval result From 14a67b7e8c2992322caea203c92c0b477adab4a2 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 03:17:25 +0800 Subject: [PATCH 10/60] support dynamic change training data and add validation data --- include/LightGBM/bin.h | 12 ++++ include/LightGBM/boosting.h | 10 ++- include/LightGBM/c_api.h | 23 ++++-- include/LightGBM/dataset.h | 21 ++++++ include/LightGBM/feature.h | 7 ++ src/application/application.cpp | 2 +- src/boosting/gbdt.cpp | 124 +++++++++++++++++++------------- src/boosting/gbdt.h | 13 +++- src/c_api.cpp | 104 ++++++++++++++------------- tests/c_api_test/test.py | 8 +-- 10 files changed, 210 insertions(+), 114 deletions(-) diff --git a/include/LightGBM/bin.h b/include/LightGBM/bin.h index afb4b41f1..fd139d8ce 100644 --- a/include/LightGBM/bin.h +++ b/include/LightGBM/bin.h @@ -51,6 +51,18 @@ public: explicit BinMapper(const void* memory); ~BinMapper(); + bool CheckAlign(const BinMapper& other) const { + if (num_bin_ != other.num_bin_) { + return false; + } + for (int i = 0; i < num_bin_; ++i) { + if (bin_upper_bound_[i] != other.bin_upper_bound_[i]) { + return false; + } + } + return true; + } + /*! \brief Get number of bins */ inline int num_bin() const { return num_bin_; } /*! \brief True if bin is trival (contains only one bin) */ diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index bcffcc5d6..419fabd2c 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -41,12 +41,20 @@ public: */ virtual void ResetConfig(const BoostingConfig* config) = 0; + /*! + * \brief Reset training data for current boosting + * \param train_data Training data + * \param object_function Training objective function + * \param training_metrics Training metric + */ + virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) = 0; + /*! * \brief Add a validation data * \param valid_data Validation data * \param valid_metrics Metric for validation data */ - virtual void AddDataset(const Dataset* valid_data, + virtual void AddValidDataset(const Dataset* valid_data, const std::vector& valid_metrics) = 0; /*! diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 48316dc99..d93245c88 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -212,19 +212,12 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, /*! * \brief create an new boosting learner * \param train_data training data set -* \param valid_datas validation data sets -* \param valid_names names of validation data sets -* \param n_valid_datas number of validation set * \param parameters format: 'key1=value1 key2=value2' -* \param init_model_filename filename of model * \prama out handle of created Booster * \return 0 when success, -1 when failure happens */ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, - const DatesetHandle valid_datas[], - int n_valid_datas, const char* parameters, - const char* init_model_filename, BoosterHandle* out); /*! @@ -247,6 +240,22 @@ DllExport int LGBM_BoosterCreateFromModelfile( */ DllExport int LGBM_BoosterFree(BoosterHandle handle); +/*! +* \brief Add new validation to booster +* \param valid_data validation data set +* \return 0 when success, -1 when failure happens +*/ +DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, + const DatesetHandle valid_data); + +/*! +* \brief Add new validation to booster +* \param train_data training data set +* \return 0 when success, -1 when failure happens +*/ +DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, + const DatesetHandle train_data); + /*! * \brief Reset config for current booster * \param parameters format: 'key1=value1 key2=value2' diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 53235c25a..959576e7c 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -277,6 +277,27 @@ public: /*! \brief Destructor */ ~Dataset(); + bool CheckAlign(const Dataset& other) const { + if (num_features_ != other.num_features_) { + return false; + } + if (num_total_features_ != other.num_total_features_) { + return false; + } + if (num_class_ != other.num_class_) { + return false; + } + if (label_idx_ != other.label_idx_) { + return false; + } + for (int i = 0; i < num_features_; ++i) { + if (!features_[i]->CheckAlign(*(other.features_[i].get()))) { + return false; + } + } + return true; + } + inline void PushOneRow(int tid, data_size_t row_idx, const std::vector& feature_values) { for (size_t i = 0; i < feature_values.size() && i < static_cast(num_total_features_); ++i) { int feature_idx = used_feature_map_[i]; diff --git a/include/LightGBM/feature.h b/include/LightGBM/feature.h index 9ede59654..c3c8b8b28 100644 --- a/include/LightGBM/feature.h +++ b/include/LightGBM/feature.h @@ -63,6 +63,13 @@ public: ~Feature() { } + bool CheckAlign(const Feature& other) const { + if (feature_index_ != other.feature_index_) { + return false; + } + return bin_mapper_->CheckAlign(*(other.bin_mapper_.get())); + } + /*! * \brief Push one record, will auto convert to bin and push to bin data * \param tid Thread id diff --git a/src/application/application.cpp b/src/application/application.cpp index 3a00dc44f..b922f9699 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -207,7 +207,7 @@ void Application::InitTrain() { Common::ConstPtrInVectorWrapper(train_metric_)); // add validation data into boosting for (size_t i = 0; i < valid_datas_.size(); ++i) { - boosting_->AddDataset(valid_datas_[i].get(), + boosting_->AddValidDataset(valid_datas_[i].get(), Common::ConstPtrInVectorWrapper(valid_metrics_[i])); } Log::Info("Finished initializing training"); diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index ebe251404..e93e980f5 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -16,7 +16,10 @@ namespace LightGBM { -GBDT::GBDT() : saved_model_size_(-1), num_iteration_for_pred_(0) { +GBDT::GBDT() + :saved_model_size_(-1), + num_iteration_for_pred_(0), + num_init_iteration_(0) { } @@ -33,51 +36,9 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O max_feature_idx_ = 0; early_stopping_round_ = gbdt_config_->early_stopping_round; shrinkage_rate_ = gbdt_config_->learning_rate; - train_data_ = train_data; num_class_ = config->num_class; - // create tree learner - tree_learner_.clear(); - for (int i = 0; i < num_class_; ++i) { - auto new_tree_learner = std::unique_ptr(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); - new_tree_learner->Init(train_data_); - // init tree learner - tree_learner_.push_back(std::move(new_tree_learner)); - } - tree_learner_.shrink_to_fit(); - object_function_ = object_function; - // push training metrics - for (const auto& metric : training_metrics) { - training_metrics_.push_back(metric); - } - training_metrics_.shrink_to_fit(); - // create score tracker - train_score_updater_.reset(new ScoreUpdater(train_data_, num_class_)); - num_data_ = train_data_->num_data(); - // create buffer for gradients and hessians - if (object_function_ != nullptr) { - gradients_ = std::vector(num_data_ * num_class_); - hessians_ = std::vector(num_data_ * num_class_); - } - sigmoid_ = -1.0f; - if (object_function_ != nullptr - && std::string(object_function_->GetName()) == std::string("binary")) { - // only binary classification need sigmoid transform - sigmoid_ = gbdt_config_->sigmoid; - } - // get max feature index - max_feature_idx_ = train_data_->num_total_features() - 1; - // get label index - label_idx_ = train_data_->label_idx(); - // if need bagging, create buffer - if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { - out_of_bag_data_indices_ = std::vector(num_data_); - bag_data_indices_ = std::vector(num_data_); - } else { - out_of_bag_data_cnt_ = 0; - out_of_bag_data_indices_.clear(); - bag_data_cnt_ = num_data_; - bag_data_indices_.clear(); - } + train_data_ = nullptr; + ResetTrainingData(train_data, object_function, training_metrics); // initialize random generator random_ = Random(gbdt_config_->bagging_seed); @@ -109,13 +70,79 @@ void GBDT::ResetConfig(const BoostingConfig* config) { // initialize random generator random_ = Random(gbdt_config_->bagging_seed); } -void GBDT::AddDataset(const Dataset* valid_data, + +void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) { + if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) { + Log::Fatal("cannot reset training data, since new training data has different bin mappers"); + } + train_data_ = train_data; + // create tree learner + tree_learner_.clear(); + for (int i = 0; i < num_class_; ++i) { + auto new_tree_learner = std::unique_ptr(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); + new_tree_learner->Init(train_data_); + // init tree learner + tree_learner_.push_back(std::move(new_tree_learner)); + } + tree_learner_.shrink_to_fit(); + object_function_ = object_function; + // push training metrics + training_metrics_.clear(); + for (const auto& metric : training_metrics) { + training_metrics_.push_back(metric); + } + training_metrics_.shrink_to_fit(); + // create score tracker + train_score_updater_.reset(new ScoreUpdater(train_data_, num_class_)); + num_data_ = train_data_->num_data(); + // create buffer for gradients and hessians + if (object_function_ != nullptr) { + gradients_ = std::vector(num_data_ * num_class_); + hessians_ = std::vector(num_data_ * num_class_); + } + sigmoid_ = -1.0f; + if (object_function_ != nullptr + && std::string(object_function_->GetName()) == std::string("binary")) { + // only binary classification need sigmoid transform + sigmoid_ = gbdt_config_->sigmoid; + } + // get max feature index + max_feature_idx_ = train_data_->num_total_features() - 1; + // get label index + label_idx_ = train_data_->label_idx(); + // if need bagging, create buffer + if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { + out_of_bag_data_indices_ = std::vector(num_data_); + bag_data_indices_ = std::vector(num_data_); + } else { + out_of_bag_data_cnt_ = 0; + out_of_bag_data_indices_.clear(); + bag_data_cnt_ = num_data_; + bag_data_indices_.clear(); + } + // update score + for (int i = 0; i < iter_; ++i) { + for (int curr_class = 0; curr_class < num_class_; ++curr_class) { + auto curr_tree = i * num_class_ + curr_class; + train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); + } + } +} + +void GBDT::AddValidDataset(const Dataset* valid_data, const std::vector& valid_metrics) { - if (iter_ > 0) { - Log::Fatal("Cannot add validation data after training started"); + if (!train_data_->CheckAlign(*valid_data)) { + Log::Fatal("cannot add validation data, since it has different bin mappers with training data"); } // for a validation dataset, we need its score and metric auto new_score_updater = std::unique_ptr(new ScoreUpdater(valid_data, num_class_)); + // update score + for (int i = 0; i < iter_; ++i) { + for (int curr_class = 0; curr_class < num_class_; ++curr_class) { + auto curr_tree = i * num_class_ + curr_class; + new_score_updater->AddScore(models_[curr_tree].get(), curr_class); + } + } valid_score_updater_.push_back(std::move(new_score_updater)); valid_metrics_.emplace_back(); if (early_stopping_round_ > 0) { @@ -499,6 +526,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { } Log::Info("Finished loading %d models", models_.size()); num_iteration_for_pred_ = static_cast(models_.size()) / num_class_; + num_init_iteration_ = num_iteration_for_pred_; } std::string GBDT::FeatureImportance() const { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 57116cf18..7c1f456a9 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -42,12 +42,20 @@ public: */ void ResetConfig(const BoostingConfig* config) override; + /*! + * \brief Reset training data for current boosting + * \param train_data Training data + * \param object_function Training objective function + * \param training_metrics Training metric + */ + void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) override; + /*! * \brief Adding a validation dataset * \param valid_data Validation dataset * \param valid_metrics Metrics for validation dataset */ - void AddDataset(const Dataset* valid_data, + void AddValidDataset(const Dataset* valid_data, const std::vector& valid_metrics) override; /*! * \brief Training logic @@ -63,7 +71,7 @@ public: */ void RollbackOneIter() override; - int GetCurrentIteration() const override { return iter_; } + int GetCurrentIteration() const override { return iter_ + num_init_iteration_; } bool EvalAndCheckEarlyStopping() override; @@ -256,6 +264,7 @@ protected: int num_iteration_for_pred_; /*! \brief Shrinkage rate for one iteration */ double shrinkage_rate_; + int num_init_iteration_; }; } // namespace LightGBM diff --git a/src/c_api.cpp b/src/c_api.cpp index 03df81d11..c41d7ee2e 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -28,9 +28,7 @@ public: } Booster(const Dataset* train_data, - std::vector valid_data, - const char* parameters) - :train_data_(train_data), valid_datas_(valid_data) { + const char* parameters) { config_.LoadFromString(parameters); // create boosting if (config_.io_config.input_model.size() > 0) { @@ -38,6 +36,17 @@ public: please use continued train with input score"); } boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "")); + ConstructObjectAndTrainingMetrics(train_data); + // initialize the boosting + boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), + Common::ConstPtrInVectorWrapper(train_metric_)); + } + + ~Booster() { + + } + + void ConstructObjectAndTrainingMetrics(const Dataset* train_data) { // create objective function objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, config_.objective_config)); @@ -45,48 +54,39 @@ public: Log::Warning("Using self-defined objective functions"); } // create training metric + train_metric_.clear(); for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr( Metric::CreateMetric(metric_type, config_.metric_config)); if (metric == nullptr) { continue; } - metric->Init(train_data_->metadata(), train_data_->num_data()); + metric->Init(train_data->metadata(), train_data->num_data()); train_metric_.push_back(std::move(metric)); } train_metric_.shrink_to_fit(); - // add metric for validation data - for (size_t i = 0; i < valid_datas_.size(); ++i) { - valid_metrics_.emplace_back(); - for (auto metric_type : config_.metric_types) { - auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); - if (metric == nullptr) { continue; } - metric->Init(valid_datas_[i]->metadata(), valid_datas_[i]->num_data()); - valid_metrics_.back().push_back(std::move(metric)); - } - valid_metrics_.back().shrink_to_fit(); - } - valid_metrics_.shrink_to_fit(); // initialize the objective function if (objective_fun_ != nullptr) { - objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); + objective_fun_->Init(train_data->metadata(), train_data->num_data()); } + } + + void ResetTrainingData(const Dataset* train_data) { + ConstructObjectAndTrainingMetrics(train_data); // initialize the boosting - boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(), - Common::ConstPtrInVectorWrapper(train_metric_)); - // add validation data into boosting - for (size_t i = 0; i < valid_datas_.size(); ++i) { - boosting_->AddDataset(valid_datas_[i], - Common::ConstPtrInVectorWrapper(valid_metrics_[i])); + boosting_->ResetTrainingData(train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); + } + + void AddValidData(const Dataset* valid_data) { + valid_metrics_.emplace_back(); + for (auto metric_type : config_.metric_types) { + auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); + if (metric == nullptr) { continue; } + metric->Init(valid_data->metadata(), valid_data->num_data()); + valid_metrics_.back().push_back(std::move(metric)); } + valid_metrics_.back().shrink_to_fit(); + boosting_->AddValidDataset(valid_data, + Common::ConstPtrInVectorWrapper(valid_metrics_.back())); } - - void LoadModelFromFile(const char* filename) { - Boosting::LoadFileToBoosting(boosting_.get(), filename); - } - - ~Booster() { - - } - bool TrainOneIter() { return boosting_->TrainOneIter(nullptr, nullptr, false); } @@ -151,9 +151,7 @@ public: } void ResetBoostingConfig(const char* parameters) { - OverallConfig new_config; - new_config.LoadFromString(parameters); - config_.boosting_config = new_config.boosting_config; + config_.LoadFromString(parameters); boosting_->ResetConfig(&config_.boosting_config); } @@ -164,14 +162,9 @@ public: const Boosting* GetBoosting() const { return boosting_.get(); } private: - std::unique_ptr boosting_; /*! \brief All configs */ OverallConfig config_; - /*! \brief Training data */ - const Dataset* train_data_; - /*! \brief Validation data */ - std::vector valid_datas_; /*! \brief Metric for training data */ std::vector> train_metric_; /*! \brief Metrics for validation data */ @@ -446,21 +439,11 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, // ---- start of booster DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, - const DatesetHandle valid_datas[], - int n_valid_datas, const char* parameters, - const char* init_model_filename, BoosterHandle* out) { API_BEGIN(); const Dataset* p_train_data = reinterpret_cast(train_data); - std::vector p_valid_datas; - for (int i = 0; i < n_valid_datas; ++i) { - p_valid_datas.emplace_back(reinterpret_cast(valid_datas[i])); - } - auto ret = std::unique_ptr(new Booster(p_train_data, p_valid_datas, parameters)); - if (init_model_filename != nullptr) { - ret->LoadModelFromFile(init_model_filename); - } + auto ret = std::unique_ptr(new Booster(p_train_data, parameters)); *out = ret.release(); API_END(); } @@ -482,6 +465,25 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { API_END(); } + +DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, + const DatesetHandle valid_data) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + const Dataset* p_dataset = reinterpret_cast(valid_data); + ref_booster->AddValidData(p_dataset); + API_END(); +} + +DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, + const DatesetHandle train_data) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + const Dataset* p_dataset = reinterpret_cast(train_data); + ref_booster->ResetTrainingData(p_dataset); + API_END(); +} + DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); diff --git a/tests/c_api_test/test.py b/tests/c_api_test/test.py index 892ad3e6d..b4690eaca 100644 --- a/tests/c_api_test/test.py +++ b/tests/c_api_test/test.py @@ -174,10 +174,10 @@ def test_dataset(): test_free_dataset(train) def test_booster(): train = test_load_from_mat('../../examples/binary_classification/binary.train', None) - test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)] + test = test_load_from_mat('../../examples/binary_classification/binary.test', train) booster = ctypes.c_void_p() - LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), - len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster)) + LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster)) + LIB.LGBM_BoosterAddValidData(booster, test) is_finished = ctypes.c_int(0) for i in range(100): LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) @@ -188,7 +188,7 @@ def test_booster(): LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterFree(booster) test_free_dataset(train) - test_free_dataset(test[0]) + test_free_dataset(test) booster2 = ctypes.c_void_p() num_total_model = ctypes.c_long() LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(num_total_model), ctypes.byref(booster2)) From 4accb9d485caf428332a00127bda865be7c1b4f5 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 03:35:09 +0800 Subject: [PATCH 11/60] support merge two booster --- include/LightGBM/boosting.h | 6 ++++++ include/LightGBM/c_api.h | 15 ++++++++++++++- include/LightGBM/tree.h | 4 ---- src/boosting/gbdt.h | 8 ++++++++ src/c_api.cpp | 12 ++++++++++++ 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 419fabd2c..337291826 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -35,6 +35,11 @@ public: const ObjectiveFunction* object_function, const std::vector& training_metrics) = 0; + /*! + * \brief Merge model from other boosting object + * \param other + */ + virtual void MergeFrom(const Boosting* other) = 0; /*! * \brief Reset Config for current boosting * \param config Configs for boosting @@ -179,6 +184,7 @@ public: Boosting(const Boosting&) = delete; static void LoadFileToBoosting(Boosting* boosting, const char* filename); + /*! * \brief Create boosting object * \param type Type of boosting diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index d93245c88..cff905a9b 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -240,8 +240,18 @@ DllExport int LGBM_BoosterCreateFromModelfile( */ DllExport int LGBM_BoosterFree(BoosterHandle handle); +/*! +* \brief Merge model in two booster to first handle +* \param handle handle, will merge other handle to this +* \param other_handle +* \return 0 when success, -1 when failure happens +*/ +DllExport int LGBM_BoosterMerge(BoosterHandle handle, + BoosterHandle other_handle); + /*! * \brief Add new validation to booster +* \param handle handle * \param valid_data validation data set * \return 0 when success, -1 when failure happens */ @@ -249,7 +259,8 @@ DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, const DatesetHandle valid_data); /*! -* \brief Add new validation to booster +* \brief Reset training data for booster +* \param handle handle * \param train_data training data set * \return 0 when success, -1 when failure happens */ @@ -258,6 +269,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, /*! * \brief Reset config for current booster +* \param handle handle * \param parameters format: 'key1=value1 key2=value2' * \return 0 when success, -1 when failure happens */ @@ -265,6 +277,7 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param /*! * \brief Get number of class +* \param handle handle * \return number of class */ DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len); diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 16320fbd0..712a907cd 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -101,10 +101,6 @@ public: /*! \brief Serialize this object by string*/ std::string ToString(); - /*! \brief Disable copy */ - Tree& operator=(const Tree&) = delete; - /*! \brief Disable copy */ - Tree(const Tree&) = delete; private: /*! * \brief Find leaf index of which record belongs by data diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 7c1f456a9..4aa8cff92 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -36,6 +36,14 @@ public: const std::vector& training_metrics) override; + void MergeFrom(const Boosting* other) override { + auto other_gbdt = reinterpret_cast(other); + for (const auto& tree : other_gbdt->models_) { + auto new_tree = std::unique_ptr(new Tree(*(tree.get()))); + models_.push_back(std::move(new_tree)); + } + } + /*! * \brief Reset Config for current boosting * \param config Configs for boosting diff --git a/src/c_api.cpp b/src/c_api.cpp index c41d7ee2e..44d5e31d7 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -42,6 +42,10 @@ public: Common::ConstPtrInVectorWrapper(train_metric_)); } + void MergeFrom(const Booster* other) { + boosting_->MergeFrom(other->boosting_.get()); + } + ~Booster() { } @@ -465,6 +469,14 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { API_END(); } +DllExport int LGBM_BoosterMerge(BoosterHandle handle, + BoosterHandle other_handle) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + Booster* ref_other_booster = reinterpret_cast(other_handle); + ref_booster->MergeFrom(ref_other_booster); + API_END(); +} DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, const DatesetHandle valid_data) { From 5b4ee9db601e027af433a972bad6acfcd1c45306 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 12:18:14 +0800 Subject: [PATCH 12/60] support set/get dataset field with nullptr --- src/c_api.cpp | 3 ++- src/io/dataset.cpp | 12 ++---------- src/io/metadata.cpp | 28 ++++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 44d5e31d7..50af27e62 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -419,7 +419,8 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, *out_type = C_API_DTYPE_INT32; is_success = true; } - if (!is_success) { throw std::runtime_error("Field not found or not exist"); } + if (!is_success) { throw std::runtime_error("Field not found"); } + if (*out_ptr == nullptr) { *out_len = 0; } API_END(); } diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 3024d1177..8391dbe8b 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -101,11 +101,7 @@ bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const floa } else { return false; } - if (*out_ptr != nullptr) { - return true; - } else { - return false; - } + return true; } bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) { @@ -117,11 +113,7 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** } else { return false; } - if (*out_ptr != nullptr) { - return true; - } else { - return false; - } + return true; } void Dataset::SaveBinaryFile(const char* bin_filename) { diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index b61f3395f..9ceb5f26d 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -196,6 +196,12 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector Date: Thu, 24 Nov 2016 12:58:46 +0800 Subject: [PATCH 13/60] more flexiable reset config/training data logic for boosting --- include/LightGBM/boosting.h | 8 +--- include/LightGBM/config.h | 4 +- src/boosting/gbdt.cpp | 42 ++++--------------- src/boosting/gbdt.h | 8 +--- src/c_api.cpp | 80 ++++++++++++++++++++++--------------- src/io/config.cpp | 6 +-- 6 files changed, 64 insertions(+), 84 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 337291826..551085787 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -40,19 +40,15 @@ public: * \param other */ virtual void MergeFrom(const Boosting* other) = 0; - /*! - * \brief Reset Config for current boosting - * \param config Configs for boosting - */ - virtual void ResetConfig(const BoostingConfig* config) = 0; /*! * \brief Reset training data for current boosting + * \param config Configs for boosting * \param train_data Training data * \param object_function Training objective function * \param training_metrics Training metric */ - virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) = 0; + virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) = 0; /*! * \brief Add a validation data diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index ea968177f..19e4b3364 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -72,6 +72,8 @@ public: inline bool GetBool( const std::unordered_map& params, const std::string& name, bool* out); + + static std::unordered_map Str2Map(const char* parameters); }; /*! \brief Types of boosting */ @@ -231,7 +233,7 @@ public: MetricConfig metric_config; void Set(const std::unordered_map& params) override; - void LoadFromString(const char* str); + private: void GetBoostingType(const std::unordered_map& params); diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index e93e980f5..1158e5394 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -29,52 +29,23 @@ GBDT::~GBDT() { void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) { - gbdt_config_ = config; iter_ = 0; saved_model_size_ = -1; num_iteration_for_pred_ = 0; max_feature_idx_ = 0; - early_stopping_round_ = gbdt_config_->early_stopping_round; - shrinkage_rate_ = gbdt_config_->learning_rate; num_class_ = config->num_class; train_data_ = nullptr; - ResetTrainingData(train_data, object_function, training_metrics); - // initialize random generator - random_ = Random(gbdt_config_->bagging_seed); - + ResetTrainingData(config, train_data, object_function, training_metrics); } -void GBDT::ResetConfig(const BoostingConfig* config) { - gbdt_config_ = config; - early_stopping_round_ = gbdt_config_->early_stopping_round; - shrinkage_rate_ = gbdt_config_->learning_rate; - // create tree learner - tree_learner_.clear(); - for (int i = 0; i < num_class_; ++i) { - auto new_tree_learner = std::unique_ptr(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); - new_tree_learner->Init(train_data_); - // init tree learner - tree_learner_.push_back(std::move(new_tree_learner)); - } - tree_learner_.shrink_to_fit(); - // if need bagging, create buffer - if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { - out_of_bag_data_indices_ = std::vector(num_data_); - bag_data_indices_ = std::vector(num_data_); - } else { - out_of_bag_data_cnt_ = 0; - out_of_bag_data_indices_.clear(); - bag_data_cnt_ = num_data_; - bag_data_indices_.clear(); - } - // initialize random generator - random_ = Random(gbdt_config_->bagging_seed); -} - -void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) { +void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, + const std::vector& training_metrics) { if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) { Log::Fatal("cannot reset training data, since new training data has different bin mappers"); } + gbdt_config_ = config; + early_stopping_round_ = gbdt_config_->early_stopping_round; + shrinkage_rate_ = gbdt_config_->learning_rate; train_data_ = train_data; // create tree learner tree_learner_.clear(); @@ -120,6 +91,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* bag_data_cnt_ = num_data_; bag_data_indices_.clear(); } + random_ = Random(gbdt_config_->bagging_seed); // update score for (int i = 0; i < iter_; ++i) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 4aa8cff92..933ebc1da 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -44,19 +44,13 @@ public: } } - /*! - * \brief Reset Config for current boosting - * \param config Configs for boosting - */ - void ResetConfig(const BoostingConfig* config) override; - /*! * \brief Reset training data for current boosting * \param train_data Training data * \param object_function Training objective function * \param training_metrics Training metric */ - void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) override; + void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) override; /*! * \brief Adding a validation dataset diff --git a/src/c_api.cpp b/src/c_api.cpp index 50af27e62..8bf2d4b2e 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -29,7 +29,8 @@ public: Booster(const Dataset* train_data, const char* parameters) { - config_.LoadFromString(parameters); + auto param = ConfigBase::Str2Map(parameters); + config_.Set(param); // create boosting if (config_.io_config.input_model.size() > 0) { Log::Warning("continued train from model is not support for c_api, \ @@ -74,9 +75,23 @@ public: } void ResetTrainingData(const Dataset* train_data) { - ConstructObjectAndTrainingMetrics(train_data); + train_data_ = train_data; + ConstructObjectAndTrainingMetrics(train_data_); // initialize the boosting - boosting_->ResetTrainingData(train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); + boosting_->ResetTrainingData(&config_.boosting_config, train_data_, + objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); + } + + void ResetConfig(const char* parameters) { + auto param = ConfigBase::Str2Map(parameters); + if (param.count("num_class")) { + Log::Fatal("cannot change num class during training"); + } + if (param.count("boosting_type")) { + Log::Fatal("cannot change boosting_type during training"); + } + config_.Set(param); + ResetTrainingData(train_data_); } void AddValidData(const Dataset* valid_data) { @@ -154,10 +169,6 @@ public: return idx; } - void ResetBoostingConfig(const char* parameters) { - config_.LoadFromString(parameters); - boosting_->ResetConfig(&config_.boosting_config); - } void RollbackOneIter() { boosting_->RollbackOneIter(); @@ -166,6 +177,7 @@ public: const Boosting* GetBoosting() const { return boosting_.get(); } private: + const Dataset* train_data_; std::unique_ptr boosting_; /*! \brief All configs */ OverallConfig config_; @@ -193,9 +205,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, const DatesetHandle* reference, DatesetHandle* out) { API_BEGIN(); - OverallConfig config; - config.LoadFromString(parameters); - DatasetLoader loader(config.io_config, nullptr); + auto param = ConfigBase::Str2Map(parameters); + IOConfig io_config; + io_config.Set(param); + DatasetLoader loader(io_config, nullptr); loader.SetHeader(filename); if (reference == nullptr) { *out = loader.LoadFromFile(filename); @@ -224,15 +237,16 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, const DatesetHandle* reference, DatesetHandle* out) { API_BEGIN(); - OverallConfig config; - config.LoadFromString(parameters); - DatasetLoader loader(config.io_config, nullptr); + auto param = ConfigBase::Str2Map(parameters); + IOConfig io_config; + io_config.Set(param); + DatasetLoader loader(io_config, nullptr); std::unique_ptr ret; auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); if (reference == nullptr) { // sample data first - Random rand(config.io_config.data_random_seed); - const int sample_cnt = static_cast(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); + Random rand(io_config.data_random_seed); + const int sample_cnt = static_cast(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt); std::vector> sample_values(ncol); for (size_t i = 0; i < sample_indices.size(); ++i) { @@ -246,10 +260,10 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, } ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); } else { - ret.reset(new Dataset(nrow, config.io_config.num_class)); + ret.reset(new Dataset(nrow, io_config.num_class)); ret->CopyFeatureMapperFrom( reinterpret_cast(*reference), - config.io_config.is_enable_sparse); + io_config.is_enable_sparse); } #pragma omp parallel for schedule(guided) @@ -275,16 +289,17 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, const DatesetHandle* reference, DatesetHandle* out) { API_BEGIN(); - OverallConfig config; - config.LoadFromString(parameters); - DatasetLoader loader(config.io_config, nullptr); + auto param = ConfigBase::Str2Map(parameters); + IOConfig io_config; + io_config.Set(param); + DatasetLoader loader(io_config, nullptr); std::unique_ptr ret; auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); int32_t nrow = static_cast(nindptr - 1); if (reference == nullptr) { // sample data first - Random rand(config.io_config.data_random_seed); - const int sample_cnt = static_cast(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); + Random rand(io_config.data_random_seed); + const int sample_cnt = static_cast(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt); std::vector> sample_values; for (size_t i = 0; i < sample_indices.size(); ++i) { @@ -307,10 +322,10 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, CHECK(num_col >= static_cast(sample_values.size())); ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); } else { - ret.reset(new Dataset(nrow, config.io_config.num_class)); + ret.reset(new Dataset(nrow, io_config.num_class)); ret->CopyFeatureMapperFrom( reinterpret_cast(*reference), - config.io_config.is_enable_sparse); + io_config.is_enable_sparse); } #pragma omp parallel for schedule(guided) @@ -336,17 +351,18 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, const DatesetHandle* reference, DatesetHandle* out) { API_BEGIN(); - OverallConfig config; - config.LoadFromString(parameters); - DatasetLoader loader(config.io_config, nullptr); + auto param = ConfigBase::Str2Map(parameters); + IOConfig io_config; + io_config.Set(param); + DatasetLoader loader(io_config, nullptr); std::unique_ptr ret; auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem); int32_t nrow = static_cast(num_row); if (reference == nullptr) { Log::Warning("Construct from CSC format is not efficient"); // sample data first - Random rand(config.io_config.data_random_seed); - const int sample_cnt = static_cast(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt); + Random rand(io_config.data_random_seed); + const int sample_cnt = static_cast(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt); std::vector> sample_values(ncol_ptr - 1); #pragma omp parallel for schedule(guided) @@ -356,10 +372,10 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, } ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow)); } else { - ret.reset(new Dataset(nrow, config.io_config.num_class)); + ret.reset(new Dataset(nrow, io_config.num_class)); ret->CopyFeatureMapperFrom( reinterpret_cast(*reference), - config.io_config.is_enable_sparse); + io_config.is_enable_sparse); } #pragma omp parallel for schedule(guided) @@ -500,7 +516,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->ResetBoostingConfig(parameters); + ref_booster->ResetConfig(parameters); API_END(); } diff --git a/src/io/config.cpp b/src/io/config.cpp index ebe45eeb4..114a4b004 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -10,9 +10,9 @@ namespace LightGBM { -void OverallConfig::LoadFromString(const char* str) { +std::unordered_map ConfigBase::Str2Map(const char* parameters) { std::unordered_map params; - auto args = Common::Split(str, " \t\n\r"); + auto args = Common::Split(parameters, " \t\n\r"); for (auto arg : args) { std::vector tmp_strs = Common::Split(arg.c_str(), '='); if (tmp_strs.size() == 2) { @@ -27,7 +27,7 @@ void OverallConfig::LoadFromString(const char* str) { } } ParameterAlias::KeyAliasTransform(¶ms); - Set(params); + return params; } void OverallConfig::Set(const std::unordered_map& params) { From 629fc047e2559cab1508338b005009f1f85cb5fc Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 16:53:17 +0800 Subject: [PATCH 14/60] more flexity python basic object --- include/LightGBM/boosting.h | 1 + python-package/lightgbm/basic.py | 737 ++++++++++++++++++------------- src/boosting/gbdt.cpp | 67 +-- src/boosting/gbdt.h | 17 + src/c_api.cpp | 19 +- 5 files changed, 483 insertions(+), 358 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 551085787..31e0526b8 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -37,6 +37,7 @@ public: /*! * \brief Merge model from other boosting object + Will insert to the front of current boosting object * \param other */ virtual void MergeFrom(const Boosting* other) = 0; diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 1aef75fe0..b482ee296 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -126,16 +126,27 @@ C_API_DTYPE_INT64 =3 """Matric is row major in python""" C_API_IS_ROW_MAJOR =1 +C_API_PREDICT_NORMAL =0 +C_API_PREDICT_RAW_SCORE =1 +C_API_PREDICT_LEAF_INDEX =2 + +FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32, +"wegiht":C_API_DTYPE_FLOAT32, +"init_score":C_API_DTYPE_FLOAT32, +"group_id":C_API_DTYPE_INT32, +"group":C_API_DTYPE_INT32, + } + def c_float_array(data): """Convert numpy array / list to c float array.""" if isinstance(data, list): data = np.array(data, copy=False) if is_numpy_1d_array(data): if data.dtype == np.float32: - ptr_data = data.ctypes.data_as(ctypes.c_float) + ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) type_data = C_API_DTYPE_FLOAT32 elif data.dtype == np.float64: - ptr_data = data.ctypes.data_as(ctypes.c_double) + ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) type_data = C_API_DTYPE_FLOAT64 else: raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype)) @@ -149,10 +160,10 @@ def c_int_array(data): data = np.array(data, copy=False) if is_numpy_1d_array(data): if data.dtype == np.int32: - ptr_data = data.ctypes.data_as(ctypes.c_int32) + ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) type_data = C_API_DTYPE_INT32 elif data.dtype == np.int64: - ptr_data = data.ctypes.data_as(ctypes.c_int64) + ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)) type_data = C_API_DTYPE_INT64 else: raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype)) @@ -160,19 +171,188 @@ def c_int_array(data): raise TypeError("Unknow type({})".format(type(data).__name__)) return (ptr_data, type_data) +class Predictor(object): + """"A Predictor of LightGBM. + """ + def __init__(self,model_file=None, params=None, booster_handle=None, is_manage_handle=True): + # pylint: disable=invalid-name + """Initialize the Booster. + + Parameters + ---------- + model_file : string + Path to the model file. + params : dict + Parameters for boosters. + """ + self.handle = ctypes.c_void_p() + self.__is_manage_handle = True + if model_file is not None: + """Prediction task""" + out_num_total_model = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( + c_str(model_file), + ctypes.byref(out_num_total_model), + ctypes.byref(self.handle))) + self.__num_total_model = out_num_total_model.value + tmp_out_len = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterGetNumClasses( + self.handle, + ctypes.byref(tmp_out_len))) + self.num_class = tmp_out_len.value + elif booster_handle is not None: + self.__is_manage_handle = is_manage_handle + self.handle = booster_handle + tmp_out_len = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterGetNumClasses( + self.handle, + ctypes.byref(tmp_out_len))) + self.num_class = tmp_out_len.value + _safe_call(_LIB.LGBM_BoosterGetCurrentIteration( + self.handle, + ctypes.byref(tmp_out_len))) + self.__num_total_model = self.num_class * tmp_out_len.value + else: + raise TypeError('Need Model file to create a booster') + + def __del__(self): + if self.__is_manage_handle: + _safe_call(_LIB.LGBM_BoosterFree(self.handle)) + + + def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): + if isinstance(data, Dataset): + raise TypeError("cannot use Dataset instance for prediction, please use raw data instead") + predict_type = C_API_PREDICT_NORMAL + if raw_score: + predict_type = C_API_PREDICT_RAW_SCORE + if pred_leaf: + predict_type = C_API_PREDICT_LEAF_INDEX + int_data_has_header = 0 + if data_has_header: + int_data_has_header = 1 + if is_str(data): + tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name + _safe_call(_LIB.LGBM_BoosterPredictForFile( + self.handle, + c_str(data), + int_data_has_header, + predict_type, + num_iteration, + c_str(tmp_pred_fname))) + lines = open(tmp_pred_fname,"r").readlines() + nrow = len(lines) + preds = [] + for line in lines: + for token in line.split('\t'): + preds.append(float(token)) + preds = np.array(preds, copy=False) + os.remove(tmp_pred_fname) + elif isinstance(data, scipy.sparse.csr_matrix): + preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type) + elif isinstance(data, np.ndarray): + preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) + else: + try: + csr = scipy.sparse.csr_matrix(data) + res = self.__pred_for_csr(csr, num_iteration, predict_type) + except: + raise TypeError('can not predict data for type {}'.format(type(data).__name__)) + if pred_leaf: + preds = preds.astype(np.int32) + if preds.size != nrow and is_reshape: + if preds.size % nrow == 0: + ncol = int(preds.size / nrow) + preds = preds.reshape(nrow, ncol) + else: + raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) ) + return preds + + def __pred_for_np2d(self, mat, num_iteration, predict_type): + """ + Predict for a 2-D numpy matrix. + """ + if len(mat.shape) != 2: + raise ValueError('Input numpy.ndarray must be 2 dimensional') + + if mat.dtype == np.float32 or mat.dtype == np.float64: + data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) + else: + """change non-float data to float data, need to copy""" + data = np.array(mat.reshape(mat.size), dtype=np.float32) + ptr_data, type_ptr_data = c_float_array(data) + n_preds = self.num_class * mat.shape[0] + if predict_type == C_API_PREDICT_LEAF_INDEX: + if num_iteration > 0: + n_preds *= num_iteration + else: + used_iteration = self.__num_total_model / self.num_class + n_preds *= used_iteration + preds = np.zeros(n_preds, dtype=np.float32) + out_num_preds = ctypes.c_int64(0) + _safe_call(LIB.LGBM_BoosterPredictForMat( + self.handle, + ptr_data, + type_ptr_data, + mat.shape[0], + mat.shape[1], + C_API_IS_ROW_MAJOR, + predict_type, + num_iteration, + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + )) + if n_preds != out_num_preds.value: + raise ValueError("incorrect number for predict result") + return preds, mat.shape[0] + + def __pred_for_csr(self, csr, num_iteration, predict_type): + """ + Predict for a csr data + """ + nrow = len(csr.indptr) - 1 + n_preds = self.num_class * nrow + if predict_type == C_API_PREDICT_LEAF_INDEX: + if num_iteration > 0: + n_preds *= num_iteration + else: + used_iteration = self.__num_total_model / self.num_class + n_preds *= used_iteration + preds = np.zeros(n_preds, dtype=np.float32) + out_num_preds = ctypes.c_int64(0) + + ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) + ptr_data, type_ptr_data = c_float_array(csr.data) + + _safe_call(LIB.LGBM_BoosterPredictForCSR( + self.handle, + ptr_indptr, + type_ptr_indptr, + csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + ptr_data, + type_ptr_data, + len(csr.indptr), + len(csr.data), + csr.shape[1], + predict_type, + num_iteration, + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + )) + if n_preds != out_num_preds.value: + raise ValueError("incorrect number for predict result") + return preds, nrow + + class Dataset(object): """Dataset used in LightGBM. Dataset is a internal data structure that used by LightGBM - You can construct Dataset from numpy.arrays """ - _feature_names = None - - def __init__(self, data, max_bin=255, reference=None, - label=None, weight=None, group_id=None, - silent=False, feature_names=None, - other_params=None, is_continue_train=False): + def __init__(self, data, label=None, max_bin=255, reference=None, + weight=None, group_id=None, predictor=None, + silent=False, params=None): """ Dataset used in LightGBM. @@ -181,41 +361,35 @@ class Dataset(object): data : string/numpy array/scipy.sparse Data source of Dataset. When data is string type, it represents the path of txt file, + label : list or numpy 1-D array, optional + Label of the data max_bin : int, required max number of discrete bin for features reference : Other Dataset, optional If this dataset validation, need to use training data as reference - label : list or numpy 1-D array, optional - Label of the training data. weight : list or numpy 1-D array , optional Weight for each instance. group_id : list or numpy 1-D array , optional group/query id for each instance. Note: if having group/query id, data should group by this id silent : boolean, optional Whether print messages during construction - feature_names : list, optional - Set names for features. - other_params: dict, optional + params: dict, optional other parameters """ - if data is None: self.handle = None return - """save raw data for continue train """ - if is_continue_train: - self.raw_data = data - else: - self.raw_data = None self.data_has_header = False """process for args""" - params = {} + if params is None: + params = {} + self.max_bin = max_bin + self.predictor = predictor params["max_bin"] = max_bin if silent: params["verbose"] = 0 - if other_params: - other_params.update(params) - params = other_params + else: + params["verbose"] = 1 params_str = dict_to_str(params) """process for reference dataset""" ref_dataset = None @@ -228,7 +402,7 @@ class Dataset(object): """check data has header or not""" if "has_header" in params or "header" in params: if params["has_header"].lower() == "true" or params["header"].lower() == "true": - data_has_header = True + self.data_has_header = True self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_CreateDatasetFromFile( c_str(data), @@ -242,8 +416,6 @@ class Dataset(object): else: try: csr = scipy.sparse.csr_matrix(data) - if self.raw_data is not None: - self.raw_data = csr self.__init_from_csr(csr) except: raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) @@ -253,14 +425,52 @@ class Dataset(object): self.__group = None if label is not None: self.set_label(label) + if self.get_label() is None: + raise ValueError("label should not be None") if weight is not None: self.set_weight(weight) if group_id is not None: self.set_group_id(group_id) - self.feature_names = feature_names + # load init score + if self.predictor is not None and isinstance(self.predictor, Predictor): + init_score = self.predictor.predict(data, + raw_score=True, + data_has_header=self.data_has_header, + is_reshape=False) + if self.predictor.num_class > 1: + # need re group init score + new_init_score = np.zeros(init_score.size(), dtype=np.float32) + num_data = self.num_data() + for i in range(num_data): + for j in range(self.predictor.num_class): + new_init_score[j * num_data + i] = init_score[i * self.predictor.num_class + j] + init_score = new_init_score + self.set_init_score(init_score) - def free_raw_data(self): - self.raw_data = None + def new_valid_dataset(self, data, label=None, weight=None, group_id=None, + silent=False, params=None): + """ + Create validation data align with current dataset + + Parameters + ---------- + data : string/numpy array/scipy.sparse + Data source of Dataset. + When data is string type, it represents the path of txt file, + label : list or numpy 1-D array, optional + Label of the training data. + weight : list or numpy 1-D array , optional + Weight for each instance. + group_id : list or numpy 1-D array , optional + group/query id for each instance. Note: if having group/query id, data should group by this id + silent : boolean, optional + Whether print messages during construction + other_params: dict, optional + other parameters + """ + return Dataset(data, label=label, max_bin=self.max_bin, reference=self, + weight=weight, group_id=group_id, predictor=self.predictor, + silent=silent, params=params) def __init_from_np2d(self, mat, params_str, ref_dataset): """ @@ -301,7 +511,7 @@ class Dataset(object): _safe_call(_LIB.LGBM_CreateDatasetFromCSR( ptr_indptr, type_ptr_indptr, - csr.indices.ctypes.data_as(ctypes.c_int32), + csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, type_ptr_data, len(csr.indptr), @@ -327,19 +537,23 @@ class Dataset(object): info : array a numpy array of information of the data """ - out_len = ctypes.c_int32() + tmp_out_len = ctypes.c_int64() out_type = ctypes.c_int32() ret = ctypes.POINTER(ctypes.c_void_p)() _safe_call(_LIB.LGBM_DatasetGetField( self.handle, c_str(field_name), - ctypes.byref(out_len), + ctypes.byref(tmp_out_len), ctypes.byref(ret), ctypes.byref(out_type))) + if out_type.value != FIELD_TYPE_MAPPER[field_name]: + raise TypeError("Return type error for get_field") + if tmp_out_len.value == 0: + return None if out_type.value == C_API_DTYPE_INT32: - return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(c_int32), out_len.value)) + return cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value) elif out_type.value == C_API_DTYPE_FLOAT32: - return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(c_float), out_len.value)) + return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value) else: raise TypeError("unknow type") @@ -351,19 +565,29 @@ class Dataset(object): field_name: str The field name of the information - data: numpy array or list + data: numpy array or list or None The array ofdata to be set """ + if data is None: + _safe_call(_LIB.LGBM_DatasetSetField( + self.handle, + c_str(field_name), + None, + 0, + FIELD_TYPE_MAPPER[field_name])) + return if not is_numpy_1d_array(data): raise TypeError("Unknow type({})".format(type(data).__name__)) if data.dtype == np.float32: - ptr_data = data.ctypes.data_as(ctypes.c_float) + ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) type_data = C_API_DTYPE_FLOAT32 elif data.dtype == np.int32: - ptr_data = data.ctypes.data_as(ctypes.c_int32) + ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) type_data = C_API_DTYPE_INT32 else: raise TypeError("excepted np.float32 or np.int32, met type({})".format(data.dtype)) + if type_data != FIELD_TYPE_MAPPER[field_name]: + raise TypeError("type error for set_field") _safe_call(_LIB.LGBM_DatasetSetField( self.handle, c_str(field_name), @@ -406,9 +630,10 @@ class Dataset(object): weight : array like Weight for each data point """ - weight = list_to_1d_numpy(weight, np.float32) - if weight.dtype != np.float32: - weight = weight.astype(np.float32, copy=False) + if weight is not None: + weight = list_to_1d_numpy(weight, np.float32) + if weight.dtype != np.float32: + weight = weight.astype(np.float32, copy=False) self.__weight = weight self.set_field('weight', weight) @@ -419,10 +644,11 @@ class Dataset(object): score: array like """ - score = list_to_1d_numpy(score, np.float32) - if score.dtype != np.float32: - score = score.astype(np.float32, copy=False) - self.__init_score = init_score + if score is not None: + score = list_to_1d_numpy(score, np.float32) + if score.dtype != np.float32: + score = score.astype(np.float32, copy=False) + self.__init_score = score self.set_field('init_score', score) def set_group(self, group): @@ -433,9 +659,10 @@ class Dataset(object): group : array like Group size of each group """ - group = list_to_1d_numpy(group, np.int32) - if group.dtype != np.int32: - group = group.astype(np.int32, copy=False) + if group is not None: + group = list_to_1d_numpy(group, np.int32) + if group.dtype != np.int32: + group = group.astype(np.int32, copy=False) self.__group = group self.set_field('group', group) @@ -448,9 +675,10 @@ class Dataset(object): group : array like group_id of Dataset (used for ranking). """ - group_id = list_to_1d_numpy(group_id, np.int32) - if group_id.dtype != np.int32: - group_id = group_id.astype(np.int32, copy=False) + if group_id is not None: + group_id = list_to_1d_numpy(group_id, np.int32) + if group_id.dtype != np.int32: + group_id = group_id.astype(np.int32, copy=False) self.set_field('group_id', group_id) def get_label(self): @@ -462,6 +690,8 @@ class Dataset(object): """ if self.__label is None: self.__label = self.get_field('label') + if self.__label is None: + raise TypeError("label should not be None") return self.__label def get_weight(self): @@ -521,58 +751,11 @@ class Dataset(object): ctypes.byref(ret))) return ret.value - @property - def feature_names(self): - """Get feature names (column labels). - - Returns - ------- - feature_names : list - """ - if self._feature_names is None: - self._feature_names = ['Column_{0}'.format(i) for i in range(self.num_col())] - return self._feature_names - - @feature_names.setter - def feature_names(self, feature_names): - """Set feature names (column labels). - - Parameters - ---------- - feature_names : list - Labels for features - """ - if feature_names is not None: - # validate feature name - if not isinstance(feature_names, list): - feature_names = list(feature_names) - if len(feature_names) != len(set(feature_names)): - raise ValueError('feature_names must be unique') - if len(feature_names) != self.num_col(): - msg = 'feature_names must have the same length as data' - raise ValueError(msg) - # prohibit to use symbols may affect to parse. e.g. []< - if not all(isinstance(f, STRING_TYPES) and - not any(x in f for x in set(('[', ']', '<'))) - for f in feature_names): - raise ValueError('feature_names may not contain [, ] or <') - self._feature_names = feature_names - else: - self._feature_names = None - -C_API_PREDICT_NORMAL =0 -C_API_PREDICT_RAW_SCORE =1 -C_API_PREDICT_LEAF_INDEX =2 class Booster(object): """"A Booster of of LightGBM. """ - - feature_names = None - - def __init__(self,params=None, - train_set=None, valid_sets=None, - name_valid_sets=None, model_file=None): + def __init__(self, params=None, train_set=None, model_file=None, silent=False): # pylint: disable=invalid-name """Initialize the Booster. @@ -582,83 +765,46 @@ class Booster(object): Parameters for boosters. train_set : Dataset training dataset - valid_sets : List of Dataset or None - validation datasets - name_valid_sets : List of string - name of validation datasets model_file : string Path to the model file. - If tarin_set is not None, used for continued train. - else used for loading model prediction task """ self.handle = ctypes.c_void_p() + self.__need_reload_eval_info = True + self.__is_manage_handle = True + if params is None: + params = {} + if silent: + params["verbose"] = 0 + else: + params["verbose"] = 1 if train_set is not None: """Training task""" if not isinstance(train_set, Dataset): raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__)) - - valid_handles = None - n_valid = 0 - if valid_sets is not None: - for valid in valid_sets: - if not isinstance(valid, Dataset): - raise TypeError('valid data should be Dataset instance, met{}'.format(type(valid).__name__)) - valid_handles = c_array(ctypes.c_void_p, [valid.handle for valid in valid_sets]) - if name_valid_sets is None: - name_valid_sets = ["valid_{}".format(x+1) for x in range(len(valid_sets)) ] - if len(valid_sets) != len(name_valid_sets): - raise Exception('len of valid_sets should be equal with len of name_valid_sets') - n_valid = len(valid_sets) - ref_input_model = None params_str = dict_to_str(params) - if model_file is not None: - ref_input_model = c_str(model_file) """construct booster object""" _safe_call(_LIB.LGBM_BoosterCreate( train_set.handle, - valid_handles, - n_valid, c_str(params_str), - ref_input_model, ctypes.byref(self.handle))) - """if need to continue train""" - if model_file is not None: - self.__init_continue_train(train_set) - if valid_sets is not None: - for valid in valid_sets: - self.__init_continue_train(valid) """save reference to data""" self.train_set = train_set - self.valid_sets = valid_sets - self.name_valid_sets = name_valid_sets - self.__num_dataset = 1 + n_valid - self.__training_score = None - out_len = ctypes.c_int64(0) + self.valid_sets = [] + self.name_valid_sets = [] + self.__num_dataset = 1 + self.init_predictor = train_set.predictor + if self.init_predictor is not None: + _safe_call(_LIB.LGBM_BoosterMerge( + self.handle, + self.init_predictor.handle)) + out_num_class = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetNumClasses( self.handle, - ctypes.byref(out_len))) - self.__num_class = out_len.value + ctypes.byref(out_num_class))) + self.__num_class = out_num_class.value """buffer for inner predict""" - self.__inner_predict_buffer = [None for _ in range(self.__num_dataset)] - """Get num of inner evals""" - _safe_call(_LIB.LGBM_BoosterGetEvalCounts( - self.handle, - ctypes.byref(out_len))) - self.__num_inner_eval = out_len.value - if self.__num_inner_eval > 0: - """Get name of evals""" - string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)] - ptr_string_buffers = (ctypes.c_char_p*self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) - _safe_call(_LIB.LGBM_BoosterGetEvalNames( - self.handle, - ctypes.byref(out_len), - ptr_string_buffers)) - if self.__num_inner_eval != out_len.value: - raise ValueError("size of eval names doesn't equal with num_evals") - self.__name_inner_eval = [] - for i in range(self.__num_inner_eval): - self.__name_inner_eval.append(string_buffers[i].value.decode()) - + self.__inner_predict_buffer = [None] + self.__get_eval_info() elif model_file is not None: """Prediction task""" out_num_total_model = ctypes.c_int64(0) @@ -667,18 +813,40 @@ class Booster(object): ctypes.byref(out_num_total_model), ctypes.byref(self.handle))) self.__num_total_model = out_num_total_model.value - out_len = ctypes.c_int64(0) + out_num_class = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetNumClasses( self.handle, - ctypes.byref(out_len))) - self.__num_class = out_len.value + ctypes.byref(out_num_class))) + self.__num_class = out_num_class.value else: raise TypeError('At least need training dataset or model file to create booster instance') def __del__(self): - _safe_call(_LIB.LGBM_BoosterFree(self.handle)) + if self.handle is not None and self.__is_manage_handle: + _safe_call(_LIB.LGBM_BoosterFree(self.handle)) - def update(self, fobj=None): + def add_valid_data(self, data, name): + if data.predictor is not self.init_predictor: + raise Exception("Add validation data failed, you should use same predictor for these data") + _safe_call(_LIB.LGBM_BoosterAddValidData( + self.handle, + data.handle)) + self.valid_sets.append(data) + self.name_valid_sets.append(name) + self.__num_dataset += 1 + + def ResetParameter(self, params, silent=False): + self.__need_reload_eval_info = True + if silent: + params["verbose"] = 0 + else: + params["verbose"] = 1 + params_str = dict_to_str(params) + _safe_call(_LIB.LGBM_BoosterResetParameter( + self.handle, + c_str(params_str))) + + def update(self, train_set=None, fobj=None): """ Update for one iteration Note: for multi-class task, the score is group by class_id first, then group by row_id @@ -686,6 +854,7 @@ class Booster(object): and you should group grad and hess in this way as well Parameters ---------- + train_set : training data, None means use last training data fobj : function Customized objective function. @@ -693,6 +862,15 @@ class Booster(object): ------- is_finished, bool """ + """need reset training data""" + if train_set is not None and train_set is not self.train_set: + if train_set.predictor is not self.init_predictor: + raise Exception("Replace training data failed, you should use same predictor for these data") + self.train_set = train_set + _safe_call(_LIB.LGBM_BoosterResetTrainingData( + self.handle, + self.train_set.handle)) + self.__inner_predict_buffer[0] = None is_finished = ctypes.c_int(0) if fobj is None: _safe_call(_LIB.LGBM_BoosterUpdateOneIter( @@ -701,9 +879,9 @@ class Booster(object): return is_finished.value == 1 else: grad, hess = fobj(self.__inner_predict(0), self.train_set) - return self.boost(grad, hess) + return self.__boost(grad, hess) - def boost(self, grad, hess): + def __boost(self, grad, hess): """ Boost the booster for one iteration, with customized gradient statistics. Note: for multi-class task, the score is group by class_id first, then group by row_id @@ -729,11 +907,53 @@ class Booster(object): is_finished = ctypes.c_int(0) _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( self.handle, - grad.ctypes.data_as(ctypes.c_float), - hess.ctypes.data_as(ctypes.c_float), + grad.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)), + hess.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)), ctypes.byref(is_finished))) return is_finished.value == 1 + def rollback_one_iter(self): + _safe_call(_LIB.LGBM_BoosterRollbackOneIter( + self.handle)) + + def current_iteration(self): + out_cur_iter = ctypes.c_int64(0) + _safe_call(_LIB.LGBM_BoosterGetCurrentIteration( + self.handle, + ctypes.byref(out_cur_iter))) + return out_cur_iter.value + + def eval(self, data, name, feval=None): + """Evaluate for data + + Parameters + ---------- + data : Dataset object + name : name of data + feval : function + Custom evaluation function. + Returns + ------- + result: str + Evaluation result string. + """ + if not isinstance(data, Dataset): + raise TypeError("Can only eval for Dataset instance") + data_idx = -1 + if data is self.train_set: + data_idx = 0 + else: + for i in range(len(self.valid_sets)): + if data is self.valid_sets[i]: + data_idx = i + 1 + break + """need push new valid data""" + if data_idx == -1: + self.add_valid_data(data, name) + data_idx = self.__num_dataset - 1 + + return self.__inner_eval(name, data_idx, feval) + def eval_train(self, feval=None): """Evaluate for training data @@ -774,141 +994,28 @@ class Booster(object): c_str(filename))) def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): - if isinstance(data, Dataset): - raise TypeError("cannot use Dataset instance for prediction, please use raw data instead") - predict_type = C_API_PREDICT_NORMAL - if raw_score: - predict_type = cC_API_PREDICT_RAW_SCORE - if pred_leaf: - predict_type = C_API_PREDICT_LEAF_INDEX - int_data_has_header = 0 - if data_has_header: - int_data_has_header = 1 - if is_str(data): - tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name - _safe_call(_LIB.LGBM_BoosterPredictForFile( - self.handle, - c_str(data), - int_data_has_header, - predict_type, - num_iteration, - c_str(tmp_pred_fname))) - lines = open(tmp_pred_fname,"r").readlines() - nrow = len(lines) - preds = [] - for line in lines: - for token in line.split('\t'): - preds.append(float(token)) - preds = np.array(preds, copy=False) - os.remove(tmp_pred_fname) - elif isinstance(data, scipy.sparse.csr_matrix): - preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type) - elif isinstance(data, np.ndarray): - preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) - else: - try: - csr = scipy.sparse.csr_matrix(data) - res = self.__pred_for_csr(csr, num_iteration, predict_type) - except: - raise TypeError('can not predict data for type {}'.format(type(data).__name__)) - if pred_leaf: - preds = preds.astype(np.int32) - if preds.size != nrow and is_reshape: - if preds.size % nrow == 0: - ncol = int(preds.size / nrow) - preds = preds.reshape(nrow, ncol) - else: - raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) ) - return preds + predictor = Predictor(booster_handle=self.handle, is_manage_handle=False) + return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape) - def __pred_for_np2d(self, mat, num_iteration, predict_type): - """ - Predict for a 2-D numpy matrix. - """ - if len(mat.shape) != 2: - raise ValueError('Input numpy.ndarray must be 2 dimensional') - - if mat.dtype == np.float32 or mat.dtype == np.float64: - data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) - else: - """change non-float data to float data, need to copy""" - data = np.array(mat.reshape(mat.size), dtype=np.float32) - ptr_data, type_ptr_data = c_float_array(data) - n_preds = self.__num_class * mat.shape[0] - if predict_type == C_API_PREDICT_LEAF_INDEX: - if num_iteration > 0: - n_preds *= num_iteration - else: - used_iteration = self.__num_total_model / self.__num_class - n_preds *= used_iteration - preds = np.zeros(n_preds, dtype=np.float32) - out_num_preds = ctypes.c_int64(0) - _safe_call(LIB.LGBM_BoosterPredictForMat( - self.handle, - ptr_data, - type_ptr_data, - mat.shape[0], - mat.shape[1], - C_API_IS_ROW_MAJOR, - predict_type, - num_iteration, - ctypes.byref(out_num_preds), - preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) - )) - if n_preds != out_num_preds.value: - raise ValueError("incorrect number for predict result") - return preds, mat.shape[0] - - def __pred_for_csr(self, csr, num_iteration, predict_type): - """ - Predict for a csr data - """ - nrow = len(csr.indptr) - 1 - n_preds = self.__num_class * nrow - if predict_type == C_API_PREDICT_LEAF_INDEX: - if num_iteration > 0: - n_preds *= num_iteration - else: - used_iteration = self.__num_total_model / self.__num_class - n_preds *= used_iteration - preds = np.zeros(n_preds, dtype=np.float32) - out_num_preds = ctypes.c_int64(0) - - ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) - ptr_data, type_ptr_data = c_float_array(csr.data) - - _safe_call(LIB.LGBM_BoosterPredictForCSR( - self.handle, - ptr_indptr, - type_ptr_indptr, - csr.indices.ctypes.data_as(ctypes.c_int32), - ptr_data, - type_ptr_data, - len(csr.indptr), - len(csr.data), - csr.shape[1], - predict_type, - num_iteration, - ctypes.byref(out_num_preds), - preds.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) - )) - if n_preds != out_num_preds.value: - raise ValueError("incorrect number for predict result") - return preds, nrow + def to_predictor(self): + predictor = Predictor(booster_handle=self.handle, is_manage_handle=True) + self.__is_manage_handle = False + return predictor def __inner_eval(self, data_name, data_idx, feval=None): if data_idx >= self.__num_dataset: raise ValueError("data_idx should be smaller than number of dataset") + self.__get_eval_info() ret = [] if self.__num_inner_eval > 0: result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float32) - out_len = ctypes.c_int64(0) + tmp_out_len = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetEval( self.handle, data_idx, - ctypes.byref(out_len), + ctypes.byref(tmp_out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))) - if out_len.value != self.__num_inner_eval: + if tmp_out_len.value != self.__num_inner_eval: raise ValueError("incorrect number of eval results") for i in range(self.__num_inner_eval): ret.append('%s %s : %f' %(data_name, self.__name_inner_eval[i], result[i])) @@ -936,33 +1043,37 @@ class Booster(object): num_data = self.valid_sets[data_idx - 1].num_data() * self.__num_class self.__inner_predict_buffer[data_idx] = \ np.array([0.0 for _ in range(num_data)], dtype=np.float32, copy=False) - out_len = ctypes.c_int64(0) + tmp_out_len = ctypes.c_int64(0) data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float)) _safe_call(_LIB.LGBM_BoosterGetPredict( self.handle, data_idx, - ctypes.byref(out_len), + ctypes.byref(tmp_out_len), data_ptr)) - if out_len.value != len(self.__inner_predict_buffer[data_idx]): + if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): raise ValueError("incorrect number of predict results for data %d" %(data_idx) ) return self.__inner_predict_buffer[data_idx] - - def __init_continue_train(self, dataset): - if dataset.raw_data is None: - raise ValueError("should set is_continue_train=True in dataset while need to continue train") - init_score = self.predict(dataset.raw_data, raw_score=True,data_has_header=dataset.data_has_header, is_reshape=False) - dataset.set_init_score(init_score) - dataset.free_raw_data() - - -#tmp test -train_data = Dataset('../../examples/binary_classification/binary.train') -test_data = Dataset('../../examples/binary_classification/binary.test', reference = train_data) -param = {"metric":"l2,l1"} -lgb = Booster(train_set=train_data, valid_sets=[test_data], params=param) -for i in range(100): - lgb.update() - print(lgb.eval_valid()) - print(lgb.eval_train()) -print(lgb.predict('../../examples/binary_classification/binary.train')) \ No newline at end of file + def __get_eval_info(self): + if self.__need_reload_eval_info: + self.__need_reload_eval_info = False + out_num_eval = ctypes.c_int64(0) + """Get num of inner evals""" + _safe_call(_LIB.LGBM_BoosterGetEvalCounts( + self.handle, + ctypes.byref(out_num_eval))) + self.__num_inner_eval = out_num_eval.value + if self.__num_inner_eval > 0: + """Get name of evals""" + tmp_out_len = ctypes.c_int64(0) + string_buffers = [ctypes.create_string_buffer(255) for i in range(self.__num_inner_eval)] + ptr_string_buffers = (ctypes.c_char_p*self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) + _safe_call(_LIB.LGBM_BoosterGetEvalNames( + self.handle, + ctypes.byref(tmp_out_len), + ptr_string_buffers)) + if self.__num_inner_eval != tmp_out_len.value: + raise ValueError("size of eval names doesn't equal with num_evals") + self.__name_inner_eval = [] + for i in range(self.__num_inner_eval): + self.__name_inner_eval.append(string_buffers[i].value.decode()) diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 1158e5394..61155f08a 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ gbdt_config_ = config; early_stopping_round_ = gbdt_config_->early_stopping_round; shrinkage_rate_ = gbdt_config_->learning_rate; - train_data_ = train_data; + random_ = Random(gbdt_config_->bagging_seed); // create tree learner tree_learner_.clear(); for (int i = 0; i < num_class_; ++i) { auto new_tree_learner = std::unique_ptr(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); - new_tree_learner->Init(train_data_); + new_tree_learner->Init(train_data); // init tree learner tree_learner_.push_back(std::move(new_tree_learner)); } @@ -63,42 +63,45 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ training_metrics_.push_back(metric); } training_metrics_.shrink_to_fit(); - // create score tracker - train_score_updater_.reset(new ScoreUpdater(train_data_, num_class_)); - num_data_ = train_data_->num_data(); - // create buffer for gradients and hessians - if (object_function_ != nullptr) { - gradients_ = std::vector(num_data_ * num_class_); - hessians_ = std::vector(num_data_ * num_class_); - } sigmoid_ = -1.0f; if (object_function_ != nullptr && std::string(object_function_->GetName()) == std::string("binary")) { // only binary classification need sigmoid transform sigmoid_ = gbdt_config_->sigmoid; } - // get max feature index - max_feature_idx_ = train_data_->num_total_features() - 1; - // get label index - label_idx_ = train_data_->label_idx(); - // if need bagging, create buffer - if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { - out_of_bag_data_indices_ = std::vector(num_data_); - bag_data_indices_ = std::vector(num_data_); - } else { - out_of_bag_data_cnt_ = 0; - out_of_bag_data_indices_.clear(); - bag_data_cnt_ = num_data_; - bag_data_indices_.clear(); - } - random_ = Random(gbdt_config_->bagging_seed); - // update score - for (int i = 0; i < iter_; ++i) { - for (int curr_class = 0; curr_class < num_class_; ++curr_class) { - auto curr_tree = i * num_class_ + curr_class; - train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); + if (train_data_ != train_data) { + // not same training data, need reset score and others + // create score tracker + train_score_updater_.reset(new ScoreUpdater(train_data, num_class_)); + // update score + for (int i = 0; i < iter_; ++i) { + for (int curr_class = 0; curr_class < num_class_; ++curr_class) { + auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class; + train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); + } + } + num_data_ = train_data->num_data(); + // create buffer for gradients and hessians + if (object_function_ != nullptr) { + gradients_ = std::vector(num_data_ * num_class_); + hessians_ = std::vector(num_data_ * num_class_); + } + // get max feature index + max_feature_idx_ = train_data->num_total_features() - 1; + // get label index + label_idx_ = train_data->label_idx(); + // if need bagging, create buffer + if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { + out_of_bag_data_indices_ = std::vector(num_data_); + bag_data_indices_ = std::vector(num_data_); + } else { + out_of_bag_data_cnt_ = 0; + out_of_bag_data_indices_.clear(); + bag_data_cnt_ = num_data_; + bag_data_indices_.clear(); } } + train_data_ = train_data; } void GBDT::AddValidDataset(const Dataset* valid_data, @@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data, // update score for (int i = 0; i < iter_; ++i) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) { - auto curr_tree = i * num_class_ + curr_class; + auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class; new_score_updater->AddScore(models_[curr_tree].get(), curr_class); } } @@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is void GBDT::RollbackOneIter() { if (iter_ == 0) { return; } - int cur_iter = iter_ - 1; + int cur_iter = iter_ + num_init_iteration_ - 1; // reset score for (int curr_class = 0; curr_class < num_class_; ++curr_class) { auto curr_tree = cur_iter * num_class_ + curr_class; diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 933ebc1da..6e92f6dbd 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -36,12 +36,28 @@ public: const std::vector& training_metrics) override; + /*! + * \brief Merge model from other boosting object + Will insert to the front of current boosting object + * \param other + */ void MergeFrom(const Boosting* other) override { auto other_gbdt = reinterpret_cast(other); + // tmp move to other vector + auto original_models = std::move(models_); + models_ = std::vector>(); + // push model from other first for (const auto& tree : other_gbdt->models_) { auto new_tree = std::unique_ptr(new Tree(*(tree.get()))); models_.push_back(std::move(new_tree)); } + num_init_iteration_ = static_cast(models_.size()) / num_class_; + // push model in current object + for (const auto& tree : original_models) { + auto new_tree = std::unique_ptr(new Tree(*(tree.get()))); + models_.push_back(std::move(new_tree)); + } + num_iteration_for_pred_ = static_cast(models_.size()) / num_class_; } /*! @@ -266,6 +282,7 @@ protected: int num_iteration_for_pred_; /*! \brief Shrinkage rate for one iteration */ double shrinkage_rate_; + /*! \brief Number of loaded initial models */ int num_init_iteration_; }; diff --git a/src/c_api.cpp b/src/c_api.cpp index 8bf2d4b2e..3ede904fd 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -36,7 +36,7 @@ public: Log::Warning("continued train from model is not support for c_api, \ please use continued train with input score"); } - boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "")); + boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr)); ConstructObjectAndTrainingMetrics(train_data); // initialize the boosting boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), @@ -114,6 +114,10 @@ public: return boosting_->TrainOneIter(gradients, hessians, false); } + void RollbackOneIter() { + boosting_->RollbackOneIter(); + } + void PrepareForPrediction(int num_iteration, int predict_type) { boosting_->SetNumIterationForPred(num_iteration); bool is_predict_leaf = false; @@ -156,24 +160,13 @@ public: int idx = 0; for (const auto& metric : train_metric_) { for (const auto& name : metric->GetName()) { - int j = 0; - auto name_cstr = name.c_str(); - while (name_cstr[j] != '\0') { - out_strs[idx][j] = name_cstr[j]; - ++j; - } - out_strs[idx][j] = '\0'; + std::strcpy(out_strs[idx], name.c_str()); ++idx; } } return idx; } - - void RollbackOneIter() { - boosting_->RollbackOneIter(); - } - const Boosting* GetBoosting() const { return boosting_.get(); } private: From 92351659a7f39bed7fe20b67ce4b27f26501cd77 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 17:00:54 +0800 Subject: [PATCH 15/60] quick fix for bug of negative index --- src/io/sparse_bin.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/sparse_bin.hpp b/src/io/sparse_bin.hpp index f379a3292..68de393e2 100644 --- a/src/io/sparse_bin.hpp +++ b/src/io/sparse_bin.hpp @@ -279,7 +279,7 @@ inline VAL_T SparseBinIterator::InnerGet(data_size_t idx) { while (cur_pos_ < idx && i_delta_ < bin_data_->num_vals_) { bin_data_->NextNonzero(&i_delta_, &cur_pos_); } - if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_) { + if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_ && i_delta_ >= 0) { return bin_data_->vals_[i_delta_]; } else { return 0; From b8d9372efd2fbd3254d13fe02df589ad7b293262 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 18:21:22 +0800 Subject: [PATCH 16/60] Add simple test --- .travis.yml | 4 +++- python-package/lightgbm/basic.py | 12 ++++++------ tests/python_package_test/test_basic.py | 24 ++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) create mode 100644 tests/python_package_test/test_basic.py diff --git a/.travis.yml b/.travis.yml index 87add84d5..08b411101 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,9 +21,11 @@ script: - cd $TRAVIS_BUILD_DIR - mkdir build && cd build && cmake .. && make -j - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py +- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py - cd $TRAVIS_BUILD_DIR - rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j -- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py +- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py +- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py notifications: email: false diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index b482ee296..d5524a86d 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -290,7 +290,7 @@ class Predictor(object): n_preds *= used_iteration preds = np.zeros(n_preds, dtype=np.float32) out_num_preds = ctypes.c_int64(0) - _safe_call(LIB.LGBM_BoosterPredictForMat( + _safe_call(_LIB.LGBM_BoosterPredictForMat( self.handle, ptr_data, type_ptr_data, @@ -324,7 +324,7 @@ class Predictor(object): ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) ptr_data, type_ptr_data = c_float_array(csr.data) - _safe_call(LIB.LGBM_BoosterPredictForCSR( + _safe_call(_LIB.LGBM_BoosterPredictForCSR( self.handle, ptr_indptr, type_ptr_indptr, @@ -447,7 +447,7 @@ class Dataset(object): init_score = new_init_score self.set_init_score(init_score) - def new_valid_dataset(self, data, label=None, weight=None, group_id=None, + def create_valid(self, data, label=None, weight=None, group_id=None, silent=False, params=None): """ Create validation data align with current dataset @@ -487,7 +487,7 @@ class Dataset(object): data = np.array(mat.reshape(mat.size), dtype=np.float32) ptr_data, type_ptr_data = c_float_array(data) - _safe_call(LIB.LGBM_CreateDatasetFromMat( + _safe_call(_LIB.LGBM_CreateDatasetFromMat( ptr_data, type_ptr_data, mat.shape[0], @@ -825,7 +825,7 @@ class Booster(object): if self.handle is not None and self.__is_manage_handle: _safe_call(_LIB.LGBM_BoosterFree(self.handle)) - def add_valid_data(self, data, name): + def add_valid(self, data, name): if data.predictor is not self.init_predictor: raise Exception("Add validation data failed, you should use same predictor for these data") _safe_call(_LIB.LGBM_BoosterAddValidData( @@ -835,7 +835,7 @@ class Booster(object): self.name_valid_sets.append(name) self.__num_dataset += 1 - def ResetParameter(self, params, silent=False): + def reset_parameter(self, params, silent=False): self.__need_reload_eval_info = True if silent: params["verbose"] = 0 diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py new file mode 100644 index 000000000..d35f70a80 --- /dev/null +++ b/tests/python_package_test/test_basic.py @@ -0,0 +1,24 @@ +import numpy as np +from sklearn import datasets, metrics, model_selection +import importlib.util +spec = importlib.util.spec_from_file_location("module.name", "../../python-package/lightgbm/basic.py") +lgb = importlib.util.module_from_spec(spec) +spec.loader.exec_module(lgb) + + +X, Y = datasets.make_classification(n_samples=100000, n_features=100) +x_train, x_test, y_train, y_test = model_selection.train_test_split(X, Y, test_size=0.1) + +train_data = lgb.Dataset(x_train, max_bin=255, label=y_train) +valid_data = train_data.create_valid(x_test, label=y_test) + +config={"objective":"binary","metric":"auc", "min_data":1, "num_leaves":15} +bst = lgb.Booster(params=config, train_set=train_data) +bst.add_valid(valid_data,"valid_1") + +for i in range(100): + bst.update() + print(bst.eval_train()) + print(bst.eval_valid()) +bst.save_model("model.txt") + From c04830a8fc1aa296835e41aebe6c139f55195233 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 24 Nov 2016 18:33:12 +0800 Subject: [PATCH 17/60] fix lib path --- python-package/lightgbm/basic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d5524a86d..36a226bb9 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -24,8 +24,9 @@ def find_lib_path(): """ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) dll_path = [curr_path, os.path.join(curr_path, '../../lib/'), - os.path.join(curr_path, './lib/'), - os.path.join(sys.prefix, 'lightgbm')] + os.path.join(curr_path, '../../'), + os.path.join(curr_path, './lib/'), + os.path.join(sys.prefix, 'lightgbm')] if os.name == 'nt': dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/')) dll_path.append(os.path.join(curr_path, './windows/x64/Dll/')) From a24b7fd4483baa62da99db0233750e63dda6f3bc Mon Sep 17 00:00:00 2001 From: Allardvm Date: Thu, 24 Nov 2016 11:34:54 +0100 Subject: [PATCH 18/60] Fixed prediction bug when num_used_model = NO_LIMIT / -1 --- src/boosting/gbdt.h | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index fbc7b6154..1ea6d71d1 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -76,26 +76,26 @@ public: void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) const override; /*! - * \brief Predtion for one record without sigmoid transformation + * \brief Prediction for one record without sigmoid transformation * \param feature_values Feature value on this record * \return Prediction result for this record */ std::vector PredictRaw(const double* feature_values) const override; /*! - * \brief Predtion for one record with sigmoid transformation if enabled + * \brief Prediction for one record with sigmoid transformation if enabled * \param feature_values Feature value on this record * \return Prediction result for this record */ std::vector Predict(const double* feature_values) const override; - + /*! - * \brief Predtion for one record with leaf index + * \brief Prediction for one record with leaf index * \param feature_values Feature value on this record * \return Predicted leaf index for this record */ std::vector PredictLeafIndex(const double* value) const override; - + /*! * \brief save model to file * \param num_used_model number of model that want to save, -1 means save all @@ -137,9 +137,11 @@ public: inline void SetNumUsedModel(int num_used_model) { if (num_used_model >= 0) { num_used_model_ = static_cast(num_used_model / num_class_); + } else { + num_used_model_ = static_cast(models_.size()) / num_class_; } } - + /*! * \brief Get Type name of this boosting object */ @@ -218,7 +220,7 @@ protected: std::vector bag_data_indices_; /*! \brief Number of in-bag data */ data_size_t bag_data_cnt_; - /*! \brief Number of traning data */ + /*! \brief Number of training data */ data_size_t num_data_; /*! \brief Number of classes */ int num_class_; @@ -226,7 +228,7 @@ protected: Random random_; /*! * \brief Sigmoid parameter, used for prediction. - * if > 0 meas output score will transform by sigmoid function + * if > 0 means output score will transform by sigmoid function */ double sigmoid_; /*! \brief Index of label column */ From 65e711a285b212d4bfbdf3b0a3d1921b41df0f6e Mon Sep 17 00:00:00 2001 From: Allard van Mossel Date: Thu, 24 Nov 2016 12:23:05 +0100 Subject: [PATCH 19/60] Fixed inconsistencies and missing C-API documentation (#96) --- include/LightGBM/c_api.h | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 28c23cf0b..8368f1d56 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -7,9 +7,9 @@ /*! * To avoid type conversion on large data, most of our expose interface support both for float_32 and float_64. * Except following: -* 1. gradients and hessians. +* 1. gradients and hessians. * 2. Get current score for training data and validation -* The reason is because they are called frequently, the type-conversion on them maybe time cost. +* The reason is because they are called frequently, the type-conversion on them maybe time cost. */ #ifdef __cplusplus @@ -307,10 +307,10 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, * \brief make prediction for file * \param handle handle * \param predict_type -* 0:raw score -* 1:with transform(if needed) +* 0:normal, with transform (if needed) +* 1:raw score * 2:leaf index -* \param n_used_trees number of used tree +* \param n_used_trees number of used tree, < 0 means no limit * \param data_has_header data file has header or not * \param data_filename filename of data file * \param result_filename filename of result file @@ -327,7 +327,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, * \brief make prediction for an new data set * \param handle handle * \param indptr pointer to row headers -* \param indptr_type +* \param indptr_type * \param indices findex * \param data fvalue * \param data_type @@ -335,10 +335,10 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, * \param nelem number of nonzero elements in the matrix * \param num_col number of columns; when it's set to 0, then guess from data * \param predict_type -* 0:raw score -* 1:with transform(if needed) +* 0:normal, with transform (if needed) +* 1:raw score * 2:leaf index -* \param n_used_trees number of used tree +* \param n_used_trees number of used tree, < 0 means no limit * \param out_result used to set a pointer to array, should allocate memory before call this function * \return 0 when success, -1 when failure happens */ @@ -364,10 +364,10 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, * \param ncol number columns * \param is_row_major 1 for row major, 0 for column major * \param predict_type -* 0:raw score -* 1:with transform(if needed) +* 0:normal, with transform (if needed) +* 1:raw score * 2:leaf index -* \param n_used_trees number of used tree +* \param n_used_trees number of used tree, < 0 means no limit * \param out_result used to set a pointer to array, should allocate memory before call this function * \return 0 when success, -1 when failure happens */ @@ -384,7 +384,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, /*! * \brief save model into file * \param handle handle -* \param num_used_model +* \param num_used_model, < 0 means no limit * \param filename file name * \return 0 when success, -1 when failure happens */ @@ -403,14 +403,14 @@ std::function>(int row_idx)> RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major); std::function>(int idx)> -RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, +RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem); std::function>(int idx)> -ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, +ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem); -std::vector +std::vector SampleFromOneColumn(const std::vector>& data, const std::vector& indices); @@ -437,6 +437,6 @@ inline int LGBM_APIHandleException(const std::string& ex) { catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \ catch(...) { return LGBM_APIHandleException("unknown exception"); } \ -return 0; +return 0; #endif // LIGHTGBM_C_API_H_ From 3484e898e9522566e11c32e611a9abc4278578af Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 09:30:52 +0800 Subject: [PATCH 20/60] thread safe for booster --- include/LightGBM/boosting.h | 5 +-- include/LightGBM/c_api.h | 2 +- include/LightGBM/utils/log.h | 2 +- src/application/application.cpp | 3 +- src/boosting/dart.hpp | 13 +----- src/boosting/gbdt.cpp | 69 +++++++++++----------------- src/boosting/gbdt.h | 10 ++--- src/c_api.cpp | 79 ++++++++++++++++++++++----------- 8 files changed, 87 insertions(+), 96 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 31e0526b8..f725d54e2 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -128,11 +128,10 @@ public: /*! * \brief save model to file - * \param num_used_model number of model that want to save, -1 means save all - * \param is_finish is training finished or not + * \param num_iterations Iterations that want to save, -1 means save all * \param filename filename that want to save to */ - virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0; + virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0; /*! * \brief Restore from a serialized string diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index b1fed69fb..418667757 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -473,7 +473,7 @@ SampleFromOneColumn(const std::vector>& data, const std:: // exception handle and error msg -static std::string& LastErrorMsg() { static std::string err_msg("Everything is fine"); return err_msg; } +static std::string& LastErrorMsg() { static thread_local std::string err_msg("Everything is fine"); return err_msg; } inline void LGBM_SetLastError(const char* msg) { LastErrorMsg() = msg; diff --git a/include/LightGBM/utils/log.h b/include/LightGBM/utils/log.h index 05c565c17..eb2efc49e 100644 --- a/include/LightGBM/utils/log.h +++ b/include/LightGBM/utils/log.h @@ -89,7 +89,7 @@ private: // a trick to use static variable in header file. // May be not good, but avoid to use an additional cpp file - static LogLevel& GetLevel() { static LogLevel level; return level; } + static LogLevel& GetLevel() { static thread_local LogLevel level = LogLevel::Info; return level; } }; diff --git a/src/application/application.cpp b/src/application/application.cpp index b922f9699..af0431a3d 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -225,11 +225,10 @@ void Application::Train() { // output used time per iteration Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration(end_time - start_time) * 1e-3, iter + 1); - boosting_->SaveModelToFile(NO_LIMIT, is_finished, config_.io_config.output_model.c_str()); } is_finished = true; // save model to file - boosting_->SaveModelToFile(NO_LIMIT, is_finished, config_.io_config.output_model.c_str()); + boosting_->SaveModelToFile(NO_LIMIT, config_.io_config.output_model.c_str()); Log::Info("Finished training"); } diff --git a/src/boosting/dart.hpp b/src/boosting/dart.hpp index 9df28dd6a..4c54d1789 100644 --- a/src/boosting/dart.hpp +++ b/src/boosting/dart.hpp @@ -67,18 +67,7 @@ public: *out_len = train_score_updater_->num_data() * num_class_; return train_score_updater_->score(); } - /*! - * \brief save model to file - * \param num_iteration -1 means save all - * \param is_finish is training finished or not - * \param filename filename that want to save to - */ - void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override { - // only save model once when is_finish = true - if (is_finish && saved_model_size_ < 0) { - GBDT::SaveModelToFile(num_iteration, is_finish, filename); - } - } + /*! * \brief Get Type name of this boosting object */ diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 61155f08a..0529ec5a9 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -17,8 +17,7 @@ namespace LightGBM { GBDT::GBDT() - :saved_model_size_(-1), - num_iteration_for_pred_(0), + :num_iteration_for_pred_(0), num_init_iteration_(0) { } @@ -30,7 +29,6 @@ GBDT::~GBDT() { void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) { iter_ = 0; - saved_model_size_ = -1; num_iteration_for_pred_ = 0; max_feature_idx_ = 0; num_class_ = config->num_class; @@ -395,56 +393,41 @@ void GBDT::Boosting() { GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data()); } -void GBDT::SaveModelToFile(int num_iteration, bool is_finish, const char* filename) { - // first time to this function, open file - if (saved_model_size_ < 0) { - model_output_file_.open(filename); - // output model type - model_output_file_ << Name() << std::endl; - // output number of class - model_output_file_ << "num_class=" << num_class_ << std::endl; - // output label index - model_output_file_ << "label_index=" << label_idx_ << std::endl; - // output max_feature_idx - model_output_file_ << "max_feature_idx=" << max_feature_idx_ << std::endl; - // output objective name - if (object_function_ != nullptr) { - model_output_file_ << "objective=" << object_function_->GetName() << std::endl; - } - // output sigmoid parameter - model_output_file_ << "sigmoid=" << sigmoid_ << std::endl; - model_output_file_ << std::endl; - saved_model_size_ = 0; - } - // already saved - if (!model_output_file_.is_open()) { - return; +void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { + /*! \brief File to write models */ + std::ofstream outpu_file; + outpu_file.open(filename); + // output model type + outpu_file << Name() << std::endl; + // output number of class + outpu_file << "num_class=" << num_class_ << std::endl; + // output label index + outpu_file << "label_index=" << label_idx_ << std::endl; + // output max_feature_idx + outpu_file << "max_feature_idx=" << max_feature_idx_ << std::endl; + // output objective name + if (object_function_ != nullptr) { + outpu_file << "objective=" << object_function_->GetName() << std::endl; } + // output sigmoid parameter + outpu_file << "sigmoid=" << sigmoid_ << std::endl; + outpu_file << std::endl; + int num_used_model = 0; if (num_iteration == NO_LIMIT) { num_used_model = static_cast(models_.size()); } else { num_used_model = num_iteration * num_class_; } - int rest = num_used_model - early_stopping_round_ * num_class_; + // output tree models - for (int i = saved_model_size_; i < rest; ++i) { - model_output_file_ << "Tree=" << i << std::endl; - model_output_file_ << models_[i]->ToString() << std::endl; + for (int i = 0; i < num_used_model; ++i) { + outpu_file << "Tree=" << i << std::endl; + outpu_file << models_[i]->ToString() << std::endl; } - saved_model_size_ = std::max(saved_model_size_, rest); - - model_output_file_.flush(); - // training finished, can close file - if (is_finish) { - for (int i = saved_model_size_; i < num_used_model; ++i) { - model_output_file_ << "Tree=" << i << std::endl; - model_output_file_ << models_[i]->ToString() << std::endl; - } - model_output_file_ << std::endl << FeatureImportance() << std::endl; - model_output_file_.close(); - } + outpu_file << std::endl << FeatureImportance() << std::endl; + outpu_file.close(); } void GBDT::LoadModelFromString(const std::string& model_str) { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 7aa5e14e6..14d3b3d8d 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -138,11 +138,11 @@ public: /*! * \brief save model to file - * \param num_iteration -1 means save all - * \param is_finish is training finished or not + * \param num_iterations Iterations that want to save, -1 means save all * \param filename filename that want to save to */ - virtual void SaveModelToFile(int num_iteration, bool is_finish, const char* filename) override; + virtual void SaveModelToFile(int num_iterations, const char* filename) const override ; + /*! * \brief Restore from a serialized string */ @@ -274,10 +274,6 @@ protected: double sigmoid_; /*! \brief Index of label column */ data_size_t label_idx_; - /*! \brief Saved number of models */ - int saved_model_size_; - /*! \brief File to write models */ - std::ofstream model_output_file_; /*! \brief number of used model */ int num_iteration_for_pred_; /*! \brief Shrinkage rate for one iteration */ diff --git a/src/c_api.cpp b/src/c_api.cpp index 3ede904fd..814718931 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include "./application/predictor.hpp" @@ -29,6 +30,7 @@ public: Booster(const Dataset* train_data, const char* parameters) { + std::unique_lock lock(mutex_); auto param = ConfigBase::Str2Map(parameters); config_.Set(param); // create boosting @@ -41,48 +43,31 @@ public: // initialize the boosting boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); + lock.unlock(); } void MergeFrom(const Booster* other) { + std::unique_lock lock(mutex_); boosting_->MergeFrom(other->boosting_.get()); + lock.unlock(); } ~Booster() { } - void ConstructObjectAndTrainingMetrics(const Dataset* train_data) { - // create objective function - objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, - config_.objective_config)); - if (objective_fun_ == nullptr) { - Log::Warning("Using self-defined objective functions"); - } - // create training metric - train_metric_.clear(); - for (auto metric_type : config_.metric_types) { - auto metric = std::unique_ptr( - Metric::CreateMetric(metric_type, config_.metric_config)); - if (metric == nullptr) { continue; } - metric->Init(train_data->metadata(), train_data->num_data()); - train_metric_.push_back(std::move(metric)); - } - train_metric_.shrink_to_fit(); - // initialize the objective function - if (objective_fun_ != nullptr) { - objective_fun_->Init(train_data->metadata(), train_data->num_data()); - } - } - void ResetTrainingData(const Dataset* train_data) { + std::unique_lock lock(mutex_); train_data_ = train_data; ConstructObjectAndTrainingMetrics(train_data_); // initialize the boosting boosting_->ResetTrainingData(&config_.boosting_config, train_data_, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); + lock.unlock(); } void ResetConfig(const char* parameters) { + std::unique_lock lock(mutex_); auto param = ConfigBase::Str2Map(parameters); if (param.count("num_class")) { Log::Fatal("cannot change num class during training"); @@ -92,9 +77,11 @@ public: } config_.Set(param); ResetTrainingData(train_data_); + lock.unlock(); } void AddValidData(const Dataset* valid_data) { + std::unique_lock lock(mutex_); valid_metrics_.emplace_back(); for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); @@ -105,20 +92,30 @@ public: valid_metrics_.back().shrink_to_fit(); boosting_->AddValidDataset(valid_data, Common::ConstPtrInVectorWrapper(valid_metrics_.back())); + lock.unlock(); } bool TrainOneIter() { - return boosting_->TrainOneIter(nullptr, nullptr, false); + std::unique_lock lock(mutex_); + bool ret = boosting_->TrainOneIter(nullptr, nullptr, false); + lock.unlock(); + return ret; } bool TrainOneIter(const float* gradients, const float* hessians) { - return boosting_->TrainOneIter(gradients, hessians, false); + std::unique_lock lock(mutex_); + bool ret = boosting_->TrainOneIter(gradients, hessians, false); + lock.unlock(); + return ret; } void RollbackOneIter() { + std::unique_lock lock(mutex_); boosting_->RollbackOneIter(); + lock.unlock(); } void PrepareForPrediction(int num_iteration, int predict_type) { + std::unique_lock lock(mutex_); boosting_->SetNumIterationForPred(num_iteration); bool is_predict_leaf = false; bool is_raw_score = false; @@ -130,6 +127,7 @@ public: is_raw_score = false; } predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf)); + lock.unlock(); } void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { @@ -145,7 +143,9 @@ public: } void SaveModelToFile(int num_iteration, const char* filename) { - boosting_->SaveModelToFile(num_iteration, true, filename); + std::unique_lock lock(mutex_); + boosting_->SaveModelToFile(num_iteration, filename); + lock.unlock(); } int GetEvalCounts() const { @@ -170,6 +170,30 @@ public: const Boosting* GetBoosting() const { return boosting_.get(); } private: + + void ConstructObjectAndTrainingMetrics(const Dataset* train_data) { + // create objective function + objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, + config_.objective_config)); + if (objective_fun_ == nullptr) { + Log::Warning("Using self-defined objective functions"); + } + // create training metric + train_metric_.clear(); + for (auto metric_type : config_.metric_types) { + auto metric = std::unique_ptr( + Metric::CreateMetric(metric_type, config_.metric_config)); + if (metric == nullptr) { continue; } + metric->Init(train_data->metadata(), train_data->num_data()); + train_metric_.push_back(std::move(metric)); + } + train_metric_.shrink_to_fit(); + // initialize the objective function + if (objective_fun_ != nullptr) { + objective_fun_->Init(train_data->metadata(), train_data->num_data()); + } + } + const Dataset* train_data_; std::unique_ptr boosting_; /*! \brief All configs */ @@ -182,7 +206,8 @@ private: std::unique_ptr objective_fun_; /*! \brief Using predictor for prediction task */ std::unique_ptr predictor_; - + /*! \brief mutex for threading safe call */ + std::mutex mutex_; }; } From 962b7eb00f154d224aaa7b89ac1605dbd5b11950 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 09:35:11 +0800 Subject: [PATCH 21/60] change to std::lock_guard --- src/c_api.cpp | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 814718931..71cd3329b 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -30,7 +30,7 @@ public: Booster(const Dataset* train_data, const char* parameters) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); auto param = ConfigBase::Str2Map(parameters); config_.Set(param); // create boosting @@ -43,13 +43,11 @@ public: // initialize the boosting boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); - lock.unlock(); } void MergeFrom(const Booster* other) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); boosting_->MergeFrom(other->boosting_.get()); - lock.unlock(); } ~Booster() { @@ -57,17 +55,16 @@ public: } void ResetTrainingData(const Dataset* train_data) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); train_data_ = train_data; ConstructObjectAndTrainingMetrics(train_data_); // initialize the boosting boosting_->ResetTrainingData(&config_.boosting_config, train_data_, objective_fun_.get(), Common::ConstPtrInVectorWrapper(train_metric_)); - lock.unlock(); } void ResetConfig(const char* parameters) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); auto param = ConfigBase::Str2Map(parameters); if (param.count("num_class")) { Log::Fatal("cannot change num class during training"); @@ -77,11 +74,10 @@ public: } config_.Set(param); ResetTrainingData(train_data_); - lock.unlock(); } void AddValidData(const Dataset* valid_data) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); valid_metrics_.emplace_back(); for (auto metric_type : config_.metric_types) { auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); @@ -92,30 +88,24 @@ public: valid_metrics_.back().shrink_to_fit(); boosting_->AddValidDataset(valid_data, Common::ConstPtrInVectorWrapper(valid_metrics_.back())); - lock.unlock(); } bool TrainOneIter() { - std::unique_lock lock(mutex_); - bool ret = boosting_->TrainOneIter(nullptr, nullptr, false); - lock.unlock(); - return ret; + std::lock_guard lock(mutex_); + return boosting_->TrainOneIter(nullptr, nullptr, false); } bool TrainOneIter(const float* gradients, const float* hessians) { - std::unique_lock lock(mutex_); - bool ret = boosting_->TrainOneIter(gradients, hessians, false); - lock.unlock(); - return ret; + std::lock_guard lock(mutex_); + return boosting_->TrainOneIter(gradients, hessians, false); } void RollbackOneIter() { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); boosting_->RollbackOneIter(); - lock.unlock(); } void PrepareForPrediction(int num_iteration, int predict_type) { - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); boosting_->SetNumIterationForPred(num_iteration); bool is_predict_leaf = false; bool is_raw_score = false; @@ -127,7 +117,6 @@ public: is_raw_score = false; } predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf)); - lock.unlock(); } void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { @@ -143,9 +132,7 @@ public: } void SaveModelToFile(int num_iteration, const char* filename) { - std::unique_lock lock(mutex_); boosting_->SaveModelToFile(num_iteration, filename); - lock.unlock(); } int GetEvalCounts() const { From 6e0b58bac38b84ecdc9b549c88b6efcc77ef0b19 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 09:45:41 +0800 Subject: [PATCH 22/60] thread-safe for set field of dataset --- include/LightGBM/dataset.h | 3 +++ src/c_api.cpp | 1 - src/io/metadata.cpp | 5 +++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 959576e7c..d5980dc5f 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -13,6 +13,7 @@ #include #include #include +#include namespace LightGBM { @@ -234,6 +235,8 @@ private: std::vector init_score_; /*! \brief Queries data */ std::vector queries_; + /*! \brief mutex for threading safe call */ + std::mutex mutex_; }; diff --git a/src/c_api.cpp b/src/c_api.cpp index 71cd3329b..60ca5293e 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -30,7 +30,6 @@ public: Booster(const Dataset* train_data, const char* parameters) { - std::lock_guard lock(mutex_); auto param = ConfigBase::Str2Map(parameters); config_.Set(param); // create boosting diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 9ceb5f26d..14f9dc946 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -196,6 +196,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector lock(mutex_); // save to nullptr if (init_score == nullptr || len == 0) { init_score_.clear(); @@ -214,6 +215,7 @@ void Metadata::SetInitScore(const float* init_score, data_size_t len) { } void Metadata::SetLabel(const float* label, data_size_t len) { + std::lock_guard lock(mutex_); if (label == nullptr) { Log::Fatal("label cannot be nullptr"); } @@ -228,6 +230,7 @@ void Metadata::SetLabel(const float* label, data_size_t len) { } void Metadata::SetWeights(const float* weights, data_size_t len) { + std::lock_guard lock(mutex_); // save to nullptr if (weights == nullptr || len == 0) { weights_.clear(); @@ -247,6 +250,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) { } void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size_t len) { + std::lock_guard lock(mutex_); // save to nullptr if (query_boundaries == nullptr || len == 0) { query_boundaries_.clear(); @@ -270,6 +274,7 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size } void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) { + std::lock_guard lock(mutex_); // save to nullptr if (query_id == nullptr || len == 0) { query_boundaries_.clear(); From f059d0fe0f9616d9dc2f91f459880b01f8cbf7f9 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 10:10:59 +0800 Subject: [PATCH 23/60] fix some bugs in basic.py --- python-package/lightgbm/basic.py | 61 +++++++++++++++++++------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 36a226bb9..ebc433992 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -78,10 +78,21 @@ def is_numpy_1d_array(data): else: return False +def is_1d_list(data): + if not isinstance(data, list): + return False + if len(data) > 0: + if not isinstance(data[0], (int, str, bool) ): + return False + return True + def list_to_1d_numpy(data, dtype): if is_numpy_1d_array(data): - return data - elif isinstance(data, list): + if data.dtype == dtype: + return data + else: + return data.astype(dtype=dtype, copy=False) + elif is_1d_list(data): return np.array(data, dtype=dtype, copy=False) else: raise TypeError("Unknow type({})".format(type(data).__name__)) @@ -140,7 +151,7 @@ FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32, def c_float_array(data): """Convert numpy array / list to c float array.""" - if isinstance(data, list): + if is_1d_list(data): data = np.array(data, copy=False) if is_numpy_1d_array(data): if data.dtype == np.float32: @@ -157,7 +168,7 @@ def c_float_array(data): def c_int_array(data): """Convert numpy array to c int array.""" - if isinstance(data, list): + if is_1d_list(data): data = np.array(data, copy=False) if is_numpy_1d_array(data): if data.dtype == np.int32: @@ -256,7 +267,7 @@ class Predictor(object): else: try: csr = scipy.sparse.csr_matrix(data) - res = self.__pred_for_csr(csr, num_iteration, predict_type) + preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type) except: raise TypeError('can not predict data for type {}'.format(type(data).__name__)) if pred_leaf: @@ -417,7 +428,7 @@ class Dataset(object): else: try: csr = scipy.sparse.csr_matrix(data) - self.__init_from_csr(csr) + self.__init_from_csr(csr, params_str, ref_dataset) except: raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) self.__label = None @@ -618,8 +629,6 @@ class Dataset(object): The label information to be set into Dataset """ label = list_to_1d_numpy(label, np.float32) - if label.dtype != np.float32: - label = label.astype(np.float32, copy=False) self.__label = label self.set_field('label', label) @@ -633,8 +642,6 @@ class Dataset(object): """ if weight is not None: weight = list_to_1d_numpy(weight, np.float32) - if weight.dtype != np.float32: - weight = weight.astype(np.float32, copy=False) self.__weight = weight self.set_field('weight', weight) @@ -647,8 +654,6 @@ class Dataset(object): """ if score is not None: score = list_to_1d_numpy(score, np.float32) - if score.dtype != np.float32: - score = score.astype(np.float32, copy=False) self.__init_score = score self.set_field('init_score', score) @@ -662,8 +667,6 @@ class Dataset(object): """ if group is not None: group = list_to_1d_numpy(group, np.int32) - if group.dtype != np.int32: - group = group.astype(np.int32, copy=False) self.__group = group self.set_field('group', group) @@ -678,8 +681,6 @@ class Dataset(object): """ if group_id is not None: group_id = list_to_1d_numpy(group_id, np.int32) - if group_id.dtype != np.int32: - group_id = group_id.astype(np.int32, copy=False) self.set_field('group_id', group_id) def get_label(self): @@ -890,26 +891,36 @@ class Booster(object): and you should group grad and hess in this way as well Parameters ---------- - grad : 1d numpy with dtype=float32 + grad : 1d numpy or list The first order of gradient. - hess : 1d numpy with dtype=float32 + hess : 1d numpy or list The second order of gradient. Returns ------- is_finished, bool """ - if not is_numpy_1d_array(grad) and not is_numpy_1d_array(hess): - raise TypeError('type of grad / hess should be 1d numpy object') - if not grad.dtype == np.float32 and not hess.dtype == np.float32: - raise TypeError('type of grad / hess should be np.float32') + if not is_numpy_1d_array(grad): + if is_1d_list(grad): + grad = np.array(grad, dtype=np.float32, copy=False) + else: + raise TypeError("grad should be numpy 1d array or 1d list") + if not is_numpy_1d_array(hess): + if is_1d_list(hess): + hess = np.array(hess, dtype=np.float32, copy=False) + else: + raise TypeError("hess should be numpy 1d array or 1d list") if len(grad) != len(hess): raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) + if grad.dtype != np.float32: + grad = grad.astype(np.float32, copy=False) + if hess.dtype != np.float32: + hess = hess.astype(np.float32, copy=False) is_finished = ctypes.c_int(0) _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( self.handle, - grad.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)), - hess.ctypes.data_as(ctypes.ctypes.POINTER(ctypes.c_float)), + grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.byref(is_finished))) return is_finished.value == 1 @@ -950,7 +961,7 @@ class Booster(object): break """need push new valid data""" if data_idx == -1: - self.add_valid_data(data, name) + self.add_valid(data, name) data_idx = self.__num_dataset - 1 return self.__inner_eval(name, data_idx, feval) From d8ecdaf5adf35e0da1323c5811cd890d73b4e535 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 11:54:11 +0800 Subject: [PATCH 24/60] change metric name. update some comments --- python-package/lightgbm/basic.py | 158 +++++++++++++++++++++++++------ src/metric/binary_metric.hpp | 6 +- src/metric/multiclass_metric.hpp | 4 +- src/metric/rank_metric.hpp | 2 +- src/metric/regression_metric.hpp | 4 +- 5 files changed, 136 insertions(+), 38 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index ebc433992..1efef3dca 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -187,8 +187,7 @@ class Predictor(object): """"A Predictor of LightGBM. """ def __init__(self,model_file=None, params=None, booster_handle=None, is_manage_handle=True): - # pylint: disable=invalid-name - """Initialize the Booster. + """Initialize the Predictor. Parameters ---------- @@ -233,6 +232,29 @@ class Predictor(object): def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): + """ + Predict logic + + Parameters + ---------- + data : string/numpy array/scipy.sparse + Data source for prediction + When data is string type, it represents the path of txt file, + num_iteration : + used iteration for prediction + raw_score : bool + True for predict raw score + pred_leaf : bool + True for predict leaf index + data_has_header : bool + Used for txt data + is_reshape : bool + True for reshape to [nrow, ...] + + Returns + ------- + Prediction result + """ if isinstance(data, Dataset): raise TypeError("cannot use Dataset instance for prediction, please use raw data instead") predict_type = C_API_PREDICT_NORMAL @@ -400,7 +422,7 @@ class Dataset(object): params["max_bin"] = max_bin if silent: params["verbose"] = 0 - else: + elif "verbose" not in params: params["verbose"] = 1 params_str = dict_to_str(params) """process for reference dataset""" @@ -477,7 +499,7 @@ class Dataset(object): group/query id for each instance. Note: if having group/query id, data should group by this id silent : boolean, optional Whether print messages during construction - other_params: dict, optional + params: dict, optional other parameters """ return Dataset(data, label=label, max_bin=self.max_bin, reference=self, @@ -758,7 +780,6 @@ class Booster(object): """"A Booster of of LightGBM. """ def __init__(self, params=None, train_set=None, model_file=None, silent=False): - # pylint: disable=invalid-name """Initialize the Booster. Parameters @@ -769,6 +790,8 @@ class Booster(object): training dataset model_file : string Path to the model file. + silent : boolean, optional + Whether print messages during construction """ self.handle = ctypes.c_void_p() self.__need_reload_eval_info = True @@ -777,7 +800,7 @@ class Booster(object): params = {} if silent: params["verbose"] = 0 - else: + elif "verbose" not in params: params["verbose"] = 1 if train_set is not None: """Training task""" @@ -806,6 +829,7 @@ class Booster(object): self.__num_class = out_num_class.value """buffer for inner predict""" self.__inner_predict_buffer = [None] + self.__is_predicted_cur_iter = [False] self.__get_eval_info() elif model_file is not None: """Prediction task""" @@ -828,6 +852,15 @@ class Booster(object): _safe_call(_LIB.LGBM_BoosterFree(self.handle)) def add_valid(self, data, name): + """Add an validation data + + Parameters + ---------- + data : Dataset + validation data + name : String + name of validation data + """ if data.predictor is not self.init_predictor: raise Exception("Add validation data failed, you should use same predictor for these data") _safe_call(_LIB.LGBM_BoosterAddValidData( @@ -836,12 +869,23 @@ class Booster(object): self.valid_sets.append(data) self.name_valid_sets.append(name) self.__num_dataset += 1 + self.__inner_predict_buffer.append(None) + self.__is_predicted_cur_iter.append(False) def reset_parameter(self, params, silent=False): + """Reset parameters for booster + + Parameters + ---------- + params : dict + params + silent : boolean, optional + Whether print messages during construction + """ self.__need_reload_eval_info = True if silent: params["verbose"] = 0 - else: + elif "verbose" not in params: params["verbose"] = 1 params_str = dict_to_str(params) _safe_call(_LIB.LGBM_BoosterResetParameter( @@ -864,6 +908,7 @@ class Booster(object): ------- is_finished, bool """ + """need reset training data""" if train_set is not None and train_set is not self.train_set: if train_set.predictor is not self.init_predictor: @@ -878,6 +923,7 @@ class Booster(object): _safe_call(_LIB.LGBM_BoosterUpdateOneIter( self.handle, ctypes.byref(is_finished))) + self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] return is_finished.value == 1 else: grad, hess = fobj(self.__inner_predict(0), self.train_set) @@ -891,9 +937,9 @@ class Booster(object): and you should group grad and hess in this way as well Parameters ---------- - grad : 1d numpy or list + grad : 1d numpy or 1d list The first order of gradient. - hess : 1d numpy or list + hess : 1d numpy or 1d list The second order of gradient. Returns @@ -922,11 +968,16 @@ class Booster(object): grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), ctypes.byref(is_finished))) + self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] return is_finished.value == 1 def rollback_one_iter(self): + """ + Rollback one iteration + """ _safe_call(_LIB.LGBM_BoosterRollbackOneIter( self.handle)) + self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] def current_iteration(self): out_cur_iter = ctypes.c_int64(0) @@ -946,8 +997,8 @@ class Booster(object): Custom evaluation function. Returns ------- - result: str - Evaluation result string. + result: list + Evaluation result list. """ if not isinstance(data, Dataset): raise TypeError("Can only eval for Dataset instance") @@ -977,7 +1028,7 @@ class Booster(object): Returns ------- result: str - Evaluation result string. + Evaluation result list. """ return self.__inner_eval("training", 0, feval) @@ -992,29 +1043,67 @@ class Booster(object): Returns ------- result: str - Evaluation result string. + Evaluation result list. """ ret = [] for i in range(1, self.__num_dataset): - ret.append(self.__inner_eval(self.name_valid_sets[i-1], i, feval)) - return '\n'.join(ret) + ret.extend(self.__inner_eval(self.name_valid_sets[i-1], i, feval)) + return ret def save_model(self, filename, num_iteration=-1): + """Save model of booster to file + + Parameters + ---------- + filename : str + filename to save + num_iteration: int + number of iteration that want to save. < 0 means save all + """ _safe_call(_LIB.LGBM_BoosterSaveModel( self.handle, num_iteration, c_str(filename))) def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): + """ + Predict logic + + Parameters + ---------- + data : string/numpy array/scipy.sparse + Data source for prediction + When data is string type, it represents the path of txt file, + num_iteration : + used iteration for prediction + raw_score : bool + True for predict raw score + pred_leaf : bool + True for predict leaf index + data_has_header : bool + Used for txt data + is_reshape : bool + True for reshape to [nrow, ...] + + Returns + ------- + Prediction result + """ predictor = Predictor(booster_handle=self.handle, is_manage_handle=False) return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape) def to_predictor(self): + """Convert to predictor + Note: Predictor will manage the handle after doing this + """ predictor = Predictor(booster_handle=self.handle, is_manage_handle=True) self.__is_manage_handle = False return predictor def __inner_eval(self, data_name, data_idx, feval=None): + """ + Evaulate traning or validation data + """ if data_idx >= self.__num_dataset: raise ValueError("data_idx should be smaller than number of dataset") self.__get_eval_info() @@ -1030,7 +1119,7 @@ class Booster(object): if tmp_out_len.value != self.__num_inner_eval: raise ValueError("incorrect number of eval results") for i in range(self.__num_inner_eval): - ret.append('%s %s : %f' %(data_name, self.__name_inner_eval[i], result[i])) + ret.append((data_name, self.__name_inner_eval[i], result[i])) if feval is not None: if data_idx == 0: cur_data = self.train_set @@ -1038,14 +1127,17 @@ class Booster(object): cur_data = self.valid_sets[data_idx - 1] feval_ret = feval(self.__inner_predict(data_idx), cur_data) if isinstance(feval_ret, list): - for name, val in feval_ret: - ret.append('%s %s : %f' % (data_name, name, val)) + for eval_name, val in feval_ret: + ret.append((data_name, eval_name, val)) else: - name, val = feval_ret - ret.append('%s %s : %f' % (data_name, name, val)) - return '\t'.join(ret) + eval_name, val = feval_ret + ret.append((data_name, eval_name, val)) + return ret def __inner_predict(self, data_idx): + """ + Predict for training and validation dataset + """ if data_idx >= self.__num_dataset: raise ValueError("data_idx should be smaller than number of dataset") if self.__inner_predict_buffer[data_idx] is None: @@ -1055,18 +1147,24 @@ class Booster(object): num_data = self.valid_sets[data_idx - 1].num_data() * self.__num_class self.__inner_predict_buffer[data_idx] = \ np.array([0.0 for _ in range(num_data)], dtype=np.float32, copy=False) - tmp_out_len = ctypes.c_int64(0) - data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float)) - _safe_call(_LIB.LGBM_BoosterGetPredict( - self.handle, - data_idx, - ctypes.byref(tmp_out_len), - data_ptr)) - if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): - raise ValueError("incorrect number of predict results for data %d" %(data_idx) ) + """avoid to predict many time in one iteration""" + if not self.__is_predicted_cur_iter[data_idx]: + tmp_out_len = ctypes.c_int64(0) + data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + _safe_call(_LIB.LGBM_BoosterGetPredict( + self.handle, + data_idx, + ctypes.byref(tmp_out_len), + data_ptr)) + if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): + raise ValueError("incorrect number of predict results for data %d" %(data_idx) ) + self.__is_predicted_cur_iter[data_idx] = True return self.__inner_predict_buffer[data_idx] def __get_eval_info(self): + """ + Get inner evaluation count and names + """ if self.__need_reload_eval_info: self.__need_reload_eval_info = False out_num_eval = ctypes.c_int64(0) diff --git a/src/metric/binary_metric.hpp b/src/metric/binary_metric.hpp index c1f2c2982..734bdc2b3 100644 --- a/src/metric/binary_metric.hpp +++ b/src/metric/binary_metric.hpp @@ -116,7 +116,7 @@ public: } inline static const char* Name() { - return "log loss"; + return "logloss"; } }; /*! @@ -135,7 +135,7 @@ public: } inline static const char* Name() { - return "error rate"; + return "error"; } }; @@ -160,7 +160,7 @@ public: } void Init(const Metadata& metadata, data_size_t num_data) override { - name_.emplace_back("AUC"); + name_.emplace_back("auc"); num_data_ = num_data; // get label diff --git a/src/metric/multiclass_metric.hpp b/src/metric/multiclass_metric.hpp index 9b5c3c7b6..eb1deb56a 100644 --- a/src/metric/multiclass_metric.hpp +++ b/src/metric/multiclass_metric.hpp @@ -109,7 +109,7 @@ public: } inline static const char* Name() { - return "multi error"; + return "multi_error"; } }; @@ -129,7 +129,7 @@ public: } inline static const char* Name() { - return "multi logloss"; + return "multi_logloss"; } }; diff --git a/src/metric/rank_metric.hpp b/src/metric/rank_metric.hpp index bc5ae96c3..c2f60eca1 100644 --- a/src/metric/rank_metric.hpp +++ b/src/metric/rank_metric.hpp @@ -35,7 +35,7 @@ public: } void Init(const Metadata& metadata, data_size_t num_data) override { for (auto k : eval_at_) { - name_.emplace_back(std::string("NDCG@") + std::to_string(k)); + name_.emplace_back(std::string("ndcg@") + std::to_string(k)); } num_data_ = num_data; // get label diff --git a/src/metric/regression_metric.hpp b/src/metric/regression_metric.hpp index 7e7f21241..d07e0b369 100644 --- a/src/metric/regression_metric.hpp +++ b/src/metric/regression_metric.hpp @@ -101,7 +101,7 @@ public: } inline static const char* Name() { - return "l2 loss"; + return "l2"; } }; @@ -114,7 +114,7 @@ public: return std::fabs(score - label); } inline static const char* Name() { - return "l1 loss"; + return "l1"; } }; From 28e891ad54dfcb8c147ab5187e8bf28f1df0b79f Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 12:05:16 +0800 Subject: [PATCH 25/60] typo --- include/LightGBM/c_api.h | 64 ++++++++++++++++++++-------------------- src/boosting/gbdt.cpp | 26 ++++++++-------- src/c_api.cpp | 8 ++--- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 418667757..bcc9c37d6 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -38,7 +38,7 @@ typedef void* BoosterHandle; /*! * \brief get string message of the last error -* all function in this file will return 0 when success +* all function in this file will return 0 when succeed * and -1 when an error occured, * \return const char* error inforomation */ @@ -53,7 +53,7 @@ DllExport const char* LGBM_GetLastError(); * \param parameters additional parameters * \param reference used to align bin mapper with other dataset, nullptr means don't used * \param out a loaded dataset -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_CreateDatasetFromFile(const char* filename, const char* parameters, @@ -64,7 +64,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, * \brief load data set from binary file like the command_line LightGBM do * \param filename the name of the file * \param out a loaded dataset -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, DatesetHandle* out); @@ -82,7 +82,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, * \param parameters additional parameters * \param reference used to align bin mapper with other dataset, nullptr means don't used * \param out created dataset -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, int indptr_type, @@ -109,7 +109,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, * \param parameters additional parameters * \param reference used to align bin mapper with other dataset, nullptr means don't used * \param out created dataset -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, int col_ptr_type, @@ -133,7 +133,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, * \param parameters additional parameters * \param reference used to align bin mapper with other dataset, nullptr means don't used * \param out created dataset -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_CreateDatasetFromMat(const void* data, int data_type, @@ -146,7 +146,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, /*! * \brief free space for dataset -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetFree(DatesetHandle handle); @@ -154,7 +154,7 @@ DllExport int LGBM_DatasetFree(DatesetHandle handle); * \brief save dateset to binary file * \param handle a instance of dataset * \param filename file name -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, const char* filename); @@ -166,7 +166,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, * \param field_data pointer to vector * \param num_element number of element in field_data * \param type float32 or int32 -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetSetField(DatesetHandle handle, const char* field_name, @@ -181,7 +181,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, * \param out_len used to set result length * \param out_ptr pointer to the result * \param out_type float32 or int32 -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetGetField(DatesetHandle handle, const char* field_name, @@ -193,7 +193,7 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, * \brief get number of data. * \param handle the handle to the dataset * \param out The address to hold number of data -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetGetNumData(DatesetHandle handle, int64_t* out); @@ -202,7 +202,7 @@ DllExport int LGBM_DatasetGetNumData(DatesetHandle handle, * \brief get number of features * \param handle the handle to the dataset * \param out The output of number of features -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, int64_t* out); @@ -214,7 +214,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, * \param train_data training data set * \param parameters format: 'key1=value1 key2=value2' * \prama out handle of created Booster -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, const char* parameters, @@ -225,7 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, * \param filename filename of model * \param out_num_total_model number of total models * \param out handle of created Booster -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, @@ -236,7 +236,7 @@ DllExport int LGBM_BoosterCreateFromModelfile( /*! * \brief free obj in handle * \param handle handle to be freed -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterFree(BoosterHandle handle); @@ -244,7 +244,7 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle); * \brief Merge model in two booster to first handle * \param handle handle, will merge other handle to this * \param other_handle -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterMerge(BoosterHandle handle, BoosterHandle other_handle); @@ -253,7 +253,7 @@ DllExport int LGBM_BoosterMerge(BoosterHandle handle, * \brief Add new validation to booster * \param handle handle * \param valid_data validation data set -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, const DatesetHandle valid_data); @@ -262,7 +262,7 @@ DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, * \brief Reset training data for booster * \param handle handle * \param train_data training data set -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, const DatesetHandle train_data); @@ -271,7 +271,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, * \brief Reset config for current booster * \param handle handle * \param parameters format: 'key1=value1 key2=value2' -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters); @@ -286,7 +286,7 @@ DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len); * \brief update the model in one round * \param handle handle * \param is_finished 1 means finised(cannot split any more) -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished); @@ -297,7 +297,7 @@ DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished); * \param grad gradient statistics * \param hess second order gradient statistics * \param is_finished 1 means finised(cannot split any more) -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, const float* grad, @@ -307,7 +307,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, /*! * \brief Rollback one iteration * \param handle handle -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle); @@ -332,13 +332,13 @@ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, c /*! * \brief get evaluation for training data and validation data * \param handle handle -* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ... +* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param out_len len of output result * \param out_result the string containing evaluation statistics, should allocate memory before call this function -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, - int data, + int data_idx, int64_t* out_len, float* out_results); @@ -346,13 +346,13 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, * \brief Get prediction for training data and validation data this can be used to support customized eval function * \param handle handle -* \param data 0:training data, 1: 1st valid data, 2:2nd valid data ... +* \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, - int data, + int data_idx, int64_t* out_len, float* out_result); @@ -367,7 +367,7 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, * 2:leaf index * \param num_iteration number of iteration for prediction, < 0 means no limit * \param result_filename filename of result file -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, const char* data_filename, @@ -394,7 +394,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, * \param num_iteration number of iteration for prediction, < 0 means no limit * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, const void* indptr, @@ -425,7 +425,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, * \param num_iteration number of iteration for prediction, < 0 means no limit * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, const void* data, @@ -443,7 +443,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, * \param handle handle * \param num_iteration, < 0 means no limit * \param filename file name -* \return 0 when success, -1 when failure happens +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, int num_iteration, diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 0529ec5a9..2ae0e280f 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -395,23 +395,23 @@ void GBDT::Boosting() { void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { /*! \brief File to write models */ - std::ofstream outpu_file; - outpu_file.open(filename); + std::ofstream output_file; + output_file.open(filename); // output model type - outpu_file << Name() << std::endl; + output_file << Name() << std::endl; // output number of class - outpu_file << "num_class=" << num_class_ << std::endl; + output_file << "num_class=" << num_class_ << std::endl; // output label index - outpu_file << "label_index=" << label_idx_ << std::endl; + output_file << "label_index=" << label_idx_ << std::endl; // output max_feature_idx - outpu_file << "max_feature_idx=" << max_feature_idx_ << std::endl; + output_file << "max_feature_idx=" << max_feature_idx_ << std::endl; // output objective name if (object_function_ != nullptr) { - outpu_file << "objective=" << object_function_->GetName() << std::endl; + output_file << "objective=" << object_function_->GetName() << std::endl; } // output sigmoid parameter - outpu_file << "sigmoid=" << sigmoid_ << std::endl; - outpu_file << std::endl; + output_file << "sigmoid=" << sigmoid_ << std::endl; + output_file << std::endl; int num_used_model = 0; if (num_iteration == NO_LIMIT) { @@ -422,12 +422,12 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { // output tree models for (int i = 0; i < num_used_model; ++i) { - outpu_file << "Tree=" << i << std::endl; - outpu_file << models_[i]->ToString() << std::endl; + output_file << "Tree=" << i << std::endl; + output_file << models_[i]->ToString() << std::endl; } - outpu_file << std::endl << FeatureImportance() << std::endl; - outpu_file.close(); + output_file << std::endl << FeatureImportance() << std::endl; + output_file.close(); } void GBDT::LoadModelFromString(const std::string& model_str) { diff --git a/src/c_api.cpp b/src/c_api.cpp index 60ca5293e..9776f4b57 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -593,13 +593,13 @@ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, c DllExport int LGBM_BoosterGetEval(BoosterHandle handle, - int data, + int data_idx, int64_t* out_len, float* out_results) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); auto boosting = ref_booster->GetBoosting(); - auto result_buf = boosting->GetEvalAt(data); + auto result_buf = boosting->GetEvalAt(data_idx); *out_len = static_cast(result_buf.size()); for (size_t i = 0; i < result_buf.size(); ++i) { (out_results)[i] = static_cast(result_buf[i]); @@ -608,13 +608,13 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, } DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, - int data, + int data_idx, int64_t* out_len, float* out_result) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); int len = 0; - ref_booster->GetPredictAt(data, out_result, &len); + ref_booster->GetPredictAt(data_idx, out_result, &len); *out_len = static_cast(len); API_END(); } From 7287af3e941009e8ba1ff0702490962591d41232 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 12:35:55 +0800 Subject: [PATCH 26/60] change no_limit to <=0. fix float in is_1d_list. --- include/LightGBM/c_api.h | 8 ++++---- include/LightGBM/config.h | 6 +++--- include/LightGBM/meta.h | 1 - python-package/lightgbm/basic.py | 2 +- src/application/application.cpp | 3 +-- src/boosting/gbdt.cpp | 4 ++-- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index bcc9c37d6..37efd6378 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -365,7 +365,7 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, * 0:normal, with transform (if needed) * 1:raw score * 2:leaf index -* \param num_iteration number of iteration for prediction, < 0 means no limit +* \param num_iteration number of iteration for prediction, <= 0 means no limit * \param result_filename filename of result file * \return 0 when succeed, -1 when failure happens */ @@ -391,7 +391,7 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, * 0:normal, with transform (if needed) * 1:raw score * 2:leaf index -* \param num_iteration number of iteration for prediction, < 0 means no limit +* \param num_iteration number of iteration for prediction, <= 0 means no limit * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function * \return 0 when succeed, -1 when failure happens @@ -422,7 +422,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, * 0:normal, with transform (if needed) * 1:raw score * 2:leaf index -* \param num_iteration number of iteration for prediction, < 0 means no limit +* \param num_iteration number of iteration for prediction, <= 0 means no limit * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function * \return 0 when succeed, -1 when failure happens @@ -441,7 +441,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle, /*! * \brief save model into file * \param handle handle -* \param num_iteration, < 0 means no limit +* \param num_iteration, <= 0 means save all * \param filename file name * \return 0 when succeed, -1 when failure happens */ diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 19e4b3364..5ec22cb63 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -99,7 +99,7 @@ public: std::string output_result = "LightGBM_predict_result.txt"; std::string input_model = ""; int verbosity = 1; - int num_iteration_predict = NO_LIMIT; + int num_iteration_predict = -1; bool is_pre_partition = false; bool is_enable_sparse = true; bool use_two_round_loading = false; @@ -166,12 +166,12 @@ public: int feature_fraction_seed = 2; double feature_fraction = 1.0f; // max cache size(unit:MB) for historical histogram. < 0 means not limit - double histogram_pool_size = NO_LIMIT; + double histogram_pool_size = -1.0f; // max depth of tree model. // Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting // And the max leaves will be min(num_leaves, pow(2, max_depth - 1)) // max_depth < 0 means not limit - int max_depth = NO_LIMIT; + int max_depth = -1; void Set(const std::unordered_map& params) override; }; diff --git a/include/LightGBM/meta.h b/include/LightGBM/meta.h index 033e492b6..2689e2ede 100644 --- a/include/LightGBM/meta.h +++ b/include/LightGBM/meta.h @@ -24,7 +24,6 @@ using ReduceFunction = std::function; using PredictFunction = std::function(const std::vector>&)>; -#define NO_LIMIT (-1) #define NO_SPECIFIC (-1) } // namespace LightGBM diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 1efef3dca..d09ac1add 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -82,7 +82,7 @@ def is_1d_list(data): if not isinstance(data, list): return False if len(data) > 0: - if not isinstance(data[0], (int, str, bool) ): + if not isinstance(data[0], (int, float, bool) ): return False return True diff --git a/src/application/application.cpp b/src/application/application.cpp index af0431a3d..3501f7299 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -226,9 +226,8 @@ void Application::Train() { Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration(end_time - start_time) * 1e-3, iter + 1); } - is_finished = true; // save model to file - boosting_->SaveModelToFile(NO_LIMIT, config_.io_config.output_model.c_str()); + boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); Log::Info("Finished training"); } diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 2ae0e280f..615125bcd 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -414,12 +414,12 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const { output_file << std::endl; int num_used_model = 0; - if (num_iteration == NO_LIMIT) { + if (num_iteration <= 0) { num_used_model = static_cast(models_.size()); } else { num_used_model = num_iteration * num_class_; } - + num_used_model = std::min(num_used_model, static_cast(models_.size())); // output tree models for (int i = 0; i < num_used_model; ++i) { output_file << "Tree=" << i << std::endl; From 0ae51f146a33df6b64d3758d2cc9ed2eccf19ca0 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 12:54:54 +0800 Subject: [PATCH 27/60] change static std::string to static char[]. (refer to https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables) --- include/LightGBM/c_api.h | 7 ++++--- src/c_api.cpp | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 37efd6378..4ad36dc71 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -3,7 +3,9 @@ #include #include #include +#include #include + /*! * To avoid type conversion on large data, most of our expose interface support both for float_32 and float_64. * Except following: @@ -472,11 +474,10 @@ SampleFromOneColumn(const std::vector>& data, const std:: // exception handle and error msg - -static std::string& LastErrorMsg() { static thread_local std::string err_msg("Everything is fine"); return err_msg; } +static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; } inline void LGBM_SetLastError(const char* msg) { - LastErrorMsg() = msg; + std::strcpy(LastErrorMsg(), msg); } inline int LGBM_APIHandleException(const std::exception& ex) { diff --git a/src/c_api.cpp b/src/c_api.cpp index 9776f4b57..71daa2787 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -201,7 +201,7 @@ private: using namespace LightGBM; DllExport const char* LGBM_GetLastError() { - return LastErrorMsg().c_str(); + return LastErrorMsg(); } DllExport int LGBM_CreateDatasetFromFile(const char* filename, From 522e99932dcd4754607ee698dd0c0626ea18cf4d Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 13:42:41 +0800 Subject: [PATCH 28/60] support identity bin file from file content --- include/LightGBM/c_api.h | 9 ------ include/LightGBM/dataset.h | 2 ++ include/LightGBM/dataset_loader.h | 2 +- src/c_api.cpp | 9 ------ src/io/dataset.cpp | 4 ++- src/io/dataset_loader.cpp | 52 +++++++++++++++++++++++-------- 6 files changed, 45 insertions(+), 33 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 4ad36dc71..fe4d25e38 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -62,15 +62,6 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, const DatesetHandle* reference, DatesetHandle* out); -/*! -* \brief load data set from binary file like the command_line LightGBM do -* \param filename the name of the file -* \param out a loaded dataset -* \return 0 when succeed, -1 when failure happens -*/ -DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, - DatesetHandle* out); - /*! * \brief create a dataset from CSR format * \param indptr pointer to row headers diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index d5980dc5f..d3bc8b894 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -402,6 +402,8 @@ private: int label_idx_ = 0; /*! \brief store feature names */ std::vector feature_names_; + /*! \brief store feature names */ + static const char* binary_file_token; }; } // namespace LightGBM diff --git a/include/LightGBM/dataset_loader.h b/include/LightGBM/dataset_loader.h index 44e63d867..a9d897089 100644 --- a/include/LightGBM/dataset_loader.h +++ b/include/LightGBM/dataset_loader.h @@ -49,7 +49,7 @@ private: void ExtractFeaturesFromFile(const char* filename, const Parser* parser, const std::vector& used_data_indices, Dataset* dataset); /*! \brief Check can load from binary file */ - bool CheckCanLoadFromBin(const char* filename); + std::string CheckCanLoadFromBin(const char* filename); const IOConfig& io_config_; /*! \brief Random generator*/ diff --git a/src/c_api.cpp b/src/c_api.cpp index 71daa2787..589e0262d 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -223,15 +223,6 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, API_END(); } -DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, - DatesetHandle* out) { - API_BEGIN(); - OverallConfig config; - DatasetLoader loader(config.io_config, nullptr); - *out = loader.LoadFromBinFile(filename, 0, 1); - API_END(); -} - DllExport int LGBM_CreateDatasetFromMat(const void* data, int data_type, int32_t nrow, diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 8391dbe8b..a8814dc5b 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -14,6 +14,7 @@ namespace LightGBM { +const char* Dataset::binary_file_token = "______LightGBM_Binary_File_Token______\n"; Dataset::Dataset() { num_class_ = 1; @@ -135,7 +136,8 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { Log::Fatal("Cannot write binary data to %s ", bin_filename); } Log::Info("Saving data to binary file %s", bin_filename); - + size_t size_of_token = std::strlen(binary_file_token); + fwrite(binary_file_token, sizeof(char), size_of_token, file); // get size of header size_t size_of_header = sizeof(num_data_) + sizeof(num_class_) + sizeof(num_features_) + sizeof(num_total_features_) + sizeof(size_t) + sizeof(int) * used_feature_map_.size(); diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 5215180a7..c073d7620 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -152,8 +152,8 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac dataset->data_filename_ = filename; dataset->num_class_ = io_config_.num_class; dataset->metadata_.Init(filename, dataset->num_class_); - bool is_loading_from_binfile = CheckCanLoadFromBin(filename); - if (!is_loading_from_binfile) { + auto bin_filename = CheckCanLoadFromBin(filename); + if (bin_filename.size() == 0) { if (!io_config_.use_two_round_loading) { // read data to memory auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines,&num_global_data, &used_data_indices); @@ -185,8 +185,6 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac } } else { // load data from binary file - std::string bin_filename(filename); - bin_filename.append(".bin"); dataset.reset(LoadFromBinFile(bin_filename.c_str(), rank, num_machines)); } // check meta data @@ -209,8 +207,8 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, dataset->data_filename_ = filename; dataset->num_class_ = io_config_.num_class; dataset->metadata_.Init(filename, dataset->num_class_); - bool is_loading_from_binfile = CheckCanLoadFromBin(filename); - if (!is_loading_from_binfile) { + auto bin_filename = CheckCanLoadFromBin(filename); + if (bin_filename.size() == 0) { if (!io_config_.use_two_round_loading) { // read data in memory auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices); @@ -234,8 +232,6 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, } } else { // load data from binary file - std::string bin_filename(filename); - bin_filename.append(".bin"); dataset.reset(LoadFromBinFile(bin_filename.c_str(), 0, 1)); } // not need to check validation data @@ -260,9 +256,19 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int // buffer to read binary file size_t buffer_size = 16 * 1024 * 1024; auto buffer = std::vector(buffer_size); + + // check token + size_t size_of_token = std::strlen(Dataset::binary_file_token); + size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file); + if (read_cnt != size_of_token) { + Log::Fatal("Binary file error: token has the wrong size"); + } + if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) { + Log::Fatal("input file is not LightGBM binary file"); + } // read size of header - size_t read_cnt = fread(buffer.data(), sizeof(size_t), 1, file); + read_cnt = fread(buffer.data(), sizeof(size_t), 1, file); if (read_cnt != 1) { Log::Fatal("Binary file error: header has the wrong size"); @@ -849,7 +855,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* } /*! \brief Check can load from binary file */ -bool DatasetLoader::CheckCanLoadFromBin(const char* filename) { +std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { std::string bin_filename(filename); bin_filename.append(".bin"); @@ -860,12 +866,32 @@ bool DatasetLoader::CheckCanLoadFromBin(const char* filename) { #else file = fopen(bin_filename.c_str(), "rb"); #endif + if (file == NULL) { - return false; + bin_filename = std::string(filename); +#ifdef _MSC_VER + fopen_s(&file, bin_filename.c_str(), "rb"); +#else + file = fopen(bin_filename.c_str(), "rb"); +#endif + if (file == NULL) { + Log::Fatal("cannot open data file %s", bin_filename.c_str()); + } + } + + size_t buffer_size = 256; + auto buffer = std::vector(buffer_size); + // read size of token + size_t size_of_token = std::strlen(Dataset::binary_file_token); + size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file); + fclose(file); + if (read_cnt == size_of_token + && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) { + return bin_filename; } else { - fclose(file); - return true; + return std::string(); } + } } \ No newline at end of file From 0612dcc08d8bff4c02ab2cdce779823f9803ce74 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 14:44:48 +0800 Subject: [PATCH 29/60] support get subset of dataset --- include/LightGBM/c_api.h | 24 ++++++++++++++++++++---- include/LightGBM/dataset.h | 4 ++-- include/LightGBM/feature.h | 3 +++ python-package/lightgbm/basic.py | 6 +++--- src/c_api.cpp | 28 ++++++++++++++++++++++++---- src/io/dataset.cpp | 30 ++++++++++++++++++++++++++---- src/io/dataset_loader.cpp | 31 +++++++++++++++---------------- tests/c_api_test/test.py | 22 ++++++++-------------- 8 files changed, 101 insertions(+), 47 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index fe4d25e38..a1a3269ea 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -57,7 +57,7 @@ DllExport const char* LGBM_GetLastError(); * \param out a loaded dataset * \return 0 when succeed, -1 when failure happens */ -DllExport int LGBM_CreateDatasetFromFile(const char* filename, +DllExport int LGBM_DatasetCreateFromFile(const char* filename, const char* parameters, const DatesetHandle* reference, DatesetHandle* out); @@ -77,7 +77,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, * \param out created dataset * \return 0 when succeed, -1 when failure happens */ -DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, +DllExport int LGBM_DatasetCreateFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, @@ -104,7 +104,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, * \param out created dataset * \return 0 when succeed, -1 when failure happens */ -DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, +DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, @@ -128,7 +128,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, * \param out created dataset * \return 0 when succeed, -1 when failure happens */ -DllExport int LGBM_CreateDatasetFromMat(const void* data, +DllExport int LGBM_DatasetCreateFromMat(const void* data, int data_type, int32_t nrow, int32_t ncol, @@ -137,6 +137,22 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, const DatesetHandle* reference, DatesetHandle* out); +/*! +* \brief Create subset of a data +* \param full_data the full dataset +* \param used_row_indices Indices used in subset +* \param num_used_row_indices len of used_row_indices +* \param parameters additional parameters +* \param out subset of data +* \return 0 when succeed, -1 when failure happens +*/ +DllExport int LGBM_DatasetGetSubset( + const DatesetHandle* full_data, + const int32_t* used_row_indices, + const int32_t num_used_row_indices, + const char* parameters, + DatesetHandle* out); + /*! * \brief free space for dataset * \return 0 when succeed, -1 when failure happens diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index d3bc8b894..356c1adfb 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -330,6 +330,8 @@ public: } } + Dataset* Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const; + void FinishLoad(); bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element); @@ -396,8 +398,6 @@ private: int num_class_; /*! \brief Store some label level data*/ Metadata metadata_; - /*! \brief True if dataset is loaded from binary file */ - bool is_loading_from_binfile_; /*! \brief index of label column */ int label_idx_ = 0; /*! \brief store feature names */ diff --git a/include/LightGBM/feature.h b/include/LightGBM/feature.h index c3c8b8b28..1794e3d6f 100644 --- a/include/LightGBM/feature.h +++ b/include/LightGBM/feature.h @@ -80,6 +80,9 @@ public: unsigned int bin = bin_mapper_->ValueToBin(value); bin_data_->Push(tid, line_idx, bin); } + inline void PushBin(int tid, data_size_t line_idx, unsigned int bin) { + bin_data_->Push(tid, line_idx, bin); + } inline void FinishLoad() { bin_data_->FinishLoad(); } /*! \brief Index of this feature */ inline int feature_index() const { return feature_index_; } diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d09ac1add..e844cadba 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -438,7 +438,7 @@ class Dataset(object): if params["has_header"].lower() == "true" or params["header"].lower() == "true": self.data_has_header = True self.handle = ctypes.c_void_p() - _safe_call(_LIB.LGBM_CreateDatasetFromFile( + _safe_call(_LIB.LGBM_DatasetCreateFromFile( c_str(data), c_str(params_str), ref_dataset, @@ -521,7 +521,7 @@ class Dataset(object): data = np.array(mat.reshape(mat.size), dtype=np.float32) ptr_data, type_ptr_data = c_float_array(data) - _safe_call(_LIB.LGBM_CreateDatasetFromMat( + _safe_call(_LIB.LGBM_DatasetCreateFromMat( ptr_data, type_ptr_data, mat.shape[0], @@ -542,7 +542,7 @@ class Dataset(object): ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) ptr_data, type_ptr_data = c_float_array(csr.data) - _safe_call(_LIB.LGBM_CreateDatasetFromCSR( + _safe_call(_LIB.LGBM_DatasetCreateFromCSR( ptr_indptr, type_ptr_indptr, csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), diff --git a/src/c_api.cpp b/src/c_api.cpp index 589e0262d..4e8955c79 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -204,7 +204,7 @@ DllExport const char* LGBM_GetLastError() { return LastErrorMsg(); } -DllExport int LGBM_CreateDatasetFromFile(const char* filename, +DllExport int LGBM_DatasetCreateFromFile(const char* filename, const char* parameters, const DatesetHandle* reference, DatesetHandle* out) { @@ -223,7 +223,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, API_END(); } -DllExport int LGBM_CreateDatasetFromMat(const void* data, +DllExport int LGBM_DatasetCreateFromMat(const void* data, int data_type, int32_t nrow, int32_t ncol, @@ -272,7 +272,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, API_END(); } -DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, +DllExport int LGBM_DatasetCreateFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, @@ -334,7 +334,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, API_END(); } -DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, +DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, @@ -384,6 +384,26 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, API_END(); } +DllExport int LGBM_DatasetGetSubset( + const DatesetHandle* full_data, + const int32_t* used_row_indices, + const int32_t num_used_row_indices, + const char* parameters, + DatesetHandle* out) { + API_BEGIN(); + auto param = ConfigBase::Str2Map(parameters); + IOConfig io_config; + io_config.Set(param); + auto full_dataset = reinterpret_cast(*full_data); + auto ret = std::unique_ptr( + full_dataset->Subset(used_row_indices, + num_used_row_indices, + io_config.is_enable_sparse)); + ret->FinishLoad(); + *out = ret.release(); + API_END(); +} + DllExport int LGBM_DatasetFree(DatesetHandle handle) { API_BEGIN(); delete reinterpret_cast(handle); diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index a8814dc5b..0eb6f99b0 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -19,13 +19,11 @@ const char* Dataset::binary_file_token = "______LightGBM_Binary_File_Token______ Dataset::Dataset() { num_class_ = 1; num_data_ = 0; - is_loading_from_binfile_ = false; } Dataset::Dataset(data_size_t num_data, int num_class) { num_class_ = num_class; num_data_ = num_data; - is_loading_from_binfile_ = false; metadata_.Init(num_data_, num_class_, -1, -1); } @@ -59,6 +57,18 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars feature_names_ = dataset->feature_names_; } +Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const { + auto ret = std::unique_ptr(new Dataset(num_used_indices, num_class_)); + ret->CopyFeatureMapperFrom(this, is_enable_sparse); +#pragma omp parallel for schedule(guided) + for (int fidx = 0; fidx < num_features_; ++fidx) { + auto iterator = features_[fidx]->bin_data()->GetIterator(0); + for (data_size_t i = 0; i < num_used_indices; ++i) { + ret->features_[fidx]->PushBin(0, i, iterator->Get(used_indices[i])); + } + } +} + bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) { std::string name(field_name); name = Common::Trim(name); @@ -118,15 +128,27 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** } void Dataset::SaveBinaryFile(const char* bin_filename) { + bool is_file_existed = false; + FILE* file; +#ifdef _MSC_VER + fopen_s(&file, bin_filename, "rb"); +#else + file = fopen(bin_filename, "rb"); +#endif - if (!is_loading_from_binfile_) { + if (file != NULL) { + is_file_existed = true; + Log::Warning("File %s existed, cannot save binary to it", bin_filename); + fclose(file); + } + + if (!is_file_existed) { std::string bin_filename_str(data_filename_); // if not pass a filename, just append ".bin" of original file if (bin_filename == nullptr || bin_filename[0] == '\0') { bin_filename_str.append(".bin"); bin_filename = bin_filename_str.c_str(); } - FILE* file; #ifdef _MSC_VER fopen_s(&file, bin_filename, "wb"); #else diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index c073d7620..a4ca396ce 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -142,18 +142,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac Please use an additional query file or pre-partition the data"); } } - auto parser = std::unique_ptr(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); - if (parser == nullptr) { - Log::Fatal("Could not recognize data format of %s", filename); - } + auto dataset = std::unique_ptr(new Dataset()); data_size_t num_global_data = 0; std::vector used_data_indices; - auto dataset = std::unique_ptr(new Dataset()); - dataset->data_filename_ = filename; - dataset->num_class_ = io_config_.num_class; - dataset->metadata_.Init(filename, dataset->num_class_); auto bin_filename = CheckCanLoadFromBin(filename); if (bin_filename.size() == 0) { + auto parser = std::unique_ptr(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); + if (parser == nullptr) { + Log::Fatal("Could not recognize data format of %s", filename); + } + dataset->data_filename_ = filename; + dataset->num_class_ = io_config_.num_class; + dataset->metadata_.Init(filename, dataset->num_class_); if (!io_config_.use_two_round_loading) { // read data to memory auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines,&num_global_data, &used_data_indices); @@ -197,18 +197,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) { - auto parser = std::unique_ptr(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); - if (parser == nullptr) { - Log::Fatal("Could not recognize data format of %s", filename); - } data_size_t num_global_data = 0; std::vector used_data_indices; auto dataset = std::unique_ptr(new Dataset()); - dataset->data_filename_ = filename; - dataset->num_class_ = io_config_.num_class; - dataset->metadata_.Init(filename, dataset->num_class_); auto bin_filename = CheckCanLoadFromBin(filename); if (bin_filename.size() == 0) { + auto parser = std::unique_ptr(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); + if (parser == nullptr) { + Log::Fatal("Could not recognize data format of %s", filename); + } + dataset->data_filename_ = filename; + dataset->num_class_ = io_config_.num_class; + dataset->metadata_.Init(filename, dataset->num_class_); if (!io_config_.use_two_round_loading) { // read data in memory auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices); @@ -407,7 +407,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int } dataset->features_.shrink_to_fit(); fclose(file); - dataset->is_loading_from_binfile_ = true; return dataset.release(); } diff --git a/tests/c_api_test/test.py b/tests/c_api_test/test.py index b4690eaca..27feaa350 100644 --- a/tests/c_api_test/test.py +++ b/tests/c_api_test/test.py @@ -16,6 +16,8 @@ def LoadDll(): LIB = LoadDll() +LIB.LGBM_GetLastError.restype = ctypes.c_char_p + dtype_float32 = 0 dtype_float64 = 1 dtype_int32 = 2 @@ -33,9 +35,10 @@ def test_load_from_file(filename, reference): if reference != None: ref = ctypes.byref(reference) handle = ctypes.c_void_p() - LIB.LGBM_CreateDatasetFromFile(c_str(filename), + LIB.LGBM_DatasetCreateFromFile(c_str(filename), c_str('max_bin=15'), ref, ctypes.byref(handle) ) + print(LIB.LGBM_GetLastError()) num_data = ctypes.c_long() LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) ) num_feature = ctypes.c_long() @@ -46,15 +49,6 @@ def test_load_from_file(filename, reference): def test_save_to_binary(handle, filename): LIB.LGBM_DatasetSaveBinary(handle, c_str(filename)) -def test_load_from_binary(filename): - handle = ctypes.c_void_p() - LIB.LGBM_CreateDatasetFromBinaryFile(c_str(filename), ctypes.byref(handle) ) - num_data = ctypes.c_long() - LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) ) - num_feature = ctypes.c_long() - LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) ) - print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) ) - return handle def test_load_from_csr(filename, reference): data = [] @@ -72,7 +66,7 @@ def test_load_from_csr(filename, reference): if reference != None: ref = ctypes.byref(reference) - LIB.LGBM_CreateDatasetFromCSR(c_array(ctypes.c_int, csr.indptr), + LIB.LGBM_DatasetCreateFromCSR(c_array(ctypes.c_int, csr.indptr), dtype_int32, c_array(ctypes.c_int, csr.indices), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), @@ -107,7 +101,7 @@ def test_load_from_csc(filename, reference): if reference != None: ref = ctypes.byref(reference) - LIB.LGBM_CreateDatasetFromCSC(c_array(ctypes.c_int, csr.indptr), + LIB.LGBM_DatasetCreateFromCSC(c_array(ctypes.c_int, csr.indptr), dtype_int32, c_array(ctypes.c_int, csr.indices), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), @@ -142,7 +136,7 @@ def test_load_from_mat(filename, reference): if reference != None: ref = ctypes.byref(reference) - LIB.LGBM_CreateDatasetFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), + LIB.LGBM_DatasetCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), dtype_float64, mat.shape[0], mat.shape[1], @@ -170,7 +164,7 @@ def test_dataset(): test_free_dataset(test) test_save_to_binary(train, 'train.binary.bin') test_free_dataset(train) - train = test_load_from_binary('train.binary.bin') + train = test_load_from_file('train.binary.bin', None) test_free_dataset(train) def test_booster(): train = test_load_from_mat('../../examples/binary_classification/binary.train', None) From 83a141744035a4892e3e292ea9af488fc549ceea Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 15:18:52 +0800 Subject: [PATCH 30/60] some bugs fixed --- include/LightGBM/c_api.h | 2 +- include/LightGBM/dataset.h | 8 +++- python-package/lightgbm/basic.py | 26 ++++++++++--- src/c_api.cpp | 2 +- src/io/dataset.cpp | 3 ++ src/io/metadata.cpp | 63 ++++++++++++++++++++++++++++++++ 6 files changed, 96 insertions(+), 8 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index a1a3269ea..5bc5cd008 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -149,7 +149,7 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data, DllExport int LGBM_DatasetGetSubset( const DatesetHandle* full_data, const int32_t* used_row_indices, - const int32_t num_used_row_indices, + int32_t num_used_row_indices, const char* parameters, DatesetHandle* out); diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 356c1adfb..1bc3b9379 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -47,6 +47,13 @@ public: */ void Init(const char* data_filename, const int num_class); /*! + * \brief init as subset + * \param metadata Filename of data + * \param used_indices + * \param num_used_indices + */ + void Init(const Metadata& metadata, const data_size_t* used_indices, data_size_t num_used_indices); + /*! * \brief Initial with binary memory * \param memory Pointer to memory */ @@ -77,7 +84,6 @@ public: void CheckOrPartition(data_size_t num_all_data, const std::vector& used_data_indices); - void SetLabel(const float* label, data_size_t len); void SetWeights(const float* weights, data_size_t len); diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index e844cadba..edd682b2b 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -410,6 +410,10 @@ class Dataset(object): params: dict, optional other parameters """ + self.__label = None + self.__weight = None + self.__init_score = None + self.__group = None if data is None: self.handle = None return @@ -453,10 +457,6 @@ class Dataset(object): self.__init_from_csr(csr, params_str, ref_dataset) except: raise TypeError('can not initialize Dataset from {}'.format(type(data).__name__)) - self.__label = None - self.__weight = None - self.__init_score = None - self.__group = None if label is not None: self.set_label(label) if self.get_label() is None: @@ -505,6 +505,22 @@ class Dataset(object): return Dataset(data, label=label, max_bin=self.max_bin, reference=self, weight=weight, group_id=group_id, predictor=self.predictor, silent=silent, params=params) + def subset(self, used_indices, params=None): + used_indices = list_to_1d_numpy(used_indices, np.int32) + ret = Dataset(None) + ret.handle = ctypes.c_void_p() + params_str = dict_to_str(params) + _safe_call(_LIB.LGBM_DatasetGetSubset( + ctypes.byref(self.handle), + used_indices.data_as(ctypes.POINTER(ctypes.c_int32)), + used_indices.shape[0], + c_str(params_str), + ctypes.byref(ret.handle))) + ret.max_bin = self.max_bin + ret.predictor = self.predictor + if ret.get_label() is None: + raise ValueError("label should not be None") + return ret def __init_from_np2d(self, mat, params_str, ref_dataset): """ @@ -1102,7 +1118,7 @@ class Booster(object): def __inner_eval(self, data_name, data_idx, feval=None): """ - Evaulate traning or validation data + Evaulate training or validation data """ if data_idx >= self.__num_dataset: raise ValueError("data_idx should be smaller than number of dataset") diff --git a/src/c_api.cpp b/src/c_api.cpp index 4e8955c79..fc672bc13 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -387,7 +387,7 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr, DllExport int LGBM_DatasetGetSubset( const DatesetHandle* full_data, const int32_t* used_row_indices, - const int32_t num_used_row_indices, + int32_t num_used_row_indices, const char* parameters, DatesetHandle* out) { API_BEGIN(); diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 0eb6f99b0..df289faf1 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -55,6 +55,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars num_features_ = static_cast(features_.size()); num_total_features_ = dataset->num_total_features_; feature_names_ = dataset->feature_names_; + label_idx_ = dataset->label_idx_; } Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const { @@ -67,6 +68,8 @@ Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_i ret->features_[fidx]->PushBin(0, i, iterator->Get(used_indices[i])); } } + ret->metadata_.Init(metadata_, used_indices, num_used_indices); + return ret.release(); } bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) { diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 14f9dc946..1906c5bb2 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -50,6 +50,69 @@ void Metadata::Init(data_size_t num_data, int num_class, int weight_idx, int que } } +void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) { + num_data_ = num_used_indices; + num_class_ = fullset.num_class_; + + label_ = std::vector(num_used_indices); + for (data_size_t i = 0; i < num_used_indices; i++) { + label_[i] = fullset.label_[used_indices[i]]; + } + + if (fullset.weights_.size() > 0) { + weights_ = std::vector(num_used_indices); + num_weights_ = num_used_indices; + for (data_size_t i = 0; i < num_used_indices; i++) { + weights_[i] = fullset.weights_[used_indices[i]]; + } + } else { + num_weights_ = 0; + } + + if (fullset.init_score_.size() > 0) { + init_score_ = std::vector(num_used_indices); + num_init_score_ = num_used_indices; + for (data_size_t i = 0; i < num_used_indices; i++) { + init_score_[i] = fullset.init_score_[used_indices[i]]; + } + } else { + num_init_score_ = 0; + } + + if (fullset.query_boundaries_.size() > 0) { + std::vector used_query; + data_size_t data_idx = 0; + for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_indices; ++qid) { + data_size_t start = fullset.query_boundaries_[qid]; + data_size_t end = fullset.query_boundaries_[qid + 1]; + data_size_t len = end - start; + if (used_indices[data_idx] > start) { + continue; + } else if (used_indices[data_idx] == start) { + if (num_used_indices >= data_idx + len && used_indices[data_idx + len - 1] == end - 1) { + used_query.push_back(qid); + data_idx += len; + } else { + Log::Fatal("Data partition error, data didn't match queries"); + } + } else { + Log::Fatal("Data partition error, data didn't match queries"); + } + } + query_boundaries_ = std::vector(used_query.size() + 1); + num_queries_ = static_cast(used_query.size()); + query_boundaries_[0] = 0; + for (data_size_t i = 0; i < num_queries_; ++i) { + data_size_t qid = used_query[i]; + data_size_t len = fullset.query_boundaries_[qid + 1] - fullset.query_boundaries_[qid]; + query_boundaries_[i + 1] = query_boundaries_[i] + len; + } + } else { + num_queries_ = 0; + } + +} + void Metadata::PartitionLabel(const std::vector& used_indices) { if (used_indices.size() <= 0) { return; From 67ca60915033e01b8f9fd664d4089b245f810fd6 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 26 Nov 2016 15:24:37 +0800 Subject: [PATCH 31/60] explicit close file in python --- python-package/lightgbm/basic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index edd682b2b..09a24a301 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -274,7 +274,9 @@ class Predictor(object): predict_type, num_iteration, c_str(tmp_pred_fname))) - lines = open(tmp_pred_fname,"r").readlines() + tmp_file = open(tmp_pred_fname,"r") + lines = tmp_file.readlines() + tmp_file.close() nrow = len(lines) preds = [] for line in lines: @@ -505,7 +507,11 @@ class Dataset(object): return Dataset(data, label=label, max_bin=self.max_bin, reference=self, weight=weight, group_id=group_id, predictor=self.predictor, silent=silent, params=params) + def subset(self, used_indices, params=None): + """ + Get subset of current dataset + """ used_indices = list_to_1d_numpy(used_indices, np.int32) ret = Dataset(None) ret.handle = ctypes.c_void_p() From 83007b1cfe978d8180cd6a3aa4f818ff67ba20a6 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 27 Nov 2016 10:19:47 +0800 Subject: [PATCH 32/60] update some comments --- include/LightGBM/c_api.h | 84 +++++++++++++++++++------------- python-package/lightgbm/basic.py | 32 ++++++------ src/c_api.cpp | 9 ++-- src/io/config.cpp | 13 +++-- 4 files changed, 76 insertions(+), 62 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 5bc5cd008..5176d2bd4 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -65,13 +65,13 @@ DllExport int LGBM_DatasetCreateFromFile(const char* filename, /*! * \brief create a dataset from CSR format * \param indptr pointer to row headers -* \param indptr_type +* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64 * \param indices findex * \param data fvalue -* \param data_type +* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64 * \param nindptr number of rows in the matrix + 1 * \param nelem number of nonzero elements in the matrix -* \param num_col number of columns; when it's set to 0, then guess from data +* \param num_col number of columns * \param parameters additional parameters * \param reference used to align bin mapper with other dataset, nullptr means don't used * \param out created dataset @@ -92,13 +92,13 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr, /*! * \brief create a dataset from CSC format * \param col_ptr pointer to col headers -* \param col_ptr_type +* \param col_ptr_type type of col_ptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64 * \param indices findex * \param data fvalue -* \param data_type -* \param ncol_ptr number of rows in the matrix + 1 +* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64 +* \param ncol_ptr number of cols in the matrix + 1 * \param nelem number of nonzero elements in the matrix -* \param num_row number of rows; when it's set to 0, then guess from data +* \param num_row number of rows * \param parameters additional parameters * \param reference used to align bin mapper with other dataset, nullptr means don't used * \param out created dataset @@ -119,7 +119,7 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr, /*! * \brief create dataset from dense matrix * \param data pointer to the data space -* \param data_type 0 +* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64 * \param nrow number of rows * \param ncol number columns * \param is_row_major 1 for row major, 0 for column major @@ -139,7 +139,7 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data, /*! * \brief Create subset of a data -* \param full_data the full dataset +* \param handle handle of full dataset * \param used_row_indices Indices used in subset * \param num_used_row_indices len of used_row_indices * \param parameters additional parameters @@ -147,7 +147,7 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data, * \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetGetSubset( - const DatesetHandle* full_data, + const DatesetHandle* handle, const int32_t* used_row_indices, int32_t num_used_row_indices, const char* parameters, @@ -170,11 +170,13 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle, /*! * \brief set vector to a content in info +* Note: group and group only work for C_API_DTYPE_INT32 +* label and weight only work for C_API_DTYPE_FLOAT32 * \param handle a instance of dataset -* \param field_name field name, can be label, weight, group +* \param field_name field name, can be label, weight, group, group_id * \param field_data pointer to vector * \param num_element number of element in field_data -* \param type float32 or int32 +* \param type C_API_DTYPE_FLOAT32 or C_API_DTYPE_INT32 * \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetSetField(DatesetHandle handle, @@ -189,7 +191,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, * \param field_name field name * \param out_len used to set result length * \param out_ptr pointer to the result -* \param out_type float32 or int32 +* \param out_type C_API_DTYPE_FLOAT32 or C_API_DTYPE_INT32 * \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_DatasetGetField(DatesetHandle handle, @@ -232,13 +234,13 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, /*! * \brief load an existing boosting from model file * \param filename filename of model -* \param out_num_total_model number of total models +* \param out_num_iterations number of iterations of this booster * \param out handle of created Booster * \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, - int64_t* out_num_total_model, + int64_t* out_num_iterations, BoosterHandle* out); @@ -287,7 +289,8 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param /*! * \brief Get number of class * \param handle handle -* \return number of class +* \param out_len number of class +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len); @@ -322,28 +325,34 @@ DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle); /*! * \brief Get iteration of current boosting rounds -* \return iteration of boosting rounds +* \param out_iteration iteration of boosting rounds +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration); /*! * \brief Get number of eval -* \return total number of eval result +* \param out_len total number of eval results +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len); /*! -* \brief Get number of eval -* \return total number of eval result +* \brief Get Name of eval +* \param out_len total number of eval results +* \param out_strs names of eval result +* \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs); /*! * \brief get evaluation for training data and validation data + Note: 1. you should call LGBM_BoosterGetEvalNames first to get the name of evaluation results + 2. should pre-allocate memory for out_results, you can get its length by LGBM_BoosterGetEvalCounts * \param handle handle * \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param out_len len of output result -* \param out_result the string containing evaluation statistics, should allocate memory before call this function +* \param out_result float arrary contains result * \return 0 when succeed, -1 when failure happens */ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, @@ -353,7 +362,8 @@ DllExport int LGBM_BoosterGetEval(BoosterHandle handle, /*! * \brief Get prediction for training data and validation data -this can be used to support customized eval function + this can be used to support customized eval function + Note: should pre-allocate memory for out_result, its length is equal to num_class * num_data * \param handle handle * \param data_idx 0:training data, 1: 1st valid data, 2:2nd valid data ... * \param out_len len of output result @@ -371,9 +381,9 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle, * \param data_filename filename of data file * \param data_has_header data file has header or not * \param predict_type -* 0:normal, with transform (if needed) -* 1:raw score -* 2:leaf index +* C_API_PREDICT_NORMAL: normal prediction, with transform (if needed) +* C_API_PREDICT_RAW_SCORE: raw score +* C_API_PREDICT_LEAF_INDEX: leaf index * \param num_iteration number of iteration for prediction, <= 0 means no limit * \param result_filename filename of result file * \return 0 when succeed, -1 when failure happens @@ -387,19 +397,22 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle, /*! * \brief make prediction for an new data set +* Note: should pre-allocate memory for out_result, +* for noraml and raw score: its length is equal to num_class * num_data +* for leaf index, its length is equal to num_class * num_data * num_iteration * \param handle handle * \param indptr pointer to row headers -* \param indptr_type +* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64 * \param indices findex * \param data fvalue -* \param data_type +* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64 * \param nindptr number of rows in the matrix + 1 * \param nelem number of nonzero elements in the matrix * \param num_col number of columns; when it's set to 0, then guess from data * \param predict_type -* 0:normal, with transform (if needed) -* 1:raw score -* 2:leaf index +* C_API_PREDICT_NORMAL: normal prediction, with transform (if needed) +* C_API_PREDICT_RAW_SCORE: raw score +* C_API_PREDICT_LEAF_INDEX: leaf index * \param num_iteration number of iteration for prediction, <= 0 means no limit * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function @@ -421,16 +434,19 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle, /*! * \brief make prediction for an new data set +* Note: should pre-allocate memory for out_result, +* for noraml and raw score: its length is equal to num_class * num_data +* for leaf index, its length is equal to num_class * num_data * num_iteration * \param handle handle * \param data pointer to the data space -* \param data_type +* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64 * \param nrow number of rows * \param ncol number columns * \param is_row_major 1 for row major, 0 for column major * \param predict_type -* 0:normal, with transform (if needed) -* 1:raw score -* 2:leaf index +* C_API_PREDICT_NORMAL: normal prediction, with transform (if needed) +* C_API_PREDICT_RAW_SCORE: raw score +* C_API_PREDICT_LEAF_INDEX: leaf index * \param num_iteration number of iteration for prediction, <= 0 means no limit * \param out_len len of output result * \param out_result used to set a pointer to array, should allocate memory before call this function diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 09a24a301..917c3ae11 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -186,43 +186,42 @@ def c_int_array(data): class Predictor(object): """"A Predictor of LightGBM. """ - def __init__(self,model_file=None, params=None, booster_handle=None, is_manage_handle=True): + def __init__(self,model_file=None, booster_handle=None, is_manage_handle=True): """Initialize the Predictor. Parameters ---------- model_file : string Path to the model file. - params : dict - Parameters for boosters. """ self.handle = ctypes.c_void_p() self.__is_manage_handle = True if model_file is not None: """Prediction task""" - out_num_total_model = ctypes.c_int64(0) + out_num_iterations = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( c_str(model_file), - ctypes.byref(out_num_total_model), + ctypes.byref(out_num_iterations), ctypes.byref(self.handle))) - self.__num_total_model = out_num_total_model.value - tmp_out_len = ctypes.c_int64(0) + out_num_class = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetNumClasses( self.handle, - ctypes.byref(tmp_out_len))) - self.num_class = tmp_out_len.value + ctypes.byref(out_num_class))) + self.num_class = out_num_class.value + self.__num_total_model = out_num_iterations.value * self.num_class elif booster_handle is not None: self.__is_manage_handle = is_manage_handle self.handle = booster_handle - tmp_out_len = ctypes.c_int64(0) + out_num_class = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetNumClasses( self.handle, - ctypes.byref(tmp_out_len))) - self.num_class = tmp_out_len.value + ctypes.byref(out_num_class))) + self.num_class = out_num_class.value + out_num_iterations = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetCurrentIteration( self.handle, - ctypes.byref(tmp_out_len))) - self.__num_total_model = self.num_class * tmp_out_len.value + ctypes.byref(out_num_iterations))) + self.__num_total_model = out_num_iterations.value * self.num_class else: raise TypeError('Need Model file to create a booster') @@ -855,12 +854,11 @@ class Booster(object): self.__get_eval_info() elif model_file is not None: """Prediction task""" - out_num_total_model = ctypes.c_int64(0) + out_num_iterations = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( c_str(model_file), - ctypes.byref(out_num_total_model), + ctypes.byref(out_num_iterations), ctypes.byref(self.handle))) - self.__num_total_model = out_num_total_model.value out_num_class = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetNumClasses( self.handle, diff --git a/src/c_api.cpp b/src/c_api.cpp index fc672bc13..680acbcc3 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -385,7 +385,7 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr, } DllExport int LGBM_DatasetGetSubset( - const DatesetHandle* full_data, + const DatesetHandle* handle, const int32_t* used_row_indices, int32_t num_used_row_indices, const char* parameters, @@ -394,7 +394,7 @@ DllExport int LGBM_DatasetGetSubset( auto param = ConfigBase::Str2Map(parameters); IOConfig io_config; io_config.Set(param); - auto full_dataset = reinterpret_cast(*full_data); + auto full_dataset = reinterpret_cast(*handle); auto ret = std::unique_ptr( full_dataset->Subset(used_row_indices, num_used_row_indices, @@ -486,11 +486,12 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, DllExport int LGBM_BoosterCreateFromModelfile( const char* filename, - int64_t* num_total_model, + int64_t* out_num_iterations, BoosterHandle* out) { API_BEGIN(); auto ret = std::unique_ptr(new Booster(filename)); - *num_total_model = static_cast(ret->GetBoosting()->NumberOfTotalModel()); + *out_num_iterations = static_cast(ret->GetBoosting()->NumberOfTotalModel() + / ret->GetBoosting()->NumberOfClasses()); *out = ret.release(); API_END(); } diff --git a/src/io/config.cpp b/src/io/config.cpp index 114a4b004..24ba003d0 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include namespace LightGBM { @@ -95,16 +95,15 @@ void OverallConfig::GetMetricType(const std::unordered_map metrics = Common::Split(value.c_str(), ','); // remove dumplicate - std::unordered_map metric_maps; + std::unordered_set metric_sets; for (auto& metric : metrics) { std::transform(metric.begin(), metric.end(), metric.begin(), Common::tolower); - if (metric_maps.count(metric) <= 0) { - metric_maps[metric] = 1; + if (metric_sets.count(metric) <= 0) { + metric_sets.insert(metric); } } - for (auto& pair : metric_maps) { - std::string sub_metric_str = pair.first; - metric_types.push_back(sub_metric_str); + for (auto& metric : metric_sets) { + metric_types.push_back(metric); } metric_types.shrink_to_fit(); } From 19512d828d79aab32a970346fea6ae85b78c0653 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 27 Nov 2016 15:20:08 +0800 Subject: [PATCH 33/60] remove set_group_id. fixed bug in set num_pred_iterations. --- python-package/lightgbm/basic.py | 98 ++++++++++++++------------------ src/boosting/gbdt.h | 2 + 2 files changed, 45 insertions(+), 55 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 917c3ae11..c42164112 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -123,12 +123,17 @@ def c_array(ctype, values): """Convert a python array to c array.""" return (ctype * len(values))(*values) -def dict_to_str(data): +def param_dict_to_str(data): if data is None or len(data) == 0: return "" pairs = [] - for key in data: - pairs.append(str(key)+'='+str(data[key])) + for key, val in data.items(): + if isinstance(val, list): + pairs.append(str(key)+'='+','.join(val)) + elif isinstance(val, (int, float, str, bool)): + pairs.append(str(key)+'='+str(val)) + else: + raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__)) return ' '.join(pairs) """marco definition of data type in c_api of LightGBM""" C_API_DTYPE_FLOAT32 =0 @@ -145,7 +150,6 @@ C_API_PREDICT_LEAF_INDEX =2 FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32, "wegiht":C_API_DTYPE_FLOAT32, "init_score":C_API_DTYPE_FLOAT32, -"group_id":C_API_DTYPE_INT32, "group":C_API_DTYPE_INT32, } @@ -208,7 +212,7 @@ class Predictor(object): self.handle, ctypes.byref(out_num_class))) self.num_class = out_num_class.value - self.__num_total_model = out_num_iterations.value * self.num_class + self.__num_total_iteration = out_num_iterations.value elif booster_handle is not None: self.__is_manage_handle = is_manage_handle self.handle = booster_handle @@ -221,7 +225,7 @@ class Predictor(object): _safe_call(_LIB.LGBM_BoosterGetCurrentIteration( self.handle, ctypes.byref(out_num_iterations))) - self.__num_total_model = out_num_iterations.value * self.num_class + self.__num_total_iteration = out_num_iterations.value else: raise TypeError('Need Model file to create a booster') @@ -261,9 +265,9 @@ class Predictor(object): predict_type = C_API_PREDICT_RAW_SCORE if pred_leaf: predict_type = C_API_PREDICT_LEAF_INDEX - int_data_has_header = 0 - if data_has_header: - int_data_has_header = 1 + int_data_has_header = 1 if data_has_header else 0 + if num_iteration > self.__num_total_iteration: + num_iteration = self.__num_total_iteration if is_str(data): tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name _safe_call(_LIB.LGBM_BoosterPredictForFile( @@ -303,6 +307,15 @@ class Predictor(object): raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) ) return preds + def __get_num_preds(self, num_iteration, nrow, predict_type): + n_preds = self.num_class * nrow + if predict_type == C_API_PREDICT_LEAF_INDEX: + if num_iteration > 0: + n_preds *= min(num_iteration, self.__num_total_iteration) + else: + n_preds *= self.__num_total_iteration + return n_preds + def __pred_for_np2d(self, mat, num_iteration, predict_type): """ Predict for a 2-D numpy matrix. @@ -316,13 +329,7 @@ class Predictor(object): """change non-float data to float data, need to copy""" data = np.array(mat.reshape(mat.size), dtype=np.float32) ptr_data, type_ptr_data = c_float_array(data) - n_preds = self.num_class * mat.shape[0] - if predict_type == C_API_PREDICT_LEAF_INDEX: - if num_iteration > 0: - n_preds *= num_iteration - else: - used_iteration = self.__num_total_model / self.num_class - n_preds *= used_iteration + n_preds = self.__get_num_preds(num_iteration, mat.shape[0], predict_type) preds = np.zeros(n_preds, dtype=np.float32) out_num_preds = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterPredictForMat( @@ -346,13 +353,7 @@ class Predictor(object): Predict for a csr data """ nrow = len(csr.indptr) - 1 - n_preds = self.num_class * nrow - if predict_type == C_API_PREDICT_LEAF_INDEX: - if num_iteration > 0: - n_preds *= num_iteration - else: - used_iteration = self.__num_total_model / self.num_class - n_preds *= used_iteration + n_preds = self.__get_num_preds(num_iteration, nrow, predict_type) preds = np.zeros(n_preds, dtype=np.float32) out_num_preds = ctypes.c_int64(0) @@ -386,7 +387,7 @@ class Dataset(object): """ def __init__(self, data, label=None, max_bin=255, reference=None, - weight=None, group_id=None, predictor=None, + weight=None, group=None, predictor=None, silent=False, params=None): """ Dataset used in LightGBM. @@ -404,8 +405,8 @@ class Dataset(object): If this dataset validation, need to use training data as reference weight : list or numpy 1-D array , optional Weight for each instance. - group_id : list or numpy 1-D array , optional - group/query id for each instance. Note: if having group/query id, data should group by this id + group : list or numpy 1-D array , optional + group/query size for dataset silent : boolean, optional Whether print messages during construction params: dict, optional @@ -420,8 +421,7 @@ class Dataset(object): return self.data_has_header = False """process for args""" - if params is None: - params = {} + params = {} if params is None else params self.max_bin = max_bin self.predictor = predictor params["max_bin"] = max_bin @@ -429,7 +429,7 @@ class Dataset(object): params["verbose"] = 0 elif "verbose" not in params: params["verbose"] = 1 - params_str = dict_to_str(params) + params_str = param_dict_to_str(params) """process for reference dataset""" ref_dataset = None if isinstance(reference, Dataset): @@ -464,8 +464,8 @@ class Dataset(object): raise ValueError("label should not be None") if weight is not None: self.set_weight(weight) - if group_id is not None: - self.set_group_id(group_id) + if group is not None: + self.set_group(group) # load init score if self.predictor is not None and isinstance(self.predictor, Predictor): init_score = self.predictor.predict(data, @@ -482,7 +482,7 @@ class Dataset(object): init_score = new_init_score self.set_init_score(init_score) - def create_valid(self, data, label=None, weight=None, group_id=None, + def create_valid(self, data, label=None, weight=None, group=None, silent=False, params=None): """ Create validation data align with current dataset @@ -496,15 +496,15 @@ class Dataset(object): Label of the training data. weight : list or numpy 1-D array , optional Weight for each instance. - group_id : list or numpy 1-D array , optional - group/query id for each instance. Note: if having group/query id, data should group by this id + group : list or numpy 1-D array , optional + group/query size for dataset silent : boolean, optional Whether print messages during construction params: dict, optional other parameters """ return Dataset(data, label=label, max_bin=self.max_bin, reference=self, - weight=weight, group_id=group_id, predictor=self.predictor, + weight=weight, group=group, predictor=self.predictor, silent=silent, params=params) def subset(self, used_indices, params=None): @@ -514,7 +514,7 @@ class Dataset(object): used_indices = list_to_1d_numpy(used_indices, np.int32) ret = Dataset(None) ret.handle = ctypes.c_void_p() - params_str = dict_to_str(params) + params_str = param_dict_to_str(params) _safe_call(_LIB.LGBM_DatasetGetSubset( ctypes.byref(self.handle), used_indices.data_as(ctypes.POINTER(ctypes.c_int32)), @@ -624,6 +624,7 @@ class Dataset(object): The array ofdata to be set """ if data is None: + """set to None""" _safe_call(_LIB.LGBM_DatasetSetField( self.handle, c_str(field_name), @@ -713,18 +714,6 @@ class Dataset(object): self.__group = group self.set_field('group', group) - def set_group_id(self, group_id): - - """Set group_id of Dataset (used for ranking). - - Parameters - ---------- - group : array like - group_id of Dataset (used for ranking). - """ - if group_id is not None: - group_id = list_to_1d_numpy(group_id, np.int32) - self.set_field('group_id', group_id) def get_label(self): """Get the label of the Dataset. @@ -817,8 +806,7 @@ class Booster(object): self.handle = ctypes.c_void_p() self.__need_reload_eval_info = True self.__is_manage_handle = True - if params is None: - params = {} + params = {} if params is None else params if silent: params["verbose"] = 0 elif "verbose" not in params: @@ -827,7 +815,7 @@ class Booster(object): """Training task""" if not isinstance(train_set, Dataset): raise TypeError('training data should be Dataset instance, met{}'.format(type(train_set).__name__)) - params_str = dict_to_str(params) + params_str = param_dict_to_str(params) """construct booster object""" _safe_call(_LIB.LGBM_BoosterCreate( train_set.handle, @@ -907,7 +895,7 @@ class Booster(object): params["verbose"] = 0 elif "verbose" not in params: params["verbose"] = 1 - params_str = dict_to_str(params) + params_str = param_dict_to_str(params) _safe_call(_LIB.LGBM_BoosterResetParameter( self.handle, c_str(params_str))) @@ -1162,11 +1150,11 @@ class Booster(object): raise ValueError("data_idx should be smaller than number of dataset") if self.__inner_predict_buffer[data_idx] is None: if data_idx == 0: - num_data = self.train_set.num_data() * self.__num_class + n_preds = self.train_set.num_data() * self.__num_class else: - num_data = self.valid_sets[data_idx - 1].num_data() * self.__num_class + n_preds = self.valid_sets[data_idx - 1].num_data() * self.__num_class self.__inner_predict_buffer[data_idx] = \ - np.array([0.0 for _ in range(num_data)], dtype=np.float32, copy=False) + np.array([0.0 for _ in range(n_preds)], dtype=np.float32, copy=False) """avoid to predict many time in one iteration""" if not self.__is_predicted_cur_iter[data_idx]: tmp_out_len = ctypes.c_int64(0) diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 14d3b3d8d..7a1e2828b 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -181,6 +181,8 @@ public: } else { num_iteration_for_pred_ = static_cast(models_.size()) / num_class_; } + num_iteration_for_pred_ = std::min(num_iteration_for_pred_, + static_cast(models_.size()) / num_class_); } /*! From 63eddae0b0451c9313882d1bb36d8bcc22078264 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 27 Nov 2016 18:34:03 +0800 Subject: [PATCH 34/60] provide a light weight interface for reset learning rate --- include/LightGBM/boosting.h | 6 ++++ python-package/lightgbm/basic.py | 8 +++-- src/application/application.cpp | 54 ++++++++++++++++++-------------- src/boosting/gbdt.h | 8 +++++ src/c_api.cpp | 7 ++++- 5 files changed, 55 insertions(+), 28 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index f725d54e2..f415a3de9 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -51,6 +51,12 @@ public: */ virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) = 0; + /*! + * \brief Reset shrinkage_rate data for current boosting + * \param shrinkage_rate Configs for boosting + */ + virtual void ResetShrinkageRate(double shrinkage_rate) = 0; + /*! * \brief Add a validation data * \param valid_data Validation data diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c42164112..7946ba26a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -128,9 +128,11 @@ def param_dict_to_str(data): return "" pairs = [] for key, val in data.items(): - if isinstance(val, list): - pairs.append(str(key)+'='+','.join(val)) - elif isinstance(val, (int, float, str, bool)): + if is_str(val): + pairs.append(str(key)+'='+str(val)) + elif isinstance(val, (list, tuple)): + pairs.append(str(key)+'='+','.join(map(str,val))) + elif isinstance(val, (int, float, bool)): pairs.append(str(key)+'='+str(val)) else: raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__)) diff --git a/src/application/application.cpp b/src/application/application.cpp index 3501f7299..262ebcfbb 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -144,33 +144,39 @@ void Application::LoadData() { } } train_metric_.shrink_to_fit(); - // Add validation data, if it exists - for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { - // add - auto new_dataset = std::unique_ptr( - dataset_loader.LoadFromFileAlignWithOtherDataset( - config_.io_config.valid_data_filenames[i].c_str(), - train_data_.get()) - ); - valid_datas_.push_back(std::move(new_dataset)); - // need save binary file - if (config_.io_config.is_save_binary_file) { - valid_datas_.back()->SaveBinaryFile(nullptr); - } - // add metric for validation data - valid_metrics_.emplace_back(); - for (auto metric_type : config_.metric_types) { - auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); - if (metric == nullptr) { continue; } - metric->Init(valid_datas_.back()->metadata(), - valid_datas_.back()->num_data()); - valid_metrics_.back().push_back(std::move(metric)); + + if (config_.metric_types.size() > 0) { + // only when have metrics then need to construct validation data + + // Add validation data, if it exists + for (size_t i = 0; i < config_.io_config.valid_data_filenames.size(); ++i) { + // add + auto new_dataset = std::unique_ptr( + dataset_loader.LoadFromFileAlignWithOtherDataset( + config_.io_config.valid_data_filenames[i].c_str(), + train_data_.get()) + ); + valid_datas_.push_back(std::move(new_dataset)); + // need save binary file + if (config_.io_config.is_save_binary_file) { + valid_datas_.back()->SaveBinaryFile(nullptr); + } + + // add metric for validation data + valid_metrics_.emplace_back(); + for (auto metric_type : config_.metric_types) { + auto metric = std::unique_ptr(Metric::CreateMetric(metric_type, config_.metric_config)); + if (metric == nullptr) { continue; } + metric->Init(valid_datas_.back()->metadata(), + valid_datas_.back()->num_data()); + valid_metrics_.back().push_back(std::move(metric)); + } + valid_metrics_.back().shrink_to_fit(); } - valid_metrics_.back().shrink_to_fit(); + valid_datas_.shrink_to_fit(); + valid_metrics_.shrink_to_fit(); } - valid_datas_.shrink_to_fit(); - valid_metrics_.shrink_to_fit(); auto end_time = std::chrono::high_resolution_clock::now(); // output used time on each iteration Log::Info("Finished loading data in %f seconds", diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 7a1e2828b..dbdf30770 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -68,6 +68,14 @@ public: */ void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector& training_metrics) override; + /*! + * \brief Reset shrinkage_rate data for current boosting + * \param shrinkage_rate Configs for boosting + */ + void ResetShrinkageRate(double shrinkage_rate) override { + shrinkage_rate_ = shrinkage_rate; + } + /*! * \brief Adding a validation dataset * \param valid_data Validation dataset diff --git a/src/c_api.cpp b/src/c_api.cpp index 680acbcc3..fc117e753 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -72,7 +72,12 @@ public: Log::Fatal("cannot change boosting_type during training"); } config_.Set(param); - ResetTrainingData(train_data_); + if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) { + // only need to set learning rate + boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate); + } else { + ResetTrainingData(train_data_); + } } void AddValidData(const Dataset* valid_data) { From 27624755bfff3c773e85c61e774eb5b18feb392c Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 27 Nov 2016 20:14:53 +0800 Subject: [PATCH 35/60] add main training logic and callbacks --- python-package/lightgbm/basic.py | 33 +++-- python-package/lightgbm/callback.py | 191 +++++++++++++++++++++++++++ python-package/lightgbm/engine.py | 198 ++++++++++++++++++++++++++++ 3 files changed, 409 insertions(+), 13 deletions(-) create mode 100644 python-package/lightgbm/callback.py create mode 100644 python-package/lightgbm/engine.py diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 7946ba26a..b7dad56a4 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -787,7 +787,6 @@ class Dataset(object): ctypes.byref(ret))) return ret.value - class Booster(object): """"A Booster of of LightGBM. """ @@ -808,6 +807,7 @@ class Booster(object): self.handle = ctypes.c_void_p() self.__need_reload_eval_info = True self.__is_manage_handle = True + self.__train_data_name = "training" params = {} if params is None else params if silent: params["verbose"] = 0 @@ -861,6 +861,9 @@ class Booster(object): if self.handle is not None and self.__is_manage_handle: _safe_call(_LIB.LGBM_BoosterFree(self.handle)) + def set_train_data_name(self, name): + self.__train_data_name = name + def add_valid(self, data, name): """Add an validation data @@ -882,7 +885,7 @@ class Booster(object): self.__inner_predict_buffer.append(None) self.__is_predicted_cur_iter.append(False) - def reset_parameter(self, params, silent=False): + def reset_parameter(self, params): """Reset parameters for booster Parameters @@ -892,11 +895,8 @@ class Booster(object): silent : boolean, optional Whether print messages during construction """ - self.__need_reload_eval_info = True - if silent: - params["verbose"] = 0 - elif "verbose" not in params: - params["verbose"] = 1 + if 'metric' in params: + self.__need_reload_eval_info = True params_str = param_dict_to_str(params) _safe_call(_LIB.LGBM_BoosterResetParameter( self.handle, @@ -1040,7 +1040,7 @@ class Booster(object): result: str Evaluation result list. """ - return self.__inner_eval("training", 0, feval) + return self.__inner_eval(self.__train_data_name, 0, feval) def eval_valid(self, feval=None): """Evaluate for validation data @@ -1129,7 +1129,7 @@ class Booster(object): if tmp_out_len.value != self.__num_inner_eval: raise ValueError("incorrect number of eval results") for i in range(self.__num_inner_eval): - ret.append((data_name, self.__name_inner_eval[i], result[i])) + ret.append((data_name, self.__name_inner_eval[i], result[i], self.__higher_better_inner_eval[i])) if feval is not None: if data_idx == 0: cur_data = self.train_set @@ -1137,11 +1137,11 @@ class Booster(object): cur_data = self.valid_sets[data_idx - 1] feval_ret = feval(self.__inner_predict(data_idx), cur_data) if isinstance(feval_ret, list): - for eval_name, val in feval_ret: - ret.append((data_name, eval_name, val)) + for eval_name, val, is_higher_better in feval_ret: + ret.append((data_name, eval_name, val, is_higher_better)) else: - eval_name, val = feval_ret - ret.append((data_name, eval_name, val)) + eval_name, val, is_higher_better = feval_ret + ret.append((data_name, eval_name, val, is_higher_better)) return ret def __inner_predict(self, data_idx): @@ -1197,3 +1197,10 @@ class Booster(object): self.__name_inner_eval = [] for i in range(self.__num_inner_eval): self.__name_inner_eval.append(string_buffers[i].value.decode()) + self.__higher_better_inner_eval = [] + higher_better_metric = ['auc', 'ndcg'] + for name in self.__name_inner_eval: + if any(name.startswith(x) for x in higher_better_metric): + self.__higher_better_inner_eval.append(True) + else: + self.__higher_better_inner_eval.append(False) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py new file mode 100644 index 000000000..65cd5e225 --- /dev/null +++ b/python-package/lightgbm/callback.py @@ -0,0 +1,191 @@ +from __future__ import absolute_import + +class EarlyStopException(Exception): + """Exception of early stopping. + Parameters + ---------- + best_iteration : int + The best iteration stopped. + """ + def __init__(self, best_iteration): + super(EarlyStopException, self).__init__() + self.best_iteration = best_iteration + +# Callback environment used by callbacks +CallbackEnv = collections.namedtuple( + "LightGBMCallbackEnv", + ["model", + "cvfolds", + "iteration", + "begin_iteration", + "end_iteration", + "evaluation_result_list"]) + +def _format_eval_result(value, show_stdv=True): + """format metric string""" + if len(value) == 4: + return '%s_%s:%g' % (value[0], value[1], value[2]) + elif len(value) == 5: + if show_stdv: + return '%s_%s:%g+%g' % (value[0], value[1], value[2], value[4]) + else: + return '%s_%s:%g' % (value[0], value[1], value[2]) + else: + raise ValueError("wrong metric value") + + +def print_evaluation(period=1, show_stdv=True): + """Create a callback that print evaluation result. + + Parameters + ---------- + period : int + The period to log the evaluation results + + show_stdv : bool, optional + Whether show stdv if provided + + Returns + ------- + callback : function + A callback that print evaluation every period iterations. + """ + def callback(env): + """internal function""" + if len(env.evaluation_result_list) == 0 or period is False: + return + if (env.iteration % period == 0 or env.iteration + 1 == env.begin_iteration): + result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list]) + print('[%d]\t%s\n' % (env.iteration, result)) + return callback + + +def record_evaluation(eval_result): + """Create a call back that records the evaluation history into eval_result. + + Parameters + ---------- + eval_result : dict + A dictionary to store the evaluation results. + + Returns + ------- + callback : function + The requested callback function. + """ + if not isinstance(eval_result, dict): + raise TypeError('eval_result has to be a dictionary') + eval_result.clear() + + def init(env): + """internal function""" + for data_name, eval_name, _ in env.evaluation_result_list: + if data_name not in eval_result: + eval_result[data_name] = {} + if eval_name not in eval_result[data_name]: + eval_result[data_name][eval_name] = [] + + def callback(env): + """internal function""" + if len(eval_result) == 0: + init(env) + for data_name, eval_name, result in env.evaluation_result_list: + eval_result[data_name][eval_name].append(result) + return callback + + +def reset_learning_rate(learning_rates): + """Reset learning rate after iteration 1 + + NOTE: the initial learning rate will still take in-effect on first iteration. + + Parameters + ---------- + learning_rates: list or function + List of learning rate for each boosting round + or a customized function that calculates learning_rate in terms of + current number of round and the total number of boosting round (e.g. yields + learning rate decay) + - list l: learning_rate = l[current_round] + - function f: learning_rate = f(current_round, total_boost_round) + + Returns + ------- + callback : function + The requested callback function. + """ + def callback(env): + """internal function""" + booster = env.model + i = env.iteration + if isinstance(learning_rates, list): + if len(learning_rates) != env.end_iteration: + raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") + booster.reset_parameter({'learning_rate':learning_rates[i]}) + else: + booster.reset_parameter({'learning_rate':learning_rates(i, env.end_iteration)}) + callback.before_iteration = True + return callback + + +def early_stop(stopping_rounds, verbose=True): + """Create a callback that activates early stoppping. + Activates early stopping. + Requires at least one validation data and one metric + If there's more than one, will check all of them + + Parameters + ---------- + stopp_rounds : int + The stopping rounds before the trend occur. + + verbose : optional, bool + Whether to print message about early stopping information. + + Returns + ------- + callback : function + The requested callback function. + """ + is_init = False + + def init(env): + """internal function""" + bst = env.model + + if len(env.evaluation_result_list) == 0: + raise ValueError('For early stopping you need at least one set in evals.') + + if verbose: + msg = "Will train until hasn't improved in {} rounds.\n" + print(msg.format(stopping_rounds)) + best_scores = [ float('-inf') for _ in range(len(env.evaluation_result_list))] + best_iter = [ 0 for _ in range(len(env.evaluation_result_list))] + if verbose: + best_msg = [ "" for _ in range(len(env.evaluation_result_list))] + factor_to_bigger_better = [-1.0 for _ in range(len(env.evaluation_result_list))] + for i in range(len(env.evaluation_result_list)): + if evaluation.evaluation_result_list[i][3]: + factor_to_bigger_better[i] = 1.0 + is_init = True + + def callback(env): + """internal function""" + if not is_init: + init(env) + for i in range(len(env.evaluation_result_list)): + score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i] + if score > best_score[i]: + best_score[i] = score + best_iter[i] = env.iteration + if verbose: + best_msg[i] = '[%d]\t%s' % ( env.iteration, + '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list])) + else: + if env.iteration - best_iter[i] >= stopping_rounds: + if env.model is not None: + env.model.set_attr(best_iteration=str(best_iter[i])) + if verbose: + print('early stopping, best message is:\n {} '.format(best_msg[i])) + raise EarlyStopException(best_iter[i]) + return callback diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py new file mode 100644 index 000000000..4592bb722 --- /dev/null +++ b/python-package/lightgbm/engine.py @@ -0,0 +1,198 @@ +"""Training Library containing training routines of LightGBM.""" +from __future__ import absolute_import + +import collections +import numpy as np +from .basic import LightGBMError, Predictor, Dataset, Booster, is_str +from . import callback + + + +def _construct_dataset(x, y, reference=None, + params=None, other_fields=None, predictor=None): + if 'max_bin' in params: + max_bin = int(params['max_bin']) + else: + max_bin = 255 + weight = None + group = None + init_score = None + if other_fields is not None: + if not is isinstance(other_fields, dict): + raise TypeError("other filed data should be dict type") + weight = None if 'weight' not in other_fields else other_fields['weight'] + group = None if 'group' not in other_fields else other_fields['group'] + init_score = None if 'init_score' not in other_fields else other_fields['init_score'] + if reference is None: + ret = Dataset(x, y, max_bin=max_bin, + weight=weight, group=group, predictor=predictor, params=params) + else: + ret = reference.create_valid(x, y, weight, group, params=params) + if init_score is not None: + ret.set_init_score(init_score) + return ret + +def train(params, train_data, num_boost_round=100, + valid_datas=None, valid_names=None, + fobj=None, feval=None, init_model=None, + train_fields=None, valid_fields=None, + early_stopping_rounds=None, out_eval_result=None, + verbose_eval=True, learning_rates=None, callbacks=None): + """Train with given parameters. + + Parameters + ---------- + params : dict + params. + train_data : pair, (X, y) + Data to be trained. + num_boost_round: int + Number of boosting iterations. + valid_datas: list of pairs (valid_X, valid_y) + List of data to be evaluated during training + valid_names: list of string + names of valid_datas + fobj : function + Customized objective function. + feval : function + Customized evaluation function. + Note: should return (eval_name, eval_result, is_higher_better) of list of this + init_model : file name of lightgbm model or 'Booster' instance + model used for continued train + train_fields : dict + other data file in training data. e.g. train_fields['weight'] is weight data + support fields: weight, group, init_score + valid_fields : dict + other data file in training data. e.g. valid_fields[0]['weight'] is weight data for first valid data + support fields: weight, group, init_score + early_stopping_rounds: int + Activates early stopping. + Requires at least one validation data and one metric + If there's more than one, will check all of them + Returns the model with (best_iter + early_stopping_rounds) + If early stopping occurs, the model will add 'best_iteration' field + out_eval_result: dict or None + This dictionary used to store all evaluation results of all the items in valid_datas. + Example: with a valid_datas containing [dtest, dtrain] and valid_names containing ['eval', 'train'] and + a paramater containing ('metric':'logloss') + Returns: {'train': {'logloss': ['0.48253', '0.35953', ...]}, + 'eval': {'logloss': ['0.480385', '0.357756', ...]}} + passed with None means no using this function + verbose_eval : bool or int + Requires at least one item in evals. + If `verbose_eval` is True then the evaluation metric on the validation set is + printed at each boosting stage. + If `verbose_eval` is an integer then the evaluation metric on the validation set + is printed at every given `verbose_eval` boosting stage. The last boosting stage + / the boosting stage found by using `early_stopping_rounds` is also printed. + Example: with verbose_eval=4 and at least one item in evals, an evaluation metric + is printed every 4 boosting stages, instead of every boosting stage. + learning_rates: list or function + List of learning rate for each boosting round + or a customized function that calculates learning_rate in terms of + current number of round and the total number of boosting round (e.g. yields + learning rate decay) + - list l: learning_rate = l[current_round] + - function f: learning_rate = f(current_round, total_boost_round) + callbacks : list of callback functions + List of callback functions that are applied at end of each iteration. + + Returns + ------- + booster : a trained booster model + """ + """create predictor first""" + if is_str(init_model): + predictor = Predictor(model_file=init_model) + elif isinstance(init_model, Booster): + predictor = Booster.to_predictor() + elif isinstance(init_model, Predictor): + predictor = init_model + else: + predictor = None + """create dataset""" + train_set = _construct_dataset(train_data[0], train_data[1], None, params, train_fields, predictor, silent) + is_valid_contain_train = False + train_data_name = "training" + valid_sets = [] + name_valid_sets = [] + if valid_datas is not None: + for i in range(len(valid_datas)): + other_fields = None if valid_fields is None else valid_fields[i] + """reduce cost for prediction training data""" + if valid_datas[i] is train_data: + is_valid_contain_train = True + train_data_name = valid_names[i] + continue + valid_set = _construct_dataset( + valid_datas[i][0], + valid_datas[i][1], + train_set, + params, + other_fields, + predictor, + silent) + valid_sets.append(valid_set) + name_valid_sets.append(valid_names[i]) + """process callbacks""" + callbacks = [] if callbacks is None else callbacks + + # Most of legacy advanced options becomes callbacks + if isinstance(verbose_eval, bool) and verbose_eval: + callbacks.append(callback.print_evaluation()) + else: + if isinstance(verbose_eval, int): + callbacks.append(callback.print_evaluation(verbose_eval)) + + if early_stopping_rounds is not None: + callbacks.append(callback.early_stop(early_stopping_rounds, + verbose=bool(verbose_eval))) + if learning_rates is not None: + callbacks.append(callback.reset_learning_rate(learning_rates)) + + if evals_result is not None: + callbacks.append(callback.record_evaluation(evals_result)) + + callbacks_before_iter = [ + cb for cb in callbacks if cb.__dict__.get('before_iteration', False)] + callbacks_after_iter = [ + cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)] + """construct booster""" + booster = Booster(params=params, train_set=train_set, silent=silent) + if is_valid_contain_train: + booster.set_train_data_name(train_data_name) + for i in range(len(valid_sets)): + booster.add_valid(valid_sets[i], name_valid_sets[i]) + """start training""" + for i in range(num_boost_round): + for cb in callbacks_before_iter: + cb(CallbackEnv(model=booster, + cvfolds=None, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=None)) + + booster.update(fobj=fobj) + + evaluation_result_list = [] + # check evaluation result. + if len(valid_sets) != 0: + if is_valid_contain_train: + evaluation_result_list.extend(booster.eval_train(feval)) + evaluation_result_list.extend(booster.eval_valid(feval)) + try: + for cb in callbacks_after_iter: + cb(CallbackEnv(model=booster, + cvfolds=None, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=evaluation_result_list)) + except EarlyStopException: + break + if booster.attr('best_iteration') is not None: + booster.best_iteration = int(booster.attr('best_iteration')) + else: + booster.best_iteration = num_boost_round - 1 + return num_boost_round \ No newline at end of file From 595c10ab26165691502dc27219b454b21c24bd1d Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sun, 27 Nov 2016 20:25:28 +0800 Subject: [PATCH 36/60] some naming fix --- python-package/lightgbm/basic.py | 7 ++++--- python-package/lightgbm/engine.py | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index b7dad56a4..4a1670642 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -898,9 +898,10 @@ class Booster(object): if 'metric' in params: self.__need_reload_eval_info = True params_str = param_dict_to_str(params) - _safe_call(_LIB.LGBM_BoosterResetParameter( - self.handle, - c_str(params_str))) + if params_str: + _safe_call(_LIB.LGBM_BoosterResetParameter( + self.handle, + c_str(params_str))) def update(self, train_set=None, fobj=None): """ diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 4592bb722..f08e426e3 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -166,7 +166,7 @@ def train(params, train_data, num_boost_round=100, """start training""" for i in range(num_boost_round): for cb in callbacks_before_iter: - cb(CallbackEnv(model=booster, + cb(callback.CallbackEnv(model=booster, cvfolds=None, iteration=i, begin_iteration=0, @@ -183,16 +183,16 @@ def train(params, train_data, num_boost_round=100, evaluation_result_list.extend(booster.eval_valid(feval)) try: for cb in callbacks_after_iter: - cb(CallbackEnv(model=booster, + cb(callback.CallbackEnv(model=booster, cvfolds=None, iteration=i, begin_iteration=0, end_iteration=num_boost_round, evaluation_result_list=evaluation_result_list)) - except EarlyStopException: + except callback.EarlyStopException: break if booster.attr('best_iteration') is not None: booster.best_iteration = int(booster.attr('best_iteration')) else: booster.best_iteration = num_boost_round - 1 - return num_boost_round \ No newline at end of file + return num_boost_round From 81f459474674bdacd45577d728a8c7a0bfa16cc2 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 29 Nov 2016 14:36:38 +0800 Subject: [PATCH 37/60] add cv support --- python-package/lightgbm/callback.py | 3 +- python-package/lightgbm/engine.py | 228 ++++++++++++++++++++++++++-- 2 files changed, 216 insertions(+), 15 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 65cd5e225..6b9e5c054 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -148,7 +148,7 @@ def early_stop(stopping_rounds, verbose=True): The requested callback function. """ is_init = False - + final_best_iter = 0 def init(env): """internal function""" bst = env.model @@ -183,6 +183,7 @@ def early_stop(stopping_rounds, verbose=True): '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list])) else: if env.iteration - best_iter[i] >= stopping_rounds: + final_best_iter = best_iter[i] if env.model is not None: env.model.set_attr(best_iteration=str(best_iter[i])) if verbose: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index f08e426e3..d621357c3 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -6,9 +6,7 @@ import numpy as np from .basic import LightGBMError, Predictor, Dataset, Booster, is_str from . import callback - - -def _construct_dataset(x, y, reference=None, +def _construct_dataset(data, reference=None, params=None, other_fields=None, predictor=None): if 'max_bin' in params: max_bin = int(params['max_bin']) @@ -24,10 +22,17 @@ def _construct_dataset(x, y, reference=None, group = None if 'group' not in other_fields else other_fields['group'] init_score = None if 'init_score' not in other_fields else other_fields['init_score'] if reference is None: - ret = Dataset(x, y, max_bin=max_bin, + if is_str(data): + ret = Dataset(data, label=None, max_bin=max_bin, weight=weight, group=group, predictor=predictor, params=params) + else: + ret = Dataset(data[0], data[1], max_bin=max_bin, + weight=weight, group=group, predictor=predictor, params=params) else: - ret = reference.create_valid(x, y, weight, group, params=params) + if is_str(data): + ret = reference.create_valid(data, label=None, weight=weight, group=group, params=params) + else: + ret = reference.create_valid(data[0], data[1], weight, group, params=params) if init_score is not None: ret.set_init_score(init_score) return ret @@ -44,11 +49,11 @@ def train(params, train_data, num_boost_round=100, ---------- params : dict params. - train_data : pair, (X, y) + train_data : pair, (X, y) or filename of data Data to be trained. num_boost_round: int Number of boosting iterations. - valid_datas: list of pairs (valid_X, valid_y) + valid_datas: list of pairs (valid_X, valid_y) or filename of data List of data to be evaluated during training valid_names: list of string names of valid_datas @@ -73,7 +78,7 @@ def train(params, train_data, num_boost_round=100, If early stopping occurs, the model will add 'best_iteration' field out_eval_result: dict or None This dictionary used to store all evaluation results of all the items in valid_datas. - Example: with a valid_datas containing [dtest, dtrain] and valid_names containing ['eval', 'train'] and + Example: with a valid_datas containing [valid_set, train_set] and valid_names containing ['eval', 'train'] and a paramater containing ('metric':'logloss') Returns: {'train': {'logloss': ['0.48253', '0.35953', ...]}, 'eval': {'logloss': ['0.480385', '0.357756', ...]}} @@ -111,7 +116,7 @@ def train(params, train_data, num_boost_round=100, else: predictor = None """create dataset""" - train_set = _construct_dataset(train_data[0], train_data[1], None, params, train_fields, predictor, silent) + train_set = _construct_dataset(train_data, None, params, train_fields, predictor) is_valid_contain_train = False train_data_name = "training" valid_sets = [] @@ -125,13 +130,11 @@ def train(params, train_data, num_boost_round=100, train_data_name = valid_names[i] continue valid_set = _construct_dataset( - valid_datas[i][0], - valid_datas[i][1], + valid_datas[i], train_set, params, other_fields, - predictor, - silent) + predictor) valid_sets.append(valid_set) name_valid_sets.append(valid_names[i]) """process callbacks""" @@ -158,7 +161,7 @@ def train(params, train_data, num_boost_round=100, callbacks_after_iter = [ cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)] """construct booster""" - booster = Booster(params=params, train_set=train_set, silent=silent) + booster = Booster(params=params, train_set=train_set) if is_valid_contain_train: booster.set_train_data_name(train_data_name) for i in range(len(valid_sets)): @@ -196,3 +199,200 @@ def train(params, train_data, num_boost_round=100, else: booster.best_iteration = num_boost_round - 1 return num_boost_round + + +class CVBooster(object): + """"Auxiliary datastruct to hold one fold of CV.""" + def __init__(self, train_set, valid_test, param): + """"Initialize the CVBooster""" + self.train_set = train_set + self.valid_test = valid_test + self.booster = Booster(params=params, train_set=train_set) + self.booster.add_valid(valid_test, 'valid') + + def update(self, fobj): + """"Update the boosters for one iteration""" + self.booster.update(fobj=fobj) + + def eval(self, feval): + """"Evaluate the CVBooster for one iteration.""" + return self.booster.eval_valid(feval) + +try: + try: + from sklearn.model_selection import KFold, StratifiedKFold + except ImportError: + from sklearn.cross_validation import KFold, StratifiedKFold + SKLEARN_StratifiedKFold = True +except ImportError: + SKLEARN_StratifiedKFold = False + +def _make_n_folds(full_data, nfold, param, seed, fpreproc=None, stratified=False): + """ + Make an n-fold list of CVBooster from random indices. + """ + np.random.seed(seed) + if stratified: + if SKLEARN_StratifiedKFold: + sfk = StratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed) + idset = [x[1] for x in sfk.split(X=full_data.get_label(), y=full_data.get_label())] + else: + raise LightGBMError('sklearn needs to be installed in order to use stratified cv') + else: + randidx = np.random.permutation(full_data.num_data()) + kstep = int(len(randidx) / nfold) + idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)] + + ret = [] + for k in range(nfold): + train_set = full_data.subset(np.concatenate([idset[i] for i in range(nfold) if k != i])) + valid_set = full_data.subset(idset[k]) + # run preprocessing on the data set if needed + if fpreproc is not None: + train_set, valid_set, tparam = fpreproc(train_set, valid_set, param.copy()) + else: + tparam = param + ret.append(CVBooster(train_set, valid_set, tparam)) + return ret + +def _agg_cv_result(raw_results): + # pylint: disable=invalid-name + """ + Aggregate cross-validation results. + """ + cvmap = {} + metric_type = {} + for one_result in raw_results: + for one_line in one_result: + key = one_line[1] + metric_type[key] = one_line[3] + if key not in cvmap: + cvmap[key] = [] + cvmap[key].append(one_result[2]) + results = [] + for k, v in cvmap.items(): + v = np.array(v) + mean, std = np.mean(v), np.std(v) + results.extend(['cv_agg', k, mean, metric_type[k], std]) + return results + +def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, + metrics=(), fobj=None, feval=None, train_fields=None, early_stopping_rounds=None, + fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, + callbacks=None): + # pylint: disable = invalid-name + """Cross-validation with given paramaters. + + Parameters + ---------- + params : dict + Booster params. + train_data : pair, (X, y) or filename of data + Data to be trained. + num_boost_round : int + Number of boosting iterations. + nfold : int + Number of folds in CV. + stratified : bool + Perform stratified sampling. + folds : a KFold or StratifiedKFold instance + Sklearn KFolds or StratifiedKFolds. + metrics : string or list of strings + Evaluation metrics to be watched in CV. + fobj : function + Custom objective function. + feval : function + Custom evaluation function. + train_fields : dict + other data file in training data. e.g. train_fields['weight'] is weight data + support fields: weight, group, init_score + early_stopping_rounds: int + Activates early stopping. CV error needs to decrease at least + every round(s) to continue. + Last entry in evaluation history is the one from best iteration. + fpreproc : function + Preprocessing function that takes (dtrain, dtest, param) and returns + transformed versions of those. + verbose_eval : bool, int, or None, default None + Whether to display the progress. If None, progress will be displayed + when np.ndarray is returned. If True, progress will be displayed at + boosting stage. If an integer is given, progress will be displayed + at every given `verbose_eval` boosting stage. + show_stdv : bool, default True + Whether to display the standard deviation in progress. + Results are not affected, and always contains std. + seed : int + Seed used to generate the folds (passed to numpy.random.seed). + callbacks : list of callback functions + List of callback functions that are applied at end of each iteration. + + Returns + ------- + evaluation history : list(string) + """ + + if isinstance(metrics, str): + metrics = [metrics] + + if isinstance(params, list): + params = dict(params) + + if not 'metric' in params: + params['metric'] = [] + + if len(metric) > 0: + params['metric'].extend(metric) + + train_set = _construct_dataset(train_data, None, params, train_fields) + + results = {} + cvfolds = _make_n_folds(train_set, nfold, params, seed, fpreproc, stratified) + + # setup callbacks + callbacks = [] if callbacks is None else callbacks + if early_stopping_rounds is not None: + callbacks.append(callback.early_stop(early_stopping_rounds, + verbose=False)) + if isinstance(verbose_eval, bool) and verbose_eval: + callbacks.append(callback.print_evaluation(show_stdv=show_stdv)) + else: + if isinstance(verbose_eval, int): + callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv)) + + callbacks_before_iter = [ + cb for cb in callbacks if cb.__dict__.get('before_iteration', False)] + callbacks_after_iter = [ + cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)] + + for i in range(num_boost_round): + for cb in callbacks_before_iter: + cb(callback.CallbackEnv(model=None, + cvfolds=cvfolds, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=None)) + for fold in cvfolds: + fold.update(fobj) + res = aggcv([f.eval(feval) for f in cvfolds]) + + for _, key, mean, _, std in res: + if key + '-mean' not in results: + results[key + '-mean'] = [] + if key + '-std' not in results: + results[key + '-std'] = [] + results[key + '-mean'].append(mean) + results[key + '-std'].append(std) + try: + for cb in callbacks_after_iter: + cb(callback.CallbackEnv(model=None, + cvfolds=cvfolds, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=res)) + except callback.EarlyStopException as e: + for k in results.keys(): + results[k] = results[k][:(e.final_best_iter + 1)] + break + return results From 6b288215c9f24a8f0e81e884cd1d93dbc678605b Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 29 Nov 2016 14:55:36 +0800 Subject: [PATCH 38/60] fix keyword error in VS2013 --- include/LightGBM/c_api.h | 5 ++++- include/LightGBM/utils/log.h | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 5176d2bd4..f054af481 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -495,9 +495,12 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi std::vector SampleFromOneColumn(const std::vector>& data, const std::vector& indices); - +#if defined(_MSC_VER) // exception handle and error msg +static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; } +#else static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; } +#endif inline void LGBM_SetLastError(const char* msg) { std::strcpy(LastErrorMsg(), msg); diff --git a/include/LightGBM/utils/log.h b/include/LightGBM/utils/log.h index eb2efc49e..9957a0d1f 100644 --- a/include/LightGBM/utils/log.h +++ b/include/LightGBM/utils/log.h @@ -89,7 +89,11 @@ private: // a trick to use static variable in header file. // May be not good, but avoid to use an additional cpp file +#if defined(_MSC_VER) + static LogLevel& GetLevel() { static __declspec(thread) LogLevel level = LogLevel::Info; return level; } +#else static LogLevel& GetLevel() { static thread_local LogLevel level = LogLevel::Info; return level; } +#endif }; From c861be930efa69a0684cb0c955b225a2baefa980 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 29 Nov 2016 15:28:22 +0800 Subject: [PATCH 39/60] some bugs fix --- python-package/lightgbm/basic.py | 36 ++++++++++++++++++++++++++++- python-package/lightgbm/callback.py | 1 + python-package/lightgbm/engine.py | 32 +++++++++++++++---------- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4a1670642..e0af50518 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -519,7 +519,7 @@ class Dataset(object): params_str = param_dict_to_str(params) _safe_call(_LIB.LGBM_DatasetGetSubset( ctypes.byref(self.handle), - used_indices.data_as(ctypes.POINTER(ctypes.c_int32)), + used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), used_indices.shape[0], c_str(params_str), ctypes.byref(ret.handle))) @@ -808,6 +808,7 @@ class Booster(object): self.__need_reload_eval_info = True self.__is_manage_handle = True self.__train_data_name = "training" + self.__attr = {} params = {} if params is None else params if silent: params["verbose"] = 0 @@ -1205,3 +1206,36 @@ class Booster(object): self.__higher_better_inner_eval.append(True) else: self.__higher_better_inner_eval.append(False) + def attr(self, key): + """Get attribute string from the Booster. + + Parameters + ---------- + key : str + The key to get attribute from. + + Returns + ------- + value : str + The attribute value of the key, returns None if attribute do not exist. + """ + if key in self.__attr: + return self.__attr[key] + else: + return None + + def set_attr(self, **kwargs): + """Set the attribute of the Booster. + + Parameters + ---------- + **kwargs + The attributes to set. Setting a value to None deletes an attribute. + """ + for key, value in kwargs.items(): + if value is not None: + if not isinstance(value, STRING_TYPES): + raise ValueError("Set Attr only accepts string values") + self.__attr[key] = value + else: + self.__attr.pop(key, None) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 6b9e5c054..f89c93ae3 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +import collections class EarlyStopException(Exception): """Exception of early stopping. diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index d621357c3..01f5f53c0 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -16,7 +16,7 @@ def _construct_dataset(data, reference=None, group = None init_score = None if other_fields is not None: - if not is isinstance(other_fields, dict): + if not isinstance(other_fields, dict): raise TypeError("other filed data should be dict type") weight = None if 'weight' not in other_fields else other_fields['weight'] group = None if 'group' not in other_fields else other_fields['group'] @@ -127,7 +127,8 @@ def train(params, train_data, num_boost_round=100, """reduce cost for prediction training data""" if valid_datas[i] is train_data: is_valid_contain_train = True - train_data_name = valid_names[i] + if valid_names is not None: + train_data_name = valid_names[i] continue valid_set = _construct_dataset( valid_datas[i], @@ -136,7 +137,10 @@ def train(params, train_data, num_boost_round=100, other_fields, predictor) valid_sets.append(valid_set) - name_valid_sets.append(valid_names[i]) + if valid_names is not None: + name_valid_sets.append(valid_names[i]) + else: + name_valid_sets.append('valid_'+str(i)) """process callbacks""" callbacks = [] if callbacks is None else callbacks @@ -153,8 +157,8 @@ def train(params, train_data, num_boost_round=100, if learning_rates is not None: callbacks.append(callback.reset_learning_rate(learning_rates)) - if evals_result is not None: - callbacks.append(callback.record_evaluation(evals_result)) + if out_eval_result is not None: + callbacks.append(callback.record_evaluation(out_eval_result)) callbacks_before_iter = [ cb for cb in callbacks if cb.__dict__.get('before_iteration', False)] @@ -203,7 +207,7 @@ def train(params, train_data, num_boost_round=100, class CVBooster(object): """"Auxiliary datastruct to hold one fold of CV.""" - def __init__(self, train_set, valid_test, param): + def __init__(self, train_set, valid_test, params): """"Initialize the CVBooster""" self.train_set = train_set self.valid_test = valid_test @@ -268,12 +272,12 @@ def _agg_cv_result(raw_results): metric_type[key] = one_line[3] if key not in cvmap: cvmap[key] = [] - cvmap[key].append(one_result[2]) + cvmap[key].append(one_line[2]) results = [] for k, v in cvmap.items(): v = np.array(v) mean, std = np.mean(v), np.std(v) - results.extend(['cv_agg', k, mean, metric_type[k], std]) + results.append(('cv_agg', k, mean, metric_type[k], std)) return results def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, @@ -339,9 +343,14 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, if not 'metric' in params: params['metric'] = [] + else: + if is_str(params['metric']): + params['metric'] = params['metric'].split(',') + else: + params['metric'] = list(params['metric']) - if len(metric) > 0: - params['metric'].extend(metric) + if metrics is not None and len(metrics) > 0: + params['metric'].extend(metrics) train_set = _construct_dataset(train_data, None, params, train_fields) @@ -374,8 +383,7 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, evaluation_result_list=None)) for fold in cvfolds: fold.update(fobj) - res = aggcv([f.eval(feval) for f in cvfolds]) - + res = _agg_cv_result([f.eval(feval) for f in cvfolds]) for _, key, mean, _, std in res: if key + '-mean' not in results: results[key + '-mean'] = [] From 49d8564274d5ac82ed68d116ed073ef1b672353a Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 29 Nov 2016 17:46:41 +0800 Subject: [PATCH 40/60] add sklearn like basic model --- include/LightGBM/config.h | 9 + python-package/lightgbm/callback.py | 4 +- python-package/lightgbm/engine.py | 14 +- python-package/lightgbm/sklearn.py | 270 ++++++++++++++++++++++++++++ src/io/config.cpp | 1 + src/objective/binary_objective.hpp | 3 + 6 files changed, 295 insertions(+), 6 deletions(-) create mode 100644 python-package/lightgbm/sklearn.py diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 5ec22cb63..19ae190f2 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -138,6 +138,8 @@ public: bool is_unbalance = false; // for multiclass int num_class = 1; + // Balancing of positive and negative weights + double scale_pos_weight = 1.0f; void Set(const std::unordered_map& params) override; }; @@ -333,14 +335,18 @@ struct ParameterAlias { { "min_sum_hessian_per_leaf", "min_sum_hessian_in_leaf" }, { "min_sum_hessian", "min_sum_hessian_in_leaf" }, { "min_hessian", "min_sum_hessian_in_leaf" }, + { "min_child_weight", "min_sum_hessian_in_leaf" }, { "num_leaf", "num_leaves" }, { "sub_feature", "feature_fraction" }, + { "colsample_bytree", "feature_fraction" }, { "num_iteration", "num_iterations" }, { "num_tree", "num_iterations" }, { "num_round", "num_iterations" }, { "num_trees", "num_iterations" }, { "num_rounds", "num_iterations" }, { "sub_row", "bagging_fraction" }, + { "subsample", "bagging_fraction" }, + { "subsample_freq", "bagging_freq" }, { "shrinkage_rate", "learning_rate" }, { "tree", "tree_learner" }, { "num_machine", "num_machines" }, @@ -363,6 +369,9 @@ struct ParameterAlias { { "blacklist", "ignore_column" }, { "predict_raw_score", "is_predict_raw_score" }, { "predict_leaf_index", "is_predict_leaf_index" }, + { "gamma", "min_gain_to_split" }, + { "reg_alpha", "lambda_l1" }, + { "reg_lambda", "lambda_l2" }, { "num_classes", "num_class" } }); std::unordered_map tmp_map; diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index f89c93ae3..c574b875a 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -80,7 +80,7 @@ def record_evaluation(eval_result): def init(env): """internal function""" - for data_name, eval_name, _ in env.evaluation_result_list: + for data_name, eval_name, _, _ in env.evaluation_result_list: if data_name not in eval_result: eval_result[data_name] = {} if eval_name not in eval_result[data_name]: @@ -90,7 +90,7 @@ def record_evaluation(eval_result): """internal function""" if len(eval_result) == 0: init(env) - for data_name, eval_name, result in env.evaluation_result_list: + for data_name, eval_name, result, _ in env.evaluation_result_list: eval_result[data_name][eval_name].append(result) return callback diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 01f5f53c0..65eb598ea 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -41,7 +41,7 @@ def train(params, train_data, num_boost_round=100, valid_datas=None, valid_names=None, fobj=None, feval=None, init_model=None, train_fields=None, valid_fields=None, - early_stopping_rounds=None, out_eval_result=None, + early_stopping_rounds=None, evals_result=None, verbose_eval=True, learning_rates=None, callbacks=None): """Train with given parameters. @@ -76,7 +76,7 @@ def train(params, train_data, num_boost_round=100, If there's more than one, will check all of them Returns the model with (best_iter + early_stopping_rounds) If early stopping occurs, the model will add 'best_iteration' field - out_eval_result: dict or None + evals_result: dict or None This dictionary used to store all evaluation results of all the items in valid_datas. Example: with a valid_datas containing [valid_set, train_set] and valid_names containing ['eval', 'train'] and a paramater containing ('metric':'logloss') @@ -157,14 +157,20 @@ def train(params, train_data, num_boost_round=100, if learning_rates is not None: callbacks.append(callback.reset_learning_rate(learning_rates)) - if out_eval_result is not None: - callbacks.append(callback.record_evaluation(out_eval_result)) + if evals_result is not None: + callbacks.append(callback.record_evaluation(evals_result)) callbacks_before_iter = [ cb for cb in callbacks if cb.__dict__.get('before_iteration', False)] callbacks_after_iter = [ cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)] """construct booster""" + if 'metric' in params: + if is_str(params['metric']): + params['metric'] = params['metric'].split(',') + else: + params['metric'] = list(params['metric']) + booster = Booster(params=params, train_set=train_set) if is_valid_contain_train: booster.set_train_data_name(train_data_name) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py new file mode 100644 index 000000000..227872a78 --- /dev/null +++ b/python-package/lightgbm/sklearn.py @@ -0,0 +1,270 @@ +"""Scikit-Learn Wrapper interface for LightGBM.""" +from __future__ import absolute_import + +import numpy as np +from .basic import LightGBMError, Predictor, Dataset, Booster, is_str +from .engine import train +# sklearn +try: + from sklearn.base import BaseEstimator + from sklearn.base import RegressorMixin, ClassifierMixin + from sklearn.preprocessing import LabelEncoder + SKLEARN_INSTALLED = True + LGBMModelBase = BaseEstimator + LGBMRegressorBase = RegressorMixin + LGBMClassifierBase = ClassifierMixin + LGBMLabelEncoder = LabelEncoder +except ImportError: + SKLEARN_INSTALLED = False + LGBMModelBase = object + LGBMClassifierBase = object + LGBMRegressorBase = object + LGBMLabelEncoder = None + +def _objective_decorator(func): + """Decorate an objective function + + Converts an objective function using the typical sklearn metrics to LightGBM ffobj + + Note: for multi-class task, the label/pred is group by class_id first, then group by row_id + if you want to get i-th row label/pred in j-th class, the access way is label/pred[j*num_data+i] + and you should group grad and hess in this way as well + Parameters + ---------- + func: callable + Expects a callable with signature ``func(y_true, y_pred)``: + + y_true: array_like of shape [n_samples] + The target values + y_pred: array_like of shape [n_samples] + The predicted values + + Returns + ------- + new_func: callable + The new objective function as expected by ``lightgbm.engine.train``. + The signature is ``new_func(preds, dataset)``: + + preds: array_like, shape [n_samples] + The predicted values + dataset: ``dataset`` + The training set from which the labels will be extracted using + ``dataset.get_label()`` + """ + def inner(preds, dataset): + """internal function""" + labels = dataset.get_label() + return func(labels, preds) + return inner + +class LGBMModel(LGBMModelBase): + """Implementation of the Scikit-Learn API for LightGBM. + + Parameters + ---------- + num_leaves : int + Maximum tree leaves for base learners. + max_depth : int + Maximum tree depth for base learners, -1 means not limit. + learning_rate : float + Boosting learning rate + n_estimators : int + Number of boosted trees to fit. + silent : boolean + Whether to print messages while running boosting. + objective : string or callable + Specify the learning task and the corresponding learning objective or + a custom objective function to be used (see note below). + num_class: int + only affect for multi-class training. + nthread : int + Number of parallel threads + gamma : float + Minimum loss reduction required to make a further partition on a leaf node of the tree. + min_child_weight : int + Minimum sum of instance weight(hessian) needed in a child. + subsample : float + Subsample ratio of the training instance. + subsample_freq : int + frequence of subsample, <=0 means no enable + colsample_bytree : float + Subsample ratio of columns when constructing each tree. + colsample_byleaf : float + Subsample ratio of columns when constructing each leaf. + reg_alpha : float + L1 regularization term on weights + reg_lambda : float + L2 regularization term on weights + scale_pos_weight : float + Balancing of positive and negative weights. + is_unbalance : bool + Is unbalance for binary classification + seed : int + Random number seed. + + Note + ---- + A custom objective function can be provided for the ``objective`` + parameter. In this case, it should have the signature + ``objective(y_true, y_pred) -> grad, hess``: + + y_true: array_like of shape [n_samples] + The target values + y_pred: array_like of shape [n_samples] + The predicted values + + grad: array_like of shape [n_samples] + The value of the gradient for each sample point. + hess: array_like of shape [n_samples] + The value of the second derivative for each sample point + + for multi-class task, the label/pred is group by class_id first, then group by row_id + if you want to get i-th row label/pred in j-th class, the access way is label/pred[j*num_data+i] + and you should group grad and hess in this way as well + """ + + def __init__(self, num_leaves=63, max_depth=-1, + learning_rate=0.1, n_estimators=100, max_bin=255, + silent=True, objective="regression", num_class=1, + nthread=-1, gamma=0, min_child_weight=1, + subsample=1, subsample_freq=1, colsample_bytree=1, colsample_byleaf=1, + reg_alpha=0, reg_lambda=0, scale_pos_weight=1, + is_unbalance=False, seed=0): + if not SKLEARN_INSTALLED: + raise LightGBMError('sklearn needs to be installed in order to use this module') + + self.num_leaves = num_leaves + self.max_depth = max_depth + self.learning_rate = learning_rate + self.n_estimators = n_estimators + self.max_bin = max_bin + self.silent = silent + self.objective = objective + self.num_class = num_class + self.nthread = nthread + self.gamma = gamma + self.min_child_weight = min_child_weight + self.subsample = subsample + self.subsample_freq = subsample_freq + self.colsample_bytree = colsample_bytree + self.colsample_byleaf = colsample_byleaf + self.reg_alpha = reg_alpha + self.reg_lambda = reg_lambda + self.scale_pos_weight = scale_pos_weight + self.is_unbalance = is_unbalance + self.seed = seed + self._Booster = None + + def booster(self): + """Get the underlying lightgbm Booster of this model. + + This will raise an exception when fit was not called + + Returns + ------- + booster : a lightgbm booster of underlying model + """ + if self._Booster is None: + raise LightGBMError('need to call fit beforehand') + return self._Booster + + def get_params(self, deep=False): + """Get parameter.s""" + params = super(LGBMModel, self).get_params(deep=deep) + params['verbose'] = 0 if self.silent else 1 + if self.nthread <= 0: + params.pop('nthread', None) + return params + + def fit(self, X, y, eval_set=None, eval_metric=None, + early_stopping_rounds=None, verbose=True): + """ + Fit the gradient boosting model + + Parameters + ---------- + X : array_like + Feature matrix + y : array_like + Labels + eval_set : list, optional + A list of (X, y) tuple pairs to use as a validation set for early-stopping + eval_metric : str, list of str, callable, optional + If a str, should be a built-in evaluation metric to use. See + doc/parameter.md. If callable, a custom evaluation metric. The call + signature is func(y_predicted, y_true) where y_true will be a + Dataset fobject such that you may need to call the get_label + method. And it must return (eval_name, feature_result, is_bigger_better) + early_stopping_rounds : int + verbose : bool + If `verbose` and an evaluation set is used, writes the evaluation + metric measured on the validation set to stderr. + """ + evals_result = {} + params = self.get_params() + + if callable(self.objective): + fobj = _objective_decorator(self.objective) + params["objective"] = "None" + else: + fobj = None + if callable(eval_metric): + feval = eval_metric + else: + feval = None + params.update({'metric': eval_metric}) + feval = eval_metric if callable(eval_metric) else None + + + self._Booster = train(params, (X, y), + self.n_estimators, valid_datas=eval_set, + early_stopping_rounds=early_stopping_rounds, + evals_result=evals_result, fobj=fobj, feval=feval, + verbose_eval=verbose) + + if evals_result: + for val in evals_result.items(): + evals_result_key = list(val[1].keys())[0] + evals_result[val[0]][evals_result_key] = val[1][evals_result_key] + self.evals_result_ = evals_result + + if early_stopping_rounds is not None: + self.best_iteration = self._Booster.best_iteration + return self + + def predict(self, data, raw_score=False, num_iteration=0): + return self.booster().predict(data, + raw_score=raw_score, + num_iteration=num_iteration) + + def apply(self, X, num_iteration=0): + """Return the predicted leaf every tree for each sample. + + Parameters + ---------- + X : array_like, shape=[n_samples, n_features] + Input features matrix. + + ntree_limit : int + Limit number of trees in the prediction; defaults to 0 (use all trees). + + Returns + ------- + X_leaves : array_like, shape=[n_samples, n_trees] + """ + return self.booster().predict(X, + pred_leaf=True, + num_iteration=num_iteration) + + def evals_result(self): + """Return the evaluation results. + Returns + ------- + evals_result : dictionary + """ + if self.evals_result_: + evals_result = self.evals_result_ + else: + raise LightGBMError('No results.') + + return evals_result \ No newline at end of file diff --git a/src/io/config.cpp b/src/io/config.cpp index 24ba003d0..b9337e921 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -213,6 +213,7 @@ void ObjectiveConfig::Set(const std::unordered_map& pa CHECK(max_position > 0); GetInt(params, "num_class", &num_class); CHECK(num_class >= 1); + GetDouble(params, "scale_pos_weight", &scale_pos_weight); std::string tmp_str = ""; if (GetString(params, "label_gain", &tmp_str)) { label_gain = Common::StringToDoubleArray(tmp_str, ','); diff --git a/src/objective/binary_objective.hpp b/src/objective/binary_objective.hpp index 0e78f3d22..fcfb05ae3 100644 --- a/src/objective/binary_objective.hpp +++ b/src/objective/binary_objective.hpp @@ -18,6 +18,7 @@ public: if (sigmoid_ <= 0.0) { Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_); } + scale_pos_weight_ = static_cast(config.scale_pos_weight); } ~BinaryLogloss() {} void Init(const Metadata& metadata, data_size_t num_data) override { @@ -55,6 +56,7 @@ public: label_weights_[0] = 1.0f; } } + label_weights_[1] *= scale_pos_weight_; } void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override { @@ -104,6 +106,7 @@ private: score_t label_weights_[2]; /*! \brief Weights for data */ const float* weights_; + score_t scale_pos_weight_; }; } // namespace LightGBM From d806836792a1ab6f6670b3313a1e059716cb44bc Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 29 Nov 2016 18:16:54 +0800 Subject: [PATCH 41/60] add simple setup for pip install --- python-package/lightgbm/__init__.py | 24 ++++++++++++++++++++ python-package/lightgbm/basic.py | 25 +-------------------- python-package/lightgbm/libpath.py | 27 +++++++++++++++++++++++ python-package/setup.py | 34 +++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 24 deletions(-) create mode 100644 python-package/lightgbm/libpath.py create mode 100644 python-package/setup.py diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index e69de29bb..6aaa0ff43 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -0,0 +1,24 @@ +# coding: utf-8 +"""LightGBM, Light Gradient Boosting Machine. + +Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors +""" + +from __future__ import absolute_import + +import os + +from .basic import Predictor, Dataset, Booster +from .engine import train, cv +try: + from .sklearn import LGBMModel +except ImportError: + pass + +VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION') +with open(VERSION_FILE) as f: + __version__ = f.read().strip() + +__all__ = ['Dataset', 'Booster', + 'train', 'cv', + 'LGBMModel'] \ No newline at end of file diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index e0af50518..f4d2a6524 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -11,33 +11,10 @@ import tempfile import numpy as np import scipy.sparse +from .libpath import find_lib_path IS_PY3 = (sys.version_info[0] == 3) - -def find_lib_path(): - """Find the path to LightGBM library files. - Returns - ------- - lib_path: list(string) - List of all found library path to LightGBM - """ - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - dll_path = [curr_path, os.path.join(curr_path, '../../lib/'), - os.path.join(curr_path, '../../'), - os.path.join(curr_path, './lib/'), - os.path.join(sys.prefix, 'lightgbm')] - if os.name == 'nt': - dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/')) - dll_path.append(os.path.join(curr_path, './windows/x64/Dll/')) - dll_path = [os.path.join(p, 'lib_lightgbm.dll') for p in dll_path] - else: - dll_path = [os.path.join(p, 'lib_lightgbm.so') for p in dll_path] - lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] - if not lib_path: - raise Exception('Cannot find lightgbm Library') - return lib_path - def _load_lib(): """Load LightGBM Library.""" lib_path = find_lib_path() diff --git a/python-package/lightgbm/libpath.py b/python-package/lightgbm/libpath.py new file mode 100644 index 000000000..9efd6cc74 --- /dev/null +++ b/python-package/lightgbm/libpath.py @@ -0,0 +1,27 @@ +import os +import platform +import sys + + +def find_lib_path(): + """Find the path to LightGBM library files. + Returns + ------- + lib_path: list(string) + List of all found library path to LightGBM + """ + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + dll_path = [curr_path, os.path.join(curr_path, '../../lib/'), + os.path.join(curr_path, '../../'), + os.path.join(curr_path, './lib/'), + os.path.join(sys.prefix, 'lightgbm')] + if os.name == 'nt': + dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/')) + dll_path.append(os.path.join(curr_path, './windows/x64/Dll/')) + dll_path = [os.path.join(p, 'lib_lightgbm.dll') for p in dll_path] + else: + dll_path = [os.path.join(p, 'lib_lightgbm.so') for p in dll_path] + lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] + if not lib_path: + raise Exception('Cannot find lightgbm Library') + return lib_path diff --git a/python-package/setup.py b/python-package/setup.py new file mode 100644 index 000000000..77ee59bb3 --- /dev/null +++ b/python-package/setup.py @@ -0,0 +1,34 @@ +# pylint: disable=invalid-name, exec-used +"""Setup lightgbm package.""" +from __future__ import absolute_import +import sys +import os +from setuptools import setup, find_packages +# import subprocess +sys.path.insert(0, '.') + +CURRENT_DIR = os.path.dirname(__file__) + +libpath_py = os.path.join(CURRENT_DIR, 'lightgbm/libpath.py') +libpath = {'__file__': libpath_py} +exec(compile(open(libpath_py, "rb").read(), libpath_py, 'exec'), libpath, libpath) + +LIB_PATH = libpath['find_lib_path']() +print("Install lib_lightgbm from: %s" % LIB_PATH) + +# Please use setup_pip.py for generating and deploying pip installation +# detailed instruction in setup_pip.py +setup(name='lightgbm', + version=open(os.path.join(CURRENT_DIR, 'lightgbm/VERSION')).read().strip(), + description="LightGBM Python Package", + install_requires=[ + 'numpy', + 'scipy', + ], + maintainer='Guolin Ke', + maintainer_email='guolin.ke@microsoft.com', + zip_safe=False, + packages=find_packages(), + include_package_data=True, + data_files=[('lightgbm', LIB_PATH)], + url='hhttps://github.com/Microsoft/LightGBM') From 05754b957921d37bc3a2be25f5afe979c7e2f647 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 29 Nov 2016 19:30:29 +0800 Subject: [PATCH 42/60] finish other models. --- python-package/lightgbm/__init__.py | 4 +- python-package/lightgbm/sklearn.py | 83 +++++++++++++++++++++++++---- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 6aaa0ff43..966a73034 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -11,7 +11,7 @@ import os from .basic import Predictor, Dataset, Booster from .engine import train, cv try: - from .sklearn import LGBMModel + from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier except ImportError: pass @@ -21,4 +21,4 @@ with open(VERSION_FILE) as f: __all__ = ['Dataset', 'Booster', 'train', 'cv', - 'LGBMModel'] \ No newline at end of file + 'LGBMModel','LGBMRegressor', 'LGBMClassifier'] \ No newline at end of file diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 227872a78..5146d2daa 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -75,8 +75,6 @@ class LGBMModel(LGBMModelBase): objective : string or callable Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below). - num_class: int - only affect for multi-class training. nthread : int Number of parallel threads gamma : float @@ -125,7 +123,7 @@ class LGBMModel(LGBMModelBase): def __init__(self, num_leaves=63, max_depth=-1, learning_rate=0.1, n_estimators=100, max_bin=255, - silent=True, objective="regression", num_class=1, + silent=True, objective="regression", nthread=-1, gamma=0, min_child_weight=1, subsample=1, subsample_freq=1, colsample_bytree=1, colsample_byleaf=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1, @@ -140,7 +138,6 @@ class LGBMModel(LGBMModelBase): self.max_bin = max_bin self.silent = silent self.objective = objective - self.num_class = num_class self.nthread = nthread self.gamma = gamma self.min_child_weight = min_child_weight @@ -169,7 +166,7 @@ class LGBMModel(LGBMModelBase): return self._Booster def get_params(self, deep=False): - """Get parameter.s""" + """Get parameters""" params = super(LGBMModel, self).get_params(deep=deep) params['verbose'] = 0 if self.silent else 1 if self.nthread <= 0: @@ -177,7 +174,7 @@ class LGBMModel(LGBMModelBase): return params def fit(self, X, y, eval_set=None, eval_metric=None, - early_stopping_rounds=None, verbose=True): + early_stopping_rounds=None, verbose=True, train_fields=None, valid_fields=None, other_params=None): """ Fit the gradient boosting model @@ -198,7 +195,14 @@ class LGBMModel(LGBMModelBase): early_stopping_rounds : int verbose : bool If `verbose` and an evaluation set is used, writes the evaluation - metric measured on the validation set to stderr. + train_fields : dict + other data file in training data. e.g. train_fields['weight'] is weight data + support fields: weight, group, init_score + valid_fields : dict + other data file in training data. e.g. valid_fields[0]['weight'] is weight data for first valid data + support fields: weight, group, init_score + other_params: dict + other parameters """ evals_result = {} params = self.get_params() @@ -207,6 +211,7 @@ class LGBMModel(LGBMModelBase): fobj = _objective_decorator(self.objective) params["objective"] = "None" else: + params["objective"] = self.objective fobj = None if callable(eval_metric): feval = eval_metric @@ -215,12 +220,14 @@ class LGBMModel(LGBMModelBase): params.update({'metric': eval_metric}) feval = eval_metric if callable(eval_metric) else None + if other_params is not None: + params.update(other_params) self._Booster = train(params, (X, y), self.n_estimators, valid_datas=eval_set, early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, fobj=fobj, feval=feval, - verbose_eval=verbose) + verbose_eval=verbose, train_fields=train_fields, valid_fields=valid_fields) if evals_result: for val in evals_result.items(): @@ -267,4 +274,62 @@ class LGBMModel(LGBMModelBase): else: raise LightGBMError('No results.') - return evals_result \ No newline at end of file + return evals_result + + +class LGBMRegressor(LGBMModel, LGBMRegressorBase): + __doc__ = """Implementation of the scikit-learn API for LightGBM regression. + """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) + +class LGBMClassifier(LGBMModel, LGBMClassifierBase): + __doc__ = """Implementation of the scikit-learn API for LGBMoost classification. + + """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) + + def fit(self, X, y, eval_set=None, eval_metric=None, + early_stopping_rounds=None, verbose=True, + train_fields=None, valid_fields=None, other_params=None): + + self.classes_ = np.unique(y) + self.n_classes_ = len(self.classes_) + if other_params is None: + other_params = {} + if self.n_classes_ > 2: + # Switch to using a multiclass objective in the underlying LGBM instance + other_params["objective"] = "multiclass" + other_params['num_class'] = self.n_classes_ + else: + other_params["objective"] = "binary" + + self._le = LGBMLabelEncoder().fit(y) + training_labels = self._le.transform(y) + + if eval_set is not None: + eval_set = list( (x[0], self._le.transform(x[1])) for x in eval_set ) + + super(LGBMClassifier, self).fit(X, training_labels, eval_set, eval_metric, + early_stopping_rounds, verbose, train_fields, valid_fields, other_params) + return self + + def predict(self, data, raw_score=False, num_iteration=0): + class_probs = self.booster().predict(data, + raw_score=raw_score, + num_iteration=num_iteration) + if len(class_probs.shape) > 1: + column_indexes = np.argmax(class_probs, axis=1) + else: + column_indexes = np.repeat(0, class_probs.shape[0]) + column_indexes[class_probs > 0.5] = 1 + return self._le.inverse_transform(column_indexes) + + def predict_proba(self, data, raw_score=False, num_iteration=0): + class_probs = self.booster().predict(data, + raw_score=raw_score, + num_iteration=num_iteration) + if self.n_classes_ > 2: + return class_probs + else: + classone_probs = class_probs + classzero_probs = 1.0 - classone_probs + return np.vstack((classzero_probs, classone_probs)).transpose() + From 3051b7711874aaa39d26ab80b030bea9a71dd7d2 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 10:11:46 +0800 Subject: [PATCH 43/60] some typo --- python-package/lightgbm/basic.py | 2 +- python-package/lightgbm/sklearn.py | 4 ++-- python-package/setup.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index f4d2a6524..ab599c05a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1211,7 +1211,7 @@ class Booster(object): """ for key, value in kwargs.items(): if value is not None: - if not isinstance(value, STRING_TYPES): + if not is_str(value): raise ValueError("Set Attr only accepts string values") self.__attr[key] = value else: diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 5146d2daa..b73a9be21 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -24,7 +24,7 @@ except ImportError: def _objective_decorator(func): """Decorate an objective function - Converts an objective function using the typical sklearn metrics to LightGBM ffobj + Converts an objective function using the typical sklearn metrics to LightGBM fobj Note: for multi-class task, the label/pred is group by class_id first, then group by row_id if you want to get i-th row label/pred in j-th class, the access way is label/pred[j*num_data+i] @@ -282,7 +282,7 @@ class LGBMRegressor(LGBMModel, LGBMRegressorBase): """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) class LGBMClassifier(LGBMModel, LGBMClassifierBase): - __doc__ = """Implementation of the scikit-learn API for LGBMoost classification. + __doc__ = """Implementation of the scikit-learn API for LightGBM classification. """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) diff --git a/python-package/setup.py b/python-package/setup.py index 77ee59bb3..834af8ce9 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -16,10 +16,9 @@ exec(compile(open(libpath_py, "rb").read(), libpath_py, 'exec'), libpath, libpat LIB_PATH = libpath['find_lib_path']() print("Install lib_lightgbm from: %s" % LIB_PATH) -# Please use setup_pip.py for generating and deploying pip installation -# detailed instruction in setup_pip.py + setup(name='lightgbm', - version=open(os.path.join(CURRENT_DIR, 'lightgbm/VERSION')).read().strip(), + version=0.1, description="LightGBM Python Package", install_requires=[ 'numpy', @@ -31,4 +30,4 @@ setup(name='lightgbm', packages=find_packages(), include_package_data=True, data_files=[('lightgbm', LIB_PATH)], - url='hhttps://github.com/Microsoft/LightGBM') + url='https://github.com/Microsoft/LightGBM') From 44fcf16c6fb97db54f770439b11726a4b5c9c74b Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 10:27:27 +0800 Subject: [PATCH 44/60] fix travis --- .travis.yml | 2 ++ python-package/lightgbm/__init__.py | 5 ++--- tests/python_package_test/test_basic.py | 5 +---- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 08b411101..851017b0f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,10 +21,12 @@ script: - cd $TRAVIS_BUILD_DIR - mkdir build && cd build && cmake .. && make -j - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py +- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install - cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py - cd $TRAVIS_BUILD_DIR - rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py +- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install - cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py notifications: diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 966a73034..cb18af4d6 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -15,9 +15,8 @@ try: except ImportError: pass -VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION') -with open(VERSION_FILE) as f: - __version__ = f.read().strip() + +__version__ = 0.1 __all__ = ['Dataset', 'Booster', 'train', 'cv', diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index d35f70a80..66a898139 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -1,9 +1,6 @@ import numpy as np from sklearn import datasets, metrics, model_selection -import importlib.util -spec = importlib.util.spec_from_file_location("module.name", "../../python-package/lightgbm/basic.py") -lgb = importlib.util.module_from_spec(spec) -spec.loader.exec_module(lgb) +import lightgbm as lgb X, Y = datasets.make_classification(n_samples=100000, n_features=100) From 452b41f015bc299e0b06152f11cba1cfce248bae Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 11:40:31 +0800 Subject: [PATCH 45/60] add more tests --- python-package/lightgbm/engine.py | 2 +- python-package/lightgbm/sklearn.py | 16 ++-- tests/python_package_test/test_sklearn.py | 103 ++++++++++++++++++++++ 3 files changed, 114 insertions(+), 7 deletions(-) create mode 100644 tests/python_package_test/test_sklearn.py diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 65eb598ea..5a1341e58 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -208,7 +208,7 @@ def train(params, train_data, num_boost_round=100, booster.best_iteration = int(booster.attr('best_iteration')) else: booster.best_iteration = num_boost_round - 1 - return num_boost_round + return booster class CVBooster(object): diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index b73a9be21..3d996a128 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -207,6 +207,9 @@ class LGBMModel(LGBMModelBase): evals_result = {} params = self.get_params() + if other_params is not None: + params.update(other_params) + if callable(self.objective): fobj = _objective_decorator(self.objective) params["objective"] = "None" @@ -215,14 +218,13 @@ class LGBMModel(LGBMModelBase): fobj = None if callable(eval_metric): feval = eval_metric - else: + elif is_str(eval_metric) or isinstance(eval_metric, list): feval = None params.update({'metric': eval_metric}) + else: + feval = None feval = eval_metric if callable(eval_metric) else None - if other_params is not None: - params.update(other_params) - self._Booster = train(params, (X, y), self.n_estimators, valid_datas=eval_set, early_stopping_rounds=early_stopping_rounds, @@ -296,10 +298,12 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): other_params = {} if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying LGBM instance - other_params["objective"] = "multiclass" + if not callable(self.objective): + self.objective = "multiclass" other_params['num_class'] = self.n_classes_ else: - other_params["objective"] = "binary" + if not callable(self.objective): + self.objective = "binary" self._le = LGBMLabelEncoder().fit(y) training_labels = self._le.transform(y) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py new file mode 100644 index 000000000..3c57deec2 --- /dev/null +++ b/tests/python_package_test/test_sklearn.py @@ -0,0 +1,103 @@ +import numpy as np +import random +import lightgbm as lgb + + +rng = np.random.RandomState(2016) + +def test_binary_classification(): + + from sklearn import datasets, metrics, model_selection + + X, y = datasets.make_classification(n_samples=10000, n_features=100) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + from sklearn.datasets import load_digits + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + preds = lgb_model.predict(x_test) + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) + assert err < 0.1 + +def test_multiclass_classification(): + from sklearn.datasets import load_iris + from sklearn import datasets, metrics, model_selection + + def check_pred(preds, labels): + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != labels[i]) / float(len(preds)) + assert err < 0.7 + + + X, y = datasets.make_classification(n_samples=10000, n_features=100, n_classes=4, n_informative=3) + + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='multi_logloss') + preds = lgb_model.predict(x_test) + + check_pred(preds, y_test) + +def test_regression(): + from sklearn.metrics import mean_squared_error + from sklearn.datasets import load_boston + from sklearn.cross_validation import KFold + from sklearn import datasets, metrics, model_selection + + boston = load_boston() + y = boston['target'] + X = boston['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') + preds = lgb_model.predict(x_test) + assert mean_squared_error(preds, y_test) < 30 + +def test_regression_with_custom_objective(): + from sklearn.metrics import mean_squared_error + from sklearn.datasets import load_boston + from sklearn.cross_validation import KFold + from sklearn import datasets, metrics, model_selection + def objective_ls(y_true, y_pred): + grad = (y_pred - y_true) + hess = np.ones(len(y_true)) + return grad, hess + boston = load_boston() + y = boston['target'] + X = boston['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') + preds = lgb_model.predict(x_test) + assert mean_squared_error(preds, y_test) < 30 + + +def test_binary_classification_with_custom_objective(): + + from sklearn import datasets, metrics, model_selection + def logregobj(y_true, y_pred): + y_pred = 1.0 / (1.0 + np.exp(-y_pred)) + grad = y_pred - y_true + hess = y_pred * (1.0 - y_pred) + return grad, hess + X, y = datasets.make_classification(n_samples=10000, n_features=100) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + from sklearn.datasets import load_digits + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + preds = lgb_model.predict(x_test) + err = sum(1 for i in range(len(preds)) + if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) + assert err < 0.1 + +test_binary_classification() +test_multiclass_classification() +test_regression() +test_regression_with_custom_objective() +test_binary_classification_with_custom_objective() \ No newline at end of file From a1567983bff3025c2118cb70a378b6223404b470 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 11:45:36 +0800 Subject: [PATCH 46/60] update travis --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 851017b0f..fa03a964c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,12 +22,12 @@ script: - mkdir build && cd build && cmake .. && make -j - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py - cd $TRAVIS_BUILD_DIR/python-package && python setup.py install -- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py +- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_sklearn.py - cd $TRAVIS_BUILD_DIR - rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py - cd $TRAVIS_BUILD_DIR/python-package && python setup.py install -- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py +- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_sklearn.py notifications: email: false From b12f99682ee23dc4535ac66f5d259dd33a0b46fc Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 12:06:01 +0800 Subject: [PATCH 47/60] add pandas support --- python-package/lightgbm/engine.py | 60 ++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 5a1341e58..86a3e3075 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -6,7 +6,41 @@ import numpy as np from .basic import LightGBMError, Predictor, Dataset, Booster, is_str from . import callback -def _construct_dataset(data, reference=None, +# pandas +try: + from pandas import DataFrame +except ImportError: + class DataFrame(object): + pass + +PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', + 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', + 'float16': 'float', 'float32': 'float', 'float64': 'float', + 'bool': 'i'} + +def _data_from_pandas(data): + if isinstance(data, DataFrame): + data_dtypes = data.dtypes + if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes): + bad_fields = [data.columns[i] for i, dtype in + enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER] + + msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """ + raise ValueError(msg + ', '.join(bad_fields)) + data = data.values.astype('float') + return data + +def _label_from_pandas(label): + if isinstance(label, DataFrame): + if len(label.columns) > 1: + raise ValueError('DataFrame for label cannot have multiple columns') + label_dtypes = label.dtypes + if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes): + raise ValueError('DataFrame.dtypes for label must be int, float or bool') + label = label.values.astype('float') + return label + +def _construct_dataset(X_y, reference=None, params=None, other_fields=None, predictor=None): if 'max_bin' in params: max_bin = int(params['max_bin']) @@ -21,18 +55,20 @@ def _construct_dataset(data, reference=None, weight = None if 'weight' not in other_fields else other_fields['weight'] group = None if 'group' not in other_fields else other_fields['group'] init_score = None if 'init_score' not in other_fields else other_fields['init_score'] - if reference is None: - if is_str(data): - ret = Dataset(data, label=None, max_bin=max_bin, - weight=weight, group=group, predictor=predictor, params=params) - else: - ret = Dataset(data[0], data[1], max_bin=max_bin, - weight=weight, group=group, predictor=predictor, params=params) + if is_str(X_y): + data = X_y + label = None else: - if is_str(data): - ret = reference.create_valid(data, label=None, weight=weight, group=group, params=params) - else: - ret = reference.create_valid(data[0], data[1], weight, group, params=params) + if len(X_y) != 2: + raise TypeError("should pass (data, label) pair") + data = _data_from_pandas(X_y[0]) + label = _label_from_pandas(X_y[1]) + if reference is None: + ret = Dataset(data, label=label, max_bin=max_bin, + weight=weight, group=group, predictor=predictor, params=params) + + else: + ret = reference.create_valid(data, label=label, weight=weight, group=group, params=params) if init_score is not None: ret.set_init_score(init_score) return ret From c67d289086cb2825d6bcf0c2f4e6f0dee23f8bcd Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 12:24:32 +0800 Subject: [PATCH 48/60] move pandas support into basic.py --- python-package/lightgbm/basic.py | 35 +++++++++++++++++++++ python-package/lightgbm/engine.py | 38 ++--------------------- tests/python_package_test/test_sklearn.py | 4 +-- 3 files changed, 39 insertions(+), 38 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index ab599c05a..0cd37ff87 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -358,6 +358,39 @@ class Predictor(object): raise ValueError("incorrect number for predict result") return preds, nrow +# pandas +try: + from pandas import DataFrame +except ImportError: + class DataFrame(object): + pass + +PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', + 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', + 'float16': 'float', 'float32': 'float', 'float64': 'float', + 'bool': 'i'} + +def _data_from_pandas(data): + if isinstance(data, DataFrame): + data_dtypes = data.dtypes + if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes): + bad_fields = [data.columns[i] for i, dtype in + enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER] + + msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """ + raise ValueError(msg + ', '.join(bad_fields)) + data = data.values.astype('float') + return data + +def _label_from_pandas(label): + if isinstance(label, DataFrame): + if len(label.columns) > 1: + raise ValueError('DataFrame for label cannot have multiple columns') + label_dtypes = label.dtypes + if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes): + raise ValueError('DataFrame.dtypes for label must be int, float or bool') + label = label.values.astype('float') + return label class Dataset(object): """Dataset used in LightGBM. @@ -398,6 +431,8 @@ class Dataset(object): if data is None: self.handle = None return + data = _data_from_pandas(data) + label = _label_from_pandas(label) self.data_has_header = False """process for args""" params = {} if params is None else params diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 86a3e3075..d9cc89402 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -6,40 +6,6 @@ import numpy as np from .basic import LightGBMError, Predictor, Dataset, Booster, is_str from . import callback -# pandas -try: - from pandas import DataFrame -except ImportError: - class DataFrame(object): - pass - -PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', - 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', - 'float16': 'float', 'float32': 'float', 'float64': 'float', - 'bool': 'i'} - -def _data_from_pandas(data): - if isinstance(data, DataFrame): - data_dtypes = data.dtypes - if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes): - bad_fields = [data.columns[i] for i, dtype in - enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER] - - msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """ - raise ValueError(msg + ', '.join(bad_fields)) - data = data.values.astype('float') - return data - -def _label_from_pandas(label): - if isinstance(label, DataFrame): - if len(label.columns) > 1: - raise ValueError('DataFrame for label cannot have multiple columns') - label_dtypes = label.dtypes - if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes): - raise ValueError('DataFrame.dtypes for label must be int, float or bool') - label = label.values.astype('float') - return label - def _construct_dataset(X_y, reference=None, params=None, other_fields=None, predictor=None): if 'max_bin' in params: @@ -61,8 +27,8 @@ def _construct_dataset(X_y, reference=None, else: if len(X_y) != 2: raise TypeError("should pass (data, label) pair") - data = _data_from_pandas(X_y[0]) - label = _label_from_pandas(X_y[1]) + data = X_y[0] + label = X_y[1] if reference is None: ret = Dataset(data, label=label, max_bin=max_bin, weight=weight, group=group, predictor=predictor, params=params) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 3c57deec2..4280adfb7 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -54,7 +54,7 @@ def test_regression(): x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') preds = lgb_model.predict(x_test) - assert mean_squared_error(preds, y_test) < 30 + assert mean_squared_error(preds, y_test) < 40 def test_regression_with_custom_objective(): from sklearn.metrics import mean_squared_error @@ -71,7 +71,7 @@ def test_regression_with_custom_objective(): x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') preds = lgb_model.predict(x_test) - assert mean_squared_error(preds, y_test) < 30 + assert mean_squared_error(preds, y_test) < 40 def test_binary_classification_with_custom_objective(): From f65164f67ed37e0f147d154bf7d39e391c9467f8 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 12:46:30 +0800 Subject: [PATCH 49/60] less verbose in test --- tests/python_package_test/test_sklearn.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 4280adfb7..085a8732e 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -11,13 +11,13 @@ def test_binary_classification(): X, y = datasets.make_classification(n_samples=10000, n_features=100) x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) - lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train) from sklearn.datasets import load_digits digits = load_digits(2) y = digits['target'] X = digits['data'] x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) - lgb_model = lgb.LGBMClassifier().fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train) preds = lgb_model.predict(x_test) err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) @@ -37,7 +37,7 @@ def test_multiclass_classification(): x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) - lgb_model = lgb.LGBMClassifier().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='multi_logloss') + lgb_model = lgb.LGBMClassifier().fit(x_train, y_train) preds = lgb_model.predict(x_test) check_pred(preds, y_test) @@ -52,7 +52,7 @@ def test_regression(): y = boston['target'] X = boston['data'] x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) - lgb_model = lgb.LGBMRegressor().fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') + lgb_model = lgb.LGBMRegressor().fit(x_train, y_train) preds = lgb_model.predict(x_test) assert mean_squared_error(preds, y_test) < 40 @@ -69,7 +69,7 @@ def test_regression_with_custom_objective(): y = boston['target'] X = boston['data'] x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) - lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train,eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='l2') + lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train) preds = lgb_model.predict(x_test) assert mean_squared_error(preds, y_test) < 40 @@ -84,13 +84,13 @@ def test_binary_classification_with_custom_objective(): return grad, hess X, y = datasets.make_classification(n_samples=10000, n_features=100) x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) - lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train) from sklearn.datasets import load_digits digits = load_digits(2) y = digits['target'] X = digits['data'] x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) - lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train, eval_set=[[x_train, y_train],(x_test, y_test)], eval_metric='binary_logloss') + lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train) preds = lgb_model.predict(x_test) err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) From f8267a504437655a93a5ccc355d9ef3afefbac21 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 13:19:27 +0800 Subject: [PATCH 50/60] add min_data, fix test --- python-package/lightgbm/sklearn.py | 11 +++++++---- tests/python_package_test/test_sklearn.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 3d996a128..8a5f3139c 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -80,7 +80,9 @@ class LGBMModel(LGBMModelBase): gamma : float Minimum loss reduction required to make a further partition on a leaf node of the tree. min_child_weight : int - Minimum sum of instance weight(hessian) needed in a child. + Minimum sum of instance weight(hessian) needed in a child(leaf) + min_data : int + Minimum number of data need in a child(leaf) subsample : float Subsample ratio of the training instance. subsample_freq : int @@ -121,10 +123,10 @@ class LGBMModel(LGBMModelBase): and you should group grad and hess in this way as well """ - def __init__(self, num_leaves=63, max_depth=-1, - learning_rate=0.1, n_estimators=100, max_bin=255, + def __init__(self, num_leaves=31, max_depth=-1, + learning_rate=0.1, n_estimators=10, max_bin=255, silent=True, objective="regression", - nthread=-1, gamma=0, min_child_weight=1, + nthread=-1, gamma=0, min_child_weight=5, min_data=10, subsample=1, subsample_freq=1, colsample_bytree=1, colsample_byleaf=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1, is_unbalance=False, seed=0): @@ -141,6 +143,7 @@ class LGBMModel(LGBMModelBase): self.nthread = nthread self.gamma = gamma self.min_child_weight = min_child_weight + self.min_data = min_data self.subsample = subsample self.subsample_freq = subsample_freq self.colsample_bytree = colsample_bytree diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 085a8732e..3688daafe 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -10,13 +10,13 @@ def test_binary_classification(): from sklearn import datasets, metrics, model_selection X, y = datasets.make_classification(n_samples=10000, n_features=100) - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMClassifier().fit(x_train, y_train) from sklearn.datasets import load_digits digits = load_digits(2) y = digits['target'] X = digits['data'] - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMClassifier().fit(x_train, y_train) preds = lgb_model.predict(x_test) err = sum(1 for i in range(len(preds)) @@ -35,7 +35,7 @@ def test_multiclass_classification(): X, y = datasets.make_classification(n_samples=10000, n_features=100, n_classes=4, n_informative=3) - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMClassifier().fit(x_train, y_train) preds = lgb_model.predict(x_test) @@ -51,10 +51,10 @@ def test_regression(): boston = load_boston() y = boston['target'] X = boston['data'] - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMRegressor().fit(x_train, y_train) preds = lgb_model.predict(x_test) - assert mean_squared_error(preds, y_test) < 40 + assert mean_squared_error(preds, y_test) < 100 def test_regression_with_custom_objective(): from sklearn.metrics import mean_squared_error @@ -68,10 +68,10 @@ def test_regression_with_custom_objective(): boston = load_boston() y = boston['target'] X = boston['data'] - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMRegressor(objective=objective_ls).fit(x_train, y_train) preds = lgb_model.predict(x_test) - assert mean_squared_error(preds, y_test) < 40 + assert mean_squared_error(preds, y_test) < 100 def test_binary_classification_with_custom_objective(): @@ -83,13 +83,13 @@ def test_binary_classification_with_custom_objective(): hess = y_pred * (1.0 - y_pred) return grad, hess X, y = datasets.make_classification(n_samples=10000, n_features=100) - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train) from sklearn.datasets import load_digits digits = load_digits(2) y = digits['target'] X = digits['data'] - x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2) + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=1) lgb_model = lgb.LGBMClassifier(objective=logregobj).fit(x_train, y_train) preds = lgb_model.predict(x_test) err = sum(1 for i in range(len(preds)) From b59a5a4c110d0ef56337f7cdc894c9f0c278fc44 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 13:39:19 +0800 Subject: [PATCH 51/60] test for early_stopping --- python-package/lightgbm/callback.py | 26 +++++++++++++---------- python-package/lightgbm/engine.py | 4 ++-- tests/python_package_test/test_sklearn.py | 20 ++++++++++++++++- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index c574b875a..555a9a5f8 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -148,8 +148,11 @@ def early_stop(stopping_rounds, verbose=True): callback : function The requested callback function. """ - is_init = False - final_best_iter = 0 + state = {} + factor_to_bigger_better = {} + best_score = {} + best_iter = {} + best_msg = {} def init(env): """internal function""" bst = env.model @@ -160,19 +163,20 @@ def early_stop(stopping_rounds, verbose=True): if verbose: msg = "Will train until hasn't improved in {} rounds.\n" print(msg.format(stopping_rounds)) - best_scores = [ float('-inf') for _ in range(len(env.evaluation_result_list))] - best_iter = [ 0 for _ in range(len(env.evaluation_result_list))] - if verbose: - best_msg = [ "" for _ in range(len(env.evaluation_result_list))] - factor_to_bigger_better = [-1.0 for _ in range(len(env.evaluation_result_list))] + for i in range(len(env.evaluation_result_list)): - if evaluation.evaluation_result_list[i][3]: + best_score[i] = float('-inf') + best_iter[i] = 0 + if verbose: + best_msg[i] = "" + factor_to_bigger_better[i] = -1.0 + if env.evaluation_result_list[i][3]: factor_to_bigger_better[i] = 1.0 - is_init = True + state['best_iter'] = 0 def callback(env): """internal function""" - if not is_init: + if len(best_score) == 0: init(env) for i in range(len(env.evaluation_result_list)): score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i] @@ -184,7 +188,7 @@ def early_stop(stopping_rounds, verbose=True): '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list])) else: if env.iteration - best_iter[i] >= stopping_rounds: - final_best_iter = best_iter[i] + state['best_iter'] = best_iter[i] if env.model is not None: env.model.set_attr(best_iteration=str(best_iter[i])) if verbose: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index d9cc89402..262f0f8b0 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -112,7 +112,7 @@ def train(params, train_data, num_boost_round=100, if is_str(init_model): predictor = Predictor(model_file=init_model) elif isinstance(init_model, Booster): - predictor = Booster.to_predictor() + predictor = init_model.to_predictor() elif isinstance(init_model, Predictor): predictor = init_model else: @@ -409,6 +409,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, evaluation_result_list=res)) except callback.EarlyStopException as e: for k in results.keys(): - results[k] = results[k][:(e.final_best_iter + 1)] + results[k] = results[k][:(e.state['best_iter'] + 1)] break return results diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 3688daafe..d1bec620b 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -96,8 +96,26 @@ def test_binary_classification_with_custom_objective(): if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) assert err < 0.1 +def test_early_stopping(): + from sklearn.metrics import mean_squared_error + from sklearn.datasets import load_boston + from sklearn.cross_validation import KFold + from sklearn import datasets, metrics, model_selection + + boston = load_boston() + y = boston['target'] + X = boston['data'] + x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) + lgb_model = lgb.LGBMRegressor(n_estimators=500) \ + .fit(x_train, y_train, eval_set=[(x_test, y_test)], + eval_metric='l2', + early_stopping_rounds=10, + verbose=10) + print(lgb_model.best_iteration) + test_binary_classification() test_multiclass_classification() test_regression() test_regression_with_custom_objective() -test_binary_classification_with_custom_objective() \ No newline at end of file +test_binary_classification_with_custom_objective() +test_early_stopping() \ No newline at end of file From 164524d8fbc5411337400aa85c71e43dd9354e02 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 15:20:06 +0800 Subject: [PATCH 52/60] weighted objective function --- include/LightGBM/config.h | 3 +- python-package/lightgbm/basic.py | 2 +- python-package/lightgbm/sklearn.py | 59 ++++++++++++++++--------- tests/python_package_test/test_basic.py | 5 ++- 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 19ae190f2..e02b4fcaa 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -332,6 +332,7 @@ struct ParameterAlias { { "ndcg_at", "ndcg_eval_at" }, { "min_data_per_leaf", "min_data_in_leaf" }, { "min_data", "min_data_in_leaf" }, + { "min_child_samples", "min_data_in_leaf" }, { "min_sum_hessian_per_leaf", "min_sum_hessian_in_leaf" }, { "min_sum_hessian", "min_sum_hessian_in_leaf" }, { "min_hessian", "min_sum_hessian_in_leaf" }, @@ -369,7 +370,7 @@ struct ParameterAlias { { "blacklist", "ignore_column" }, { "predict_raw_score", "is_predict_raw_score" }, { "predict_leaf_index", "is_predict_leaf_index" }, - { "gamma", "min_gain_to_split" }, + { "min_split_gain", "min_gain_to_split" }, { "reg_alpha", "lambda_l1" }, { "reg_lambda", "lambda_l2" }, { "num_classes", "num_class" } diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 0cd37ff87..5afd75935 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -127,7 +127,7 @@ C_API_PREDICT_RAW_SCORE =1 C_API_PREDICT_LEAF_INDEX =2 FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32, -"wegiht":C_API_DTYPE_FLOAT32, +"weight":C_API_DTYPE_FLOAT32, "init_score":C_API_DTYPE_FLOAT32, "group":C_API_DTYPE_INT32, } diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 8a5f3139c..4384b9f4b 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -21,13 +21,13 @@ except ImportError: LGBMRegressorBase = object LGBMLabelEncoder = None -def _objective_decorator(func): +def _point_wise_objective(func): """Decorate an objective function Converts an objective function using the typical sklearn metrics to LightGBM fobj - Note: for multi-class task, the label/pred is group by class_id first, then group by row_id - if you want to get i-th row label/pred in j-th class, the access way is label/pred[j*num_data+i] + Note: for multi-class task, the y_pred is group by class_id first, then group by row_id + if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] and you should group grad and hess in this way as well Parameters ---------- @@ -36,16 +36,17 @@ def _objective_decorator(func): y_true: array_like of shape [n_samples] The target values - y_pred: array_like of shape [n_samples] + y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] The predicted values + Returns ------- new_func: callable The new objective function as expected by ``lightgbm.engine.train``. The signature is ``new_func(preds, dataset)``: - preds: array_like, shape [n_samples] + preds: array_like, shape [n_samples] or shape[n_samples* n_class] The predicted values dataset: ``dataset`` The training set from which the labels will be extracted using @@ -54,9 +55,26 @@ def _objective_decorator(func): def inner(preds, dataset): """internal function""" labels = dataset.get_label() - return func(labels, preds) + grad, hess = func(labels, preds) + """weighted for objective""" + weight = dataset.get_weight() + if weight is not None: + """only one class""" + if len(weight) == len(grad): + grad = np.multiply(grad, weight) + hess = np.multiply(hess, weight) + else: + num_data = len(weight) + num_class = len(grad) // num_data + for k in range(num_class): + for i in range(num_data): + idx = k * num_data + i + grad[idx] *= weight[i] + hess[idx] *= weight[i] + return grad, hess return inner + class LGBMModel(LGBMModelBase): """Implementation of the Scikit-Learn API for LightGBM. @@ -77,11 +95,11 @@ class LGBMModel(LGBMModelBase): a custom objective function to be used (see note below). nthread : int Number of parallel threads - gamma : float + min_split_gain : float Minimum loss reduction required to make a further partition on a leaf node of the tree. min_child_weight : int Minimum sum of instance weight(hessian) needed in a child(leaf) - min_data : int + min_child_samples : int Minimum number of data need in a child(leaf) subsample : float Subsample ratio of the training instance. @@ -89,8 +107,6 @@ class LGBMModel(LGBMModelBase): frequence of subsample, <=0 means no enable colsample_bytree : float Subsample ratio of columns when constructing each tree. - colsample_byleaf : float - Subsample ratio of columns when constructing each leaf. reg_alpha : float L1 regularization term on weights reg_lambda : float @@ -108,26 +124,26 @@ class LGBMModel(LGBMModelBase): parameter. In this case, it should have the signature ``objective(y_true, y_pred) -> grad, hess``: - y_true: array_like of shape [n_samples] + y_true: array_like of shape [n_samples] The target values - y_pred: array_like of shape [n_samples] + y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] The predicted values - grad: array_like of shape [n_samples] + grad: array_like of shape [n_samples] or shape[n_samples* n_class] The value of the gradient for each sample point. - hess: array_like of shape [n_samples] + hess: array_like of shape [n_samples] or shape[n_samples* n_class] The value of the second derivative for each sample point - for multi-class task, the label/pred is group by class_id first, then group by row_id - if you want to get i-th row label/pred in j-th class, the access way is label/pred[j*num_data+i] + for multi-class task, the y_pred is group by class_id first, then group by row_id + if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] and you should group grad and hess in this way as well """ def __init__(self, num_leaves=31, max_depth=-1, learning_rate=0.1, n_estimators=10, max_bin=255, silent=True, objective="regression", - nthread=-1, gamma=0, min_child_weight=5, min_data=10, - subsample=1, subsample_freq=1, colsample_bytree=1, colsample_byleaf=1, + nthread=-1, min_split_gain=0, min_child_weight=5, min_child_samples=10, + subsample=1, subsample_freq=1, colsample_bytree=1, reg_alpha=0, reg_lambda=0, scale_pos_weight=1, is_unbalance=False, seed=0): if not SKLEARN_INSTALLED: @@ -141,13 +157,12 @@ class LGBMModel(LGBMModelBase): self.silent = silent self.objective = objective self.nthread = nthread - self.gamma = gamma + self.min_split_gain = min_split_gain self.min_child_weight = min_child_weight - self.min_data = min_data + self.min_child_samples = min_child_samples self.subsample = subsample self.subsample_freq = subsample_freq self.colsample_bytree = colsample_bytree - self.colsample_byleaf = colsample_byleaf self.reg_alpha = reg_alpha self.reg_lambda = reg_lambda self.scale_pos_weight = scale_pos_weight @@ -214,7 +229,7 @@ class LGBMModel(LGBMModelBase): params.update(other_params) if callable(self.objective): - fobj = _objective_decorator(self.objective) + fobj = _point_wise_objective(self.objective) params["objective"] = "None" else: params["objective"] = self.objective diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 66a898139..cb2b1d0dc 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -15,7 +15,8 @@ bst.add_valid(valid_data,"valid_1") for i in range(100): bst.update() - print(bst.eval_train()) - print(bst.eval_valid()) + if i % 10 == 0: + print(bst.eval_train()) + print(bst.eval_valid()) bst.save_model("model.txt") From 80a52ad443cf05af5d200bf64382791ec9101200 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 16:00:59 +0800 Subject: [PATCH 53/60] add LGBMRanker --- python-package/lightgbm/__init__.py | 2 +- python-package/lightgbm/sklearn.py | 110 +++++++++++++++++++++++----- 2 files changed, 94 insertions(+), 18 deletions(-) diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index cb18af4d6..105f63b61 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -20,4 +20,4 @@ __version__ = 0.1 __all__ = ['Dataset', 'Booster', 'train', 'cv', - 'LGBMModel','LGBMRegressor', 'LGBMClassifier'] \ No newline at end of file + 'LGBMModel','LGBMRegressor', 'LGBMClassifier', 'LGBMRanker'] \ No newline at end of file diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 4384b9f4b..55f7f1c98 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -23,9 +23,6 @@ except ImportError: def _point_wise_objective(func): """Decorate an objective function - - Converts an objective function using the typical sklearn metrics to LightGBM fobj - Note: for multi-class task, the y_pred is group by class_id first, then group by row_id if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] and you should group grad and hess in this way as well @@ -36,7 +33,7 @@ def _point_wise_objective(func): y_true: array_like of shape [n_samples] The target values - y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] + y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class) The predicted values @@ -66,6 +63,8 @@ def _point_wise_objective(func): else: num_data = len(weight) num_class = len(grad) // num_data + if num_class * num_data != len(grad): + raise ValueError("lenght of grad and hess should equal with num_class * num_data") for k in range(num_class): for i in range(num_data): idx = k * num_data + i @@ -74,7 +73,6 @@ def _point_wise_objective(func): return grad, hess return inner - class LGBMModel(LGBMModelBase): """Implementation of the Scikit-Learn API for LightGBM. @@ -169,6 +167,10 @@ class LGBMModel(LGBMModelBase): self.is_unbalance = is_unbalance self.seed = seed self._Booster = None + if callable(self.objective): + self.fobj = _point_wise_objective(self.objective) + else: + self.fobj = None def booster(self): """Get the underlying lightgbm Booster of this model. @@ -205,11 +207,11 @@ class LGBMModel(LGBMModelBase): eval_set : list, optional A list of (X, y) tuple pairs to use as a validation set for early-stopping eval_metric : str, list of str, callable, optional - If a str, should be a built-in evaluation metric to use. See - doc/parameter.md. If callable, a custom evaluation metric. The call - signature is func(y_predicted, y_true) where y_true will be a + If a str, should be a built-in evaluation metric to use. + If callable, a custom evaluation metric. The call + signature is func(y_predicted, dataset) where dataset will be a Dataset fobject such that you may need to call the get_label - method. And it must return (eval_name, feature_result, is_bigger_better) + method. And it must return (eval_name->str, eval_result->float, is_bigger_better->Bool) early_stopping_rounds : int verbose : bool If `verbose` and an evaluation set is used, writes the evaluation @@ -228,12 +230,11 @@ class LGBMModel(LGBMModelBase): if other_params is not None: params.update(other_params) - if callable(self.objective): - fobj = _point_wise_objective(self.objective) + if self.fobj: params["objective"] = "None" else: params["objective"] = self.objective - fobj = None + if callable(eval_metric): feval = eval_metric elif is_str(eval_metric) or isinstance(eval_metric, list): @@ -246,7 +247,7 @@ class LGBMModel(LGBMModelBase): self._Booster = train(params, (X, y), self.n_estimators, valid_datas=eval_set, early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, fobj=fobj, feval=feval, + evals_result=evals_result, fobj=self.fobj, feval=feval, verbose_eval=verbose, train_fields=train_fields, valid_fields=valid_fields) if evals_result: @@ -316,12 +317,10 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): other_params = {} if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying LGBM instance - if not callable(self.objective): - self.objective = "multiclass" + self.objective = "multiclass" other_params['num_class'] = self.n_classes_ else: - if not callable(self.objective): - self.objective = "binary" + self.objective = "binary" self._le = LGBMLabelEncoder().fit(y) training_labels = self._le.transform(y) @@ -355,3 +354,80 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): classzero_probs = 1.0 - classone_probs return np.vstack((classzero_probs, classone_probs)).transpose() + +def _group_wise_objective(func): + """Decorate an objective function + Parameters + ---------- + func: callable + Expects a callable with signature ``func(y_true, group, y_pred)``: + + y_true: array_like of shape [n_samples] + The target values + group : array_like of shape + group size data of data + y_pred: array_like of shape [n_samples] or shape[n_samples* n_class] (for multi-class) + The predicted values + Returns + ------- + new_func: callable + The new objective function as expected by ``lightgbm.engine.train``. + The signature is ``new_func(preds, dataset)``: + + preds: array_like, shape [n_samples] or shape[n_samples* n_class] + The predicted values + dataset: ``dataset`` + The training set from which the labels will be extracted using + ``dataset.get_label()`` + """ + def inner(preds, dataset): + """internal function""" + labels = dataset.get_label() + group = dataset.get_group() + if group is None: + raise ValueError("group should not be None for ranking task") + grad, hess = func(labels, group, preds) + """weighted for objective""" + weight = dataset.get_weight() + if weight is not None: + """only one class""" + if len(weight) == len(grad): + grad = np.multiply(grad, weight) + hess = np.multiply(hess, weight) + else: + raise ValueError("lenght of grad and hess should equal with num_data") + return grad, hess + return inner + +class LGBMRanker(LGBMModel): + __doc__ = """Implementation of the scikit-learn API for LightGBM ranking application. + + """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) + + def fit(self, X, y, eval_set=None, eval_metric=None, + early_stopping_rounds=None, verbose=True, + train_fields=None, valid_fields=None, other_params=None): + + """check group data""" + if "group" not in train_fields: + raise ValueError("should set group in train_fields for ranking task") + + if eval_set is not None: + if valid_fields is None: + raise ValueError("valid_fields cannot be None when eval_set is not None") + elif len(valid_fields) != len(eval_set): + raise ValueError("lenght of valid_fields should equal with eval_set") + else: + for inner in valid_fields: + if "group" not in inner: + raise ValueError("should set group in valid_fields for ranking task") + + if callable(self.objective): + self.fobj = _group_wise_objective(self.objective) + else: + self.objective = "lambdarank" + self.fobj = None + + super(LGBMRanker, self).fit(X, y, eval_set, eval_metric, + early_stopping_rounds, verbose, train_fields, valid_fields, other_params) + return self From 673e3ea290edf8e7c5b2a4425f3672c7aa0f2bf2 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Wed, 30 Nov 2016 16:48:26 +0800 Subject: [PATCH 54/60] add import LGBMRanker in __init__.py --- python-package/lightgbm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 105f63b61..cf5a5d7b5 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -11,7 +11,7 @@ import os from .basic import Predictor, Dataset, Booster from .engine import train, cv try: - from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier + from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker except ImportError: pass From 1a8c23ed30d0d27ddc9ef02f27357adcffd7f67d Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 1 Dec 2016 15:16:18 +0800 Subject: [PATCH 55/60] some typo --- python-package/lightgbm/basic.py | 4 ++-- python-package/lightgbm/callback.py | 2 +- python-package/lightgbm/sklearn.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5afd75935..bc39c5871 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -19,7 +19,7 @@ def _load_lib(): """Load LightGBM Library.""" lib_path = find_lib_path() if len(lib_path) == 0: - return None + raise Exception("cannot find LightGBM library") lib = ctypes.cdll.LoadLibrary(lib_path[0]) lib.LGBM_GetLastError.restype = ctypes.c_char_p return lib @@ -1034,7 +1034,7 @@ class Booster(object): if data is self.valid_sets[i]: data_idx = i + 1 break - """need push new valid data""" + """need to push new valid data""" if data_idx == -1: self.add_valid(data, name) data_idx = self.__num_dataset - 1 diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 555a9a5f8..b6861ee0c 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -130,7 +130,7 @@ def reset_learning_rate(learning_rates): def early_stop(stopping_rounds, verbose=True): - """Create a callback that activates early stoppping. + """Create a callback that activates early stopping. Activates early stopping. Requires at least one validation data and one metric If there's more than one, will check all of them diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 55f7f1c98..d66fd8392 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -64,7 +64,7 @@ def _point_wise_objective(func): num_data = len(weight) num_class = len(grad) // num_data if num_class * num_data != len(grad): - raise ValueError("lenght of grad and hess should equal with num_class * num_data") + raise ValueError("length of grad and hess should equal with num_class * num_data") for k in range(num_class): for i in range(num_data): idx = k * num_data + i @@ -81,7 +81,7 @@ class LGBMModel(LGBMModelBase): num_leaves : int Maximum tree leaves for base learners. max_depth : int - Maximum tree depth for base learners, -1 means not limit. + Maximum tree depth for base learners, -1 means no limit. learning_rate : float Boosting learning rate n_estimators : int From 5b53978891e94fb95a038e62d8f36317548508af Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 1 Dec 2016 16:17:18 +0800 Subject: [PATCH 56/60] fix some pep8 check --- python-package/lightgbm/basic.py | 177 +++++++++++++++------------- python-package/lightgbm/callback.py | 7 +- python-package/lightgbm/engine.py | 70 +++++------ python-package/lightgbm/libpath.py | 6 +- python-package/lightgbm/sklearn.py | 19 +-- 5 files changed, 146 insertions(+), 133 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index bc39c5871..f8abe2916 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -4,8 +4,6 @@ from __future__ import absolute_import import sys import os import ctypes -import collections -import re import tempfile import numpy as np @@ -59,7 +57,7 @@ def is_1d_list(data): if not isinstance(data, list): return False if len(data) > 0: - if not isinstance(data[0], (int, float, bool) ): + if not isinstance(data[0], (int, float, bool)): return False return True @@ -108,29 +106,29 @@ def param_dict_to_str(data): if is_str(val): pairs.append(str(key)+'='+str(val)) elif isinstance(val, (list, tuple)): - pairs.append(str(key)+'='+','.join(map(str,val))) + pairs.append(str(key)+'='+','.join(map(str, val))) elif isinstance(val, (int, float, bool)): pairs.append(str(key)+'='+str(val)) else: - raise TypeError('unknow type of parameter:%s , got:%s' %(key, type(val).__name__)) + raise TypeError('unknow type of parameter:%s , got:%s' + % (key, type(val).__name__)) return ' '.join(pairs) """marco definition of data type in c_api of LightGBM""" -C_API_DTYPE_FLOAT32 =0 -C_API_DTYPE_FLOAT64 =1 -C_API_DTYPE_INT32 =2 -C_API_DTYPE_INT64 =3 +C_API_DTYPE_FLOAT32 = 0 +C_API_DTYPE_FLOAT64 = 1 +C_API_DTYPE_INT32 = 2 +C_API_DTYPE_INT64 = 3 """Matric is row major in python""" -C_API_IS_ROW_MAJOR =1 +C_API_IS_ROW_MAJOR = 1 -C_API_PREDICT_NORMAL =0 -C_API_PREDICT_RAW_SCORE =1 -C_API_PREDICT_LEAF_INDEX =2 +C_API_PREDICT_NORMAL = 0 +C_API_PREDICT_RAW_SCORE = 1 +C_API_PREDICT_LEAF_INDEX = 2 -FIELD_TYPE_MAPPER = {"label":C_API_DTYPE_FLOAT32, -"weight":C_API_DTYPE_FLOAT32, -"init_score":C_API_DTYPE_FLOAT32, -"group":C_API_DTYPE_INT32, - } +FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32, + "weight": C_API_DTYPE_FLOAT32, + "init_score": C_API_DTYPE_FLOAT32, + "group": C_API_DTYPE_INT32} def c_float_array(data): """Convert numpy array / list to c float array.""" @@ -144,7 +142,8 @@ def c_float_array(data): ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) type_data = C_API_DTYPE_FLOAT64 else: - raise TypeError("expected np.float32 or np.float64, met type({})".format(data.dtype)) + raise TypeError("expected np.float32 or np.float64, met type({})" + .format(data.dtype)) else: raise TypeError("Unknow type({})".format(type(data).__name__)) return (ptr_data, type_data) @@ -161,7 +160,8 @@ def c_int_array(data): ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)) type_data = C_API_DTYPE_INT64 else: - raise TypeError("expected np.int32 or np.int64, met type({})".format(data.dtype)) + raise TypeError("expected np.int32 or np.int64, met type({})" + .format(data.dtype)) else: raise TypeError("Unknow type({})".format(type(data).__name__)) return (ptr_data, type_data) @@ -169,13 +169,13 @@ def c_int_array(data): class Predictor(object): """"A Predictor of LightGBM. """ - def __init__(self,model_file=None, booster_handle=None, is_manage_handle=True): + def __init__(self, model_file=None, booster_handle=None, is_manage_handle=True): """Initialize the Predictor. Parameters ---------- model_file : string - Path to the model file. + Path to the model file. """ self.handle = ctypes.c_void_p() self.__is_manage_handle = True @@ -191,7 +191,7 @@ class Predictor(object): self.handle, ctypes.byref(out_num_class))) self.num_class = out_num_class.value - self.__num_total_iteration = out_num_iterations.value + self.__num_total_iteration = out_num_iterations.value elif booster_handle is not None: self.__is_manage_handle = is_manage_handle self.handle = booster_handle @@ -204,7 +204,7 @@ class Predictor(object): _safe_call(_LIB.LGBM_BoosterGetCurrentIteration( self.handle, ctypes.byref(out_num_iterations))) - self.__num_total_iteration = out_num_iterations.value + self.__num_total_iteration = out_num_iterations.value else: raise TypeError('Need Model file to create a booster') @@ -213,7 +213,9 @@ class Predictor(object): _safe_call(_LIB.LGBM_BoosterFree(self.handle)) - def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True): + def predict(self, data, num_iteration=-1, + raw_score=False, pred_leaf=False, data_has_header=False, + is_reshape=True): """ Predict logic @@ -222,23 +224,24 @@ class Predictor(object): data : string/numpy array/scipy.sparse Data source for prediction When data is string type, it represents the path of txt file, - num_iteration : + num_iteration : int used iteration for prediction - raw_score : bool + raw_score : bool True for predict raw score pred_leaf : bool True for predict leaf index data_has_header : bool Used for txt data is_reshape : bool - True for reshape to [nrow, ...] + True for reshape to [nrow, ...] Returns ------- Prediction result """ if isinstance(data, Dataset): - raise TypeError("cannot use Dataset instance for prediction, please use raw data instead") + raise TypeError("cannot use Dataset instance for prediction, \ + please use raw data instead") predict_type = C_API_PREDICT_NORMAL if raw_score: predict_type = C_API_PREDICT_RAW_SCORE @@ -251,12 +254,12 @@ class Predictor(object): tmp_pred_fname = tempfile.NamedTemporaryFile(prefix="lightgbm_tmp_pred_").name _safe_call(_LIB.LGBM_BoosterPredictForFile( self.handle, - c_str(data), + c_str(data), int_data_has_header, predict_type, num_iteration, c_str(tmp_pred_fname))) - tmp_file = open(tmp_pred_fname,"r") + tmp_file = open(tmp_pred_fname, "r") lines = tmp_file.readlines() tmp_file.close() nrow = len(lines) @@ -267,15 +270,19 @@ class Predictor(object): preds = np.array(preds, copy=False) os.remove(tmp_pred_fname) elif isinstance(data, scipy.sparse.csr_matrix): - preds, nrow = self.__pred_for_csr(data, num_iteration, predict_type) + preds, nrow = self.__pred_for_csr(data, num_iteration, + predict_type) elif isinstance(data, np.ndarray): - preds, nrow = self.__pred_for_np2d(data, num_iteration, predict_type) + preds, nrow = self.__pred_for_np2d(data, num_iteration, + predict_type) else: try: csr = scipy.sparse.csr_matrix(data) - preds, nrow = self.__pred_for_csr(csr, num_iteration, predict_type) + preds, nrow = self.__pred_for_csr(csr, num_iteration, + predict_type) except: - raise TypeError('can not predict data for type {}'.format(type(data).__name__)) + raise TypeError('can not predict data for type {}'. + format(type(data).__name__)) if pred_leaf: preds = preds.astype(np.int32) if preds.size != nrow and is_reshape: @@ -283,7 +290,8 @@ class Predictor(object): ncol = int(preds.size / nrow) preds = preds.reshape(nrow, ncol) else: - raise ValueError('len of predict result(%d) cannot be divide nrow(%d)' %(preds.size, nrow) ) + raise ValueError('len of predict result(%d) cannot be divide nrow (%d)' + % (preds.size, nrow)) return preds def __get_num_preds(self, num_iteration, nrow, predict_type): @@ -308,12 +316,13 @@ class Predictor(object): """change non-float data to float data, need to copy""" data = np.array(mat.reshape(mat.size), dtype=np.float32) ptr_data, type_ptr_data = c_float_array(data) - n_preds = self.__get_num_preds(num_iteration, mat.shape[0], predict_type) + n_preds = self.__get_num_preds(num_iteration, mat.shape[0], + predict_type) preds = np.zeros(n_preds, dtype=np.float32) out_num_preds = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterPredictForMat( self.handle, - ptr_data, + ptr_data, type_ptr_data, mat.shape[0], mat.shape[1], @@ -341,12 +350,12 @@ class Predictor(object): _safe_call(_LIB.LGBM_BoosterPredictForCSR( self.handle, - ptr_indptr, + ptr_indptr, type_ptr_indptr, csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, - type_ptr_data, - len(csr.indptr), + type_ptr_data, + len(csr.indptr), len(csr.data), csr.shape[1], predict_type, @@ -365,10 +374,10 @@ except ImportError: class DataFrame(object): pass -PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', - 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', - 'float16': 'float', 'float32': 'float', 'float64': 'float', - 'bool': 'i'} +PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', + 'int64': 'int', 'uint8': 'int', 'uint16': 'int', + 'uint32': 'int', 'uint64': 'int', 'float16': 'float', + 'float32': 'float', 'float64': 'float', 'bool': 'i'} def _data_from_pandas(data): if isinstance(data, DataFrame): @@ -399,8 +408,8 @@ class Dataset(object): """ def __init__(self, data, label=None, max_bin=255, reference=None, - weight=None, group=None, predictor=None, - silent=False, params=None): + weight=None, group=None, predictor=None, + silent=False, params=None): """ Dataset used in LightGBM. @@ -412,7 +421,7 @@ class Dataset(object): label : list or numpy 1-D array, optional Label of the data max_bin : int, required - max number of discrete bin for features + max number of discrete bin for features reference : Other Dataset, optional If this dataset validation, need to use training data as reference weight : list or numpy 1-D array , optional @@ -482,10 +491,10 @@ class Dataset(object): self.set_group(group) # load init score if self.predictor is not None and isinstance(self.predictor, Predictor): - init_score = self.predictor.predict(data, - raw_score=True, - data_has_header=self.data_has_header, - is_reshape=False) + init_score = self.predictor.predict(data, + raw_score=True, + data_has_header=self.data_has_header, + is_reshape=False) if self.predictor.num_class > 1: # need re group init score new_init_score = np.zeros(init_score.size(), dtype=np.float32) @@ -496,8 +505,8 @@ class Dataset(object): init_score = new_init_score self.set_init_score(init_score) - def create_valid(self, data, label=None, weight=None, group=None, - silent=False, params=None): + def create_valid(self, data, label=None, weight=None, group=None, + silent=False, params=None): """ Create validation data align with current dataset @@ -518,8 +527,8 @@ class Dataset(object): other parameters """ return Dataset(data, label=label, max_bin=self.max_bin, reference=self, - weight=weight, group=group, predictor=self.predictor, - silent=silent, params=params) + weight=weight, group=group, predictor=self.predictor, + silent=silent, params=params) def subset(self, used_indices, params=None): """ @@ -530,10 +539,10 @@ class Dataset(object): ret.handle = ctypes.c_void_p() params_str = param_dict_to_str(params) _safe_call(_LIB.LGBM_DatasetGetSubset( - ctypes.byref(self.handle), + ctypes.byref(self.handle), used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), used_indices.shape[0], - c_str(params_str), + c_str(params_str), ctypes.byref(ret.handle))) ret.max_bin = self.max_bin ret.predictor = self.predictor @@ -557,13 +566,13 @@ class Dataset(object): ptr_data, type_ptr_data = c_float_array(data) _safe_call(_LIB.LGBM_DatasetCreateFromMat( - ptr_data, + ptr_data, type_ptr_data, mat.shape[0], mat.shape[1], C_API_IS_ROW_MAJOR, - c_str(params_str), - ref_dataset, + c_str(params_str), + ref_dataset, ctypes.byref(self.handle))) def __init_from_csr(self, csr, params_str, ref_dataset): @@ -578,16 +587,16 @@ class Dataset(object): ptr_data, type_ptr_data = c_float_array(csr.data) _safe_call(_LIB.LGBM_DatasetCreateFromCSR( - ptr_indptr, + ptr_indptr, type_ptr_indptr, csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ptr_data, - type_ptr_data, - len(csr.indptr), + type_ptr_data, + len(csr.indptr), len(csr.data), - csr.shape[1], - c_str(params_str), - ref_dataset, + csr.shape[1], + c_str(params_str), + ref_dataset, ctypes.byref(self.handle))) def __del__(self): @@ -784,7 +793,7 @@ class Dataset(object): """ ret = ctypes.c_int64() _safe_call(_LIB.LGBM_DatasetGetNumData(self.handle, - ctypes.byref(ret))) + ctypes.byref(ret))) return ret.value def num_feature(self): @@ -796,7 +805,7 @@ class Dataset(object): """ ret = ctypes.c_int64() _safe_call(_LIB.LGBM_DatasetGetNumFeature(self.handle, - ctypes.byref(ret))) + ctypes.byref(ret))) return ret.value class Booster(object): @@ -812,7 +821,7 @@ class Booster(object): train_set : Dataset training dataset model_file : string - Path to the model file. + Path to the model file. silent : boolean, optional Whether print messages during construction """ @@ -833,7 +842,7 @@ class Booster(object): params_str = param_dict_to_str(params) """construct booster object""" _safe_call(_LIB.LGBM_BoosterCreate( - train_set.handle, + train_set.handle, c_str(params_str), ctypes.byref(self.handle))) """save reference to data""" @@ -859,7 +868,7 @@ class Booster(object): """Prediction task""" out_num_iterations = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterCreateFromModelfile( - c_str(model_file), + c_str(model_file), ctypes.byref(out_num_iterations), ctypes.byref(self.handle))) out_num_class = ctypes.c_int64(0) @@ -939,13 +948,13 @@ class Booster(object): raise Exception("Replace training data failed, you should use same predictor for these data") self.train_set = train_set _safe_call(_LIB.LGBM_BoosterResetTrainingData( - self.handle, + self.handle, self.train_set.handle)) self.__inner_predict_buffer[0] = None is_finished = ctypes.c_int(0) if fobj is None: _safe_call(_LIB.LGBM_BoosterUpdateOneIter( - self.handle, + self.handle, ctypes.byref(is_finished))) self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] return is_finished.value == 1 @@ -1080,7 +1089,7 @@ class Booster(object): Parameters ---------- filename : str - filename to save + filename to save num_iteration: int number of iteration that want to save. < 0 means save all """ @@ -1098,16 +1107,16 @@ class Booster(object): data : string/numpy array/scipy.sparse Data source for prediction When data is string type, it represents the path of txt file, - num_iteration : + num_iteration : int used iteration for prediction - raw_score : bool + raw_score : bool True for predict raw score pred_leaf : bool True for predict leaf index data_has_header : bool Used for txt data is_reshape : bool - True for reshape to [nrow, ...] + True for reshape to [nrow, ...] Returns ------- @@ -1136,8 +1145,8 @@ class Booster(object): result = np.array([0.0 for _ in range(self.__num_inner_eval)], dtype=np.float32) tmp_out_len = ctypes.c_int64(0) _safe_call(_LIB.LGBM_BoosterGetEval( - self.handle, - data_idx, + self.handle, + data_idx, ctypes.byref(tmp_out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))) if tmp_out_len.value != self.__num_inner_eval: @@ -1176,12 +1185,12 @@ class Booster(object): tmp_out_len = ctypes.c_int64(0) data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_float)) _safe_call(_LIB.LGBM_BoosterGetPredict( - self.handle, - data_idx, - ctypes.byref(tmp_out_len), + self.handle, + data_idx, + ctypes.byref(tmp_out_len), data_ptr)) if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): - raise ValueError("incorrect number of predict results for data %d" %(data_idx) ) + raise ValueError("incorrect number of predict results for data %d" % (data_idx) ) self.__is_predicted_cur_iter[data_idx] = True return self.__inner_predict_buffer[data_idx] diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index b6861ee0c..8dae4181f 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -148,7 +148,6 @@ def early_stop(stopping_rounds, verbose=True): callback : function The requested callback function. """ - state = {} factor_to_bigger_better = {} best_score = {} best_iter = {} @@ -172,23 +171,21 @@ def early_stop(stopping_rounds, verbose=True): factor_to_bigger_better[i] = -1.0 if env.evaluation_result_list[i][3]: factor_to_bigger_better[i] = 1.0 - state['best_iter'] = 0 def callback(env): """internal function""" if len(best_score) == 0: init(env) - for i in range(len(env.evaluation_result_list)): + for i in range(len(env.evaluation_result_list)): score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i] if score > best_score[i]: best_score[i] = score best_iter[i] = env.iteration if verbose: - best_msg[i] = '[%d]\t%s' % ( env.iteration, + best_msg[i] = '[%d]\t%s' % ( env.iteration, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list])) else: if env.iteration - best_iter[i] >= stopping_rounds: - state['best_iter'] = best_iter[i] if env.model is not None: env.model.set_attr(best_iteration=str(best_iter[i])) if verbose: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 262f0f8b0..81fb00e47 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -1,13 +1,13 @@ """Training Library containing training routines of LightGBM.""" from __future__ import absolute_import -import collections import numpy as np from .basic import LightGBMError, Predictor, Dataset, Booster, is_str from . import callback def _construct_dataset(X_y, reference=None, - params=None, other_fields=None, predictor=None): + params=None, other_fields=None, + predictor=None): if 'max_bin' in params: max_bin = int(params['max_bin']) else: @@ -31,20 +31,22 @@ def _construct_dataset(X_y, reference=None, label = X_y[1] if reference is None: ret = Dataset(data, label=label, max_bin=max_bin, - weight=weight, group=group, predictor=predictor, params=params) + weight=weight, group=group, + predictor=predictor, params=params) else: - ret = reference.create_valid(data, label=label, weight=weight, group=group, params=params) + ret = reference.create_valid(data, label=label, weight=weight, + group=group, params=params) if init_score is not None: ret.set_init_score(init_score) return ret -def train(params, train_data, num_boost_round=100, - valid_datas=None, valid_names=None, - fobj=None, feval=None, init_model=None, - train_fields=None, valid_fields=None, - early_stopping_rounds=None, evals_result=None, - verbose_eval=True, learning_rates=None, callbacks=None): +def train(params, train_data, num_boost_round=100, + valid_datas=None, valid_names=None, + fobj=None, feval=None, init_model=None, + train_fields=None, valid_fields=None, + early_stopping_rounds=None, evals_result=None, + verbose_eval=True, learning_rates=None, callbacks=None): """Train with given parameters. Parameters @@ -134,9 +136,9 @@ def train(params, train_data, num_boost_round=100, continue valid_set = _construct_dataset( valid_datas[i], - train_set, - params, - other_fields, + train_set, + params, + other_fields, predictor) valid_sets.append(valid_set) if valid_names is not None: @@ -182,11 +184,11 @@ def train(params, train_data, num_boost_round=100, for i in range(num_boost_round): for cb in callbacks_before_iter: cb(callback.CallbackEnv(model=booster, - cvfolds=None, - iteration=i, - begin_iteration=0, - end_iteration=num_boost_round, - evaluation_result_list=None)) + cvfolds=None, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=None)) booster.update(fobj=fobj) @@ -199,11 +201,11 @@ def train(params, train_data, num_boost_round=100, try: for cb in callbacks_after_iter: cb(callback.CallbackEnv(model=booster, - cvfolds=None, - iteration=i, - begin_iteration=0, - end_iteration=num_boost_round, - evaluation_result_list=evaluation_result_list)) + cvfolds=None, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=evaluation_result_list)) except callback.EarlyStopException: break if booster.attr('best_iteration') is not None: @@ -384,11 +386,11 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, for i in range(num_boost_round): for cb in callbacks_before_iter: cb(callback.CallbackEnv(model=None, - cvfolds=cvfolds, - iteration=i, - begin_iteration=0, - end_iteration=num_boost_round, - evaluation_result_list=None)) + cvfolds=cvfolds, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=None)) for fold in cvfolds: fold.update(fobj) res = _agg_cv_result([f.eval(feval) for f in cvfolds]) @@ -402,13 +404,13 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, try: for cb in callbacks_after_iter: cb(callback.CallbackEnv(model=None, - cvfolds=cvfolds, - iteration=i, - begin_iteration=0, - end_iteration=num_boost_round, - evaluation_result_list=res)) + cvfolds=cvfolds, + iteration=i, + begin_iteration=0, + end_iteration=num_boost_round, + evaluation_result_list=res)) except callback.EarlyStopException as e: for k in results.keys(): - results[k] = results[k][:(e.state['best_iter'] + 1)] + results[k] = results[k][:(e.best_iteration + 1)] break return results diff --git a/python-package/lightgbm/libpath.py b/python-package/lightgbm/libpath.py index 9efd6cc74..1e980ed94 100644 --- a/python-package/lightgbm/libpath.py +++ b/python-package/lightgbm/libpath.py @@ -12,9 +12,9 @@ def find_lib_path(): """ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) dll_path = [curr_path, os.path.join(curr_path, '../../lib/'), - os.path.join(curr_path, '../../'), - os.path.join(curr_path, './lib/'), - os.path.join(sys.prefix, 'lightgbm')] + os.path.join(curr_path, '../../'), + os.path.join(curr_path, './lib/'), + os.path.join(sys.prefix, 'lightgbm')] if os.name == 'nt': dll_path.append(os.path.join(curr_path, '../../windows/x64/Dll/')) dll_path.append(os.path.join(curr_path, './windows/x64/Dll/')) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index d66fd8392..b2f66b8de 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -194,7 +194,8 @@ class LGBMModel(LGBMModelBase): return params def fit(self, X, y, eval_set=None, eval_metric=None, - early_stopping_rounds=None, verbose=True, train_fields=None, valid_fields=None, other_params=None): + early_stopping_rounds=None, verbose=True, + train_fields=None, valid_fields=None, other_params=None): """ Fit the gradient boosting model @@ -308,7 +309,7 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) def fit(self, X, y, eval_set=None, eval_metric=None, - early_stopping_rounds=None, verbose=True, + early_stopping_rounds=None, verbose=True, train_fields=None, valid_fields=None, other_params=None): self.classes_ = np.unique(y) @@ -328,8 +329,10 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): if eval_set is not None: eval_set = list( (x[0], self._le.transform(x[1])) for x in eval_set ) - super(LGBMClassifier, self).fit(X, training_labels, eval_set, eval_metric, - early_stopping_rounds, verbose, train_fields, valid_fields, other_params) + super(LGBMClassifier, self).fit(X, training_labels, eval_set, + eval_metric, early_stopping_rounds, + verbose, train_fields, valid_fields, + other_params) return self def predict(self, data, raw_score=False, num_iteration=0): @@ -405,7 +408,7 @@ class LGBMRanker(LGBMModel): """ + '\n'.join(LGBMModel.__doc__.split('\n')[2:]) def fit(self, X, y, eval_set=None, eval_metric=None, - early_stopping_rounds=None, verbose=True, + early_stopping_rounds=None, verbose=True, train_fields=None, valid_fields=None, other_params=None): """check group data""" @@ -428,6 +431,8 @@ class LGBMRanker(LGBMModel): self.objective = "lambdarank" self.fobj = None - super(LGBMRanker, self).fit(X, y, eval_set, eval_metric, - early_stopping_rounds, verbose, train_fields, valid_fields, other_params) + super(LGBMRanker, self).fit(X, y, eval_set, eval_metric, + early_stopping_rounds, verbose, + train_fields, valid_fields, + other_params) return self From 29cf97e902d14a81f68a0c7103e67b26dee18efc Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 1 Dec 2016 16:32:47 +0800 Subject: [PATCH 57/60] add example and readme --- examples/python-guide/simple_example.py | 10 ++++++++++ python-package/README.rst | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 examples/python-guide/simple_example.py create mode 100644 python-package/README.rst diff --git a/examples/python-guide/simple_example.py b/examples/python-guide/simple_example.py new file mode 100644 index 000000000..ea7f99237 --- /dev/null +++ b/examples/python-guide/simple_example.py @@ -0,0 +1,10 @@ +import numpy as np +import random +import lightgbm as lgb +from sklearn import datasets, metrics, model_selection + +rng = np.random.RandomState(2016) + +X, y = datasets.make_classification(n_samples=10000, n_features=100) +x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) +lgb_model = lgb.LGBMClassifier(n_estimators=100).fit(x_train, y_train, [(x_test, y_test)], eval_metric="auc") diff --git a/python-package/README.rst b/python-package/README.rst new file mode 100644 index 000000000..1e8fffdb3 --- /dev/null +++ b/python-package/README.rst @@ -0,0 +1,19 @@ +LightGBM Python Package +======================= + +Installation +------------ + +1. Following `Installation Guide `__ to build first. + For the windows user, please change the build config to ``DLL``. +2. Install with ``cd python-package; python setpy.py install`` + +Note: Make sure you have `setuptools `__ + + + +Examples +-------- + +- Refer also to the walk through examples in `python-guide + folder `__ From 9c3e2718a5345d2de57f6fd8fc28956fddd1f50e Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Thu, 1 Dec 2016 19:14:22 +0800 Subject: [PATCH 58/60] Update simple_example.py --- examples/python-guide/simple_example.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/python-guide/simple_example.py b/examples/python-guide/simple_example.py index ea7f99237..0ce96b057 100644 --- a/examples/python-guide/simple_example.py +++ b/examples/python-guide/simple_example.py @@ -8,3 +8,10 @@ rng = np.random.RandomState(2016) X, y = datasets.make_classification(n_samples=10000, n_features=100) x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1) lgb_model = lgb.LGBMClassifier(n_estimators=100).fit(x_train, y_train, [(x_test, y_test)], eval_metric="auc") +lgb_model.predict(x_test) +# save model +lgb_model.booster().save_model('model.txt') +# load model +booster = lgb.Booster(model_file='model.txt') +# predict +print(booster.predict(x_test)) From e2fd11ce75dfc69e5291b3efd2e352047a027370 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Fri, 2 Dec 2016 01:28:22 +0800 Subject: [PATCH 59/60] Update README.rst --- python-package/README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/README.rst b/python-package/README.rst index 1e8fffdb3..d1a5af9a4 100644 --- a/python-package/README.rst +++ b/python-package/README.rst @@ -6,7 +6,7 @@ Installation 1. Following `Installation Guide `__ to build first. For the windows user, please change the build config to ``DLL``. -2. Install with ``cd python-package; python setpy.py install`` +2. Install with ``cd python-package; python setup.py install`` Note: Make sure you have `setuptools `__ From 6c1a74ab76d02a8d0565947ac25ea7bc1e8887dd Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Fri, 2 Dec 2016 09:55:23 +0800 Subject: [PATCH 60/60] Update README.md --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0a86c107e..45e724fab 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ LightGBM, Light Gradient Boosting Machine -========== +========================================= [![Build Status](https://travis-ci.org/Microsoft/LightGBM.svg?branch=master)](https://travis-ci.org/Microsoft/LightGBM) LightGBM is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed and efficient with the following advantages: @@ -14,6 +14,11 @@ For more details, please refer to [Features](https://github.com/Microsoft/LightG [Experiments](https://github.com/Microsoft/LightGBM/wiki/Experiments#comparison-experiment) on public datasets show that LightGBM can outperform other existing boosting framework on both efficiency and accuracy, with significant lower memory consumption. What's more, the [experiments](https://github.com/Microsoft/LightGBM/wiki/Experiments#parallel-experiment) show that LightGBM can achieve a linear speed-up by using multiple machines for training in specific settings. +News +---- + +12/02/2012 : Release [python-package](https://github.com/Microsoft/LightGBM/tree/master/python-package) beta version, welcome to have a try and provide issues and feedback. + Get Started ------------ To get started, please follow the [Installation Guide](https://github.com/Microsoft/LightGBM/wiki/Installation-Guide) and [Quick Start](https://github.com/Microsoft/LightGBM/wiki/Quick-Start).