[docs][ci][python] added docstring style test and fixed errors in existing docstrings (#1759)

* added docstring style test and fixed errors in existing docstrings

* hotfix

* hotfix

* fix grammar

* hotfix
This commit is contained in:
Nikita Titov 2018-10-16 09:19:36 +03:00 коммит произвёл Qiwei Ye
Родитель dfdf88618e
Коммит ccf2570ca3
15 изменённых файлов: 453 добавлений и 297 удалений

Просмотреть файл

@ -45,8 +45,9 @@ if [[ $TRAVIS == "true" ]] && [[ $TASK == "check-docs" ]]; then
fi
if [[ $TASK == "pylint" ]]; then
conda install -y -n $CONDA_ENV pycodestyle
conda install -y -n $CONDA_ENV pycodestyle pydocstyle
pycodestyle --ignore=E501,W503 --exclude=./compute,./.nuget . || exit -1
pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
exit 0
fi

Просмотреть файл

@ -1,3 +1,5 @@
# coding: utf-8
"""Script for generating files with NuGet package metadata."""
import os
import sys

Просмотреть файл

@ -16,7 +16,7 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute.
"""Sphinx configuration file."""
import datetime
import os
import sys
@ -128,4 +128,11 @@ htmlhelp_basename = 'LightGBMdoc'
def setup(app):
"""Add new elements at Sphinx initialization time.
Parameters
----------
app : object
The application object representing the Sphinx process.
"""
app.add_javascript("js/script.js")

Просмотреть файл

@ -1,6 +1,7 @@
# coding: utf-8
# pylint: disable = invalid-name, C0111
'''
"""Comparison of `binary` and `xentropy` objectives.
BLUF: The `xentropy` objective does logistic regression and generalizes
to the case where labels are probabilistic (i.e. numbers between 0 and 1).
@ -9,7 +10,7 @@ Details: Both `binary` and `xentropy` minimize the log loss and use
between them with default settings is that `binary` may achieve a slight
speed improvement by assuming that the labels are binary instead of
probabilistic.
'''
"""
import time
@ -46,19 +47,28 @@ DATA = {
#################
# Set up a couple of utilities for our experiments
def log_loss(preds, labels):
''' logarithmic loss with non-necessarily-binary labels '''
"""Logarithmic loss with non-necessarily-binary labels."""
log_likelihood = np.sum(labels * np.log(preds)) / len(preds)
return -log_likelihood
def experiment(objective, label_type, data):
'''
Measure performance of an objective
:param objective: (str) 'binary' or 'xentropy'
:param label_type: (str) 'binary' or 'probability'
:param data: DATA
:return: dict with experiment summary stats
'''
"""Measure performance of an objective.
Parameters
----------
objective : string 'binary' or 'xentropy'
Objective function.
label_type : string 'binary' or 'probability'
Type of the label.
data : dict
Data for training.
Returns
-------
result : dict
Experiment summary stats.
"""
np.random.seed(0)
nrounds = 5
lgb_data = data['lgb_with_' + label_type + '_labels']

Просмотреть файл

@ -1,5 +1,7 @@
# coding: utf-8
"""This script generates LightGBM/src/io/config_auto.cpp file
"""Helper script for generating config file and parameters list.
This script generates LightGBM/src/io/config_auto.cpp file
with list of all parameters, aliases table and other routines
along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file.
@ -7,7 +9,19 @@ from the information in LightGBM/include/LightGBM/config.h file.
import os
def GetParameterInfos(config_hpp):
def get_parameter_infos(config_hpp):
"""Parse config header file.
Parameters
----------
config_hpp : string
Path to the config header file.
Returns
-------
infos : tuple
Tuple with names and content of sections.
"""
is_inparameter = False
parameter_group = None
cur_key = None
@ -63,7 +77,19 @@ def GetParameterInfos(config_hpp):
return keys, member_infos
def GetNames(infos):
def get_names(infos):
"""Get names of all parameters.
Parameters
----------
infos : list
Content of the config header file.
Returns
-------
names : list
Names of all parameters.
"""
names = []
for x in infos:
for y in x:
@ -71,7 +97,19 @@ def GetNames(infos):
return names
def GetAlias(infos):
def get_alias(infos):
"""Get aliases of all parameters.
Parameters
----------
infos : list
Content of the config header file.
Returns
-------
pairs : list
List of tuples (param alias, param name).
"""
pairs = []
for x in infos:
for y in x:
@ -83,7 +121,23 @@ def GetAlias(infos):
return pairs
def SetOneVarFromString(name, param_type, checks):
def set_one_var_from_string(name, param_type, checks):
"""Construct code for auto config file for one param value.
Parameters
----------
name : string
Name of the parameter.
param_type : string
Type of the parameter.
checks : list
Constraints of the parameter.
Returns
-------
ret : string
Lines of auto config file with getting and checks of one parameter value.
"""
ret = ""
univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"}
if "vector" not in param_type:
@ -103,9 +157,33 @@ def SetOneVarFromString(name, param_type, checks):
return ret
def GenParameterDescription(sections, descriptions, params_rst):
def gen_parameter_description(sections, descriptions, params_rst):
"""Write descriptions of parameters to the documentation file.
Parameters
----------
sections : list
Names of parameters sections.
descriptions : list
Structured descriptions of parameters.
params_rst : string
Path to the file with parameters documentation.
"""
def parse_check(check, reverse=False):
"""Parse the constraint.
Parameters
----------
check : string
String representation of the constraint.
reverse : bool, optional (default=False)
Whether to reverse the sign of the constraint.
Returns
-------
pair : tuple
Parsed constraint in the form of tuple (value, sign).
"""
try:
idx = 1
float(check[idx:])
@ -164,10 +242,24 @@ def GenParameterDescription(sections, descriptions, params_rst):
new_params_file.write(after)
def GenParameterCode(config_hpp, config_out_cpp):
keys, infos = GetParameterInfos(config_hpp)
names = GetNames(infos)
alias = GetAlias(infos)
def gen_parameter_code(config_hpp, config_out_cpp):
"""Generate auto config file.
Parameters
----------
config_hpp : string
Path to the config header file.
config_out_cpp : string
Path to the auto config file.
Returns
-------
infos : tuple
Tuple with names and content of sections.
"""
keys, infos = get_parameter_infos(config_hpp)
names = get_names(infos)
alias = get_alias(infos)
str_to_write = "/// This file is auto generated by LightGBM\\helper\\parameter_generator.py from LightGBM\\include\\LightGBM\\config.h file.\n"
str_to_write += "#include<LightGBM/config.h>\nnamespace LightGBM {\n"
# alias table
@ -192,7 +284,7 @@ def GenParameterCode(config_hpp, config_out_cpp):
checks = []
if "check" in y:
checks = y["check"]
tmp = SetOneVarFromString(name, param_type, checks)
tmp = set_one_var_from_string(name, param_type, checks)
str_to_write += tmp
# tails
str_to_write += "}\n\n"
@ -226,5 +318,5 @@ if __name__ == "__main__":
config_hpp = os.path.join(current_dir, os.path.pardir, 'include', 'LightGBM', 'config.h')
config_out_cpp = os.path.join(current_dir, os.path.pardir, 'src', 'io', 'config_auto.cpp')
params_rst = os.path.join(current_dir, os.path.pardir, 'docs', 'Parameters.rst')
sections, descriptions = GenParameterCode(config_hpp, config_out_cpp)
GenParameterDescription(sections, descriptions, params_rst)
sections, descriptions = gen_parameter_code(config_hpp, config_out_cpp)
gen_parameter_description(sections, descriptions, params_rst)

Просмотреть файл

@ -151,8 +151,8 @@ Examples
Refer to the walk through examples in `Python guide folder <https://github.com/Microsoft/LightGBM/tree/master/examples/python-guide>`_.
Developments
------------
Development Guide
-----------------
The code style of Python-package follows `PEP 8 <https://www.python.org/dev/peps/pep-0008/>`_. If you would like to make a contribution and not familiar with PEP 8, please check the PEP 8 style guide first. Otherwise, the check won't pass. You should be careful about:
@ -166,6 +166,8 @@ The code style of Python-package follows `PEP 8 <https://www.python.org/dev/peps
E501 (line too long) and W503 (line break occurred before a binary operator) can be ignored.
Documentation strings (docstrings) are written in the NumPy style.
.. |License| image:: https://img.shields.io/badge/license-MIT-blue.svg
:target: https://github.com/Microsoft/LightGBM/blob/master/LICENSE
.. |Python Versions| image:: https://img.shields.io/pypi/pyversions/lightgbm.svg

Просмотреть файл

@ -1,7 +1,7 @@
# coding: utf-8
"""LightGBM, Light Gradient Boosting Machine.
Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors
Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors.
"""
from __future__ import absolute_import

Просмотреть файл

@ -1,7 +1,7 @@
# coding: utf-8
# pylint: disable = invalid-name, C0111, C0301
# pylint: disable = R0912, R0913, R0914, W0105, W0201, W0212
"""Wrapper c_api of LightGBM"""
"""Wrapper for C API of LightGBM."""
from __future__ import absolute_import
import copy
@ -22,7 +22,7 @@ from .libpath import find_lib_path
def _load_lib():
"""Load LightGBM Library."""
"""Load LightGBM library."""
lib_path = find_lib_path()
if len(lib_path) == 0:
return None
@ -35,18 +35,19 @@ _LIB = _load_lib()
def _safe_call(ret):
"""Check the return value of C API call
"""Check the return value from C API call.
Parameters
----------
ret : int
return value from API calls
The return value from C API calls.
"""
if ret != 0:
raise LightGBMError(decode_string(_LIB.LGBM_GetLastError()))
def is_numeric(obj):
"""Check is a number or not, include numpy number etc."""
"""Check whether object is a number or not, include numpy number, etc."""
try:
float(obj)
return True
@ -57,18 +58,17 @@ def is_numeric(obj):
def is_numpy_1d_array(data):
"""Check is 1d numpy array"""
"""Check whether data is a 1-D numpy array."""
return isinstance(data, np.ndarray) and len(data.shape) == 1
def is_1d_list(data):
"""Check is 1d list"""
return isinstance(data, list) and \
(not data or is_numeric(data[0]))
"""Check whether data is a 1-D list."""
return isinstance(data, list) and (not data or is_numeric(data[0]))
def list_to_1d_numpy(data, dtype=np.float32, name='list'):
"""convert to 1d numpy array"""
"""Convert data to 1-D numpy array."""
if is_numpy_1d_array(data):
if data.dtype == dtype:
return data
@ -84,8 +84,7 @@ def list_to_1d_numpy(data, dtype=np.float32, name='list'):
def cfloat32_array_to_numpy(cptr, length):
"""Convert a ctypes float pointer array to a numpy array.
"""
"""Convert a ctypes float pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
return np.fromiter(cptr, dtype=np.float32, count=length)
else:
@ -93,8 +92,7 @@ def cfloat32_array_to_numpy(cptr, length):
def cfloat64_array_to_numpy(cptr, length):
"""Convert a ctypes double pointer array to a numpy array.
"""
"""Convert a ctypes double pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_double)):
return np.fromiter(cptr, dtype=np.float64, count=length)
else:
@ -102,8 +100,7 @@ def cfloat64_array_to_numpy(cptr, length):
def cint32_array_to_numpy(cptr, length):
"""Convert a ctypes float pointer array to a numpy array.
"""
"""Convert a ctypes int pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)):
return np.fromiter(cptr, dtype=np.int32, count=length)
else:
@ -111,16 +108,17 @@ def cint32_array_to_numpy(cptr, length):
def c_str(string):
"""Convert a python string to cstring."""
"""Convert a Python string to C string."""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Convert a python array to c array."""
"""Convert a Python array to C array."""
return (ctype * len(values))(*values)
def param_dict_to_str(data):
"""Convert Python dictionary to string, which is passed to C API."""
if data is None or not data:
return ""
pairs = []
@ -156,28 +154,29 @@ class _TempFile(object):
class LightGBMError(Exception):
"""Error throwed by LightGBM"""
"""Error thrown by LightGBM."""
pass
MAX_INT32 = (1 << 31) - 1
"""marco definition of data type in c_api of LightGBM"""
"""Macro definition of data type in C API of LightGBM"""
C_API_DTYPE_FLOAT32 = 0
C_API_DTYPE_FLOAT64 = 1
C_API_DTYPE_INT32 = 2
C_API_DTYPE_INT64 = 3
"""Matric is row major in python"""
"""Matrix is row major in Python"""
C_API_IS_ROW_MAJOR = 1
"""marco definition of prediction type in c_api of LightGBM"""
"""Macro definition of prediction type in C API of LightGBM"""
C_API_PREDICT_NORMAL = 0
C_API_PREDICT_RAW_SCORE = 1
C_API_PREDICT_LEAF_INDEX = 2
C_API_PREDICT_CONTRIB = 3
"""data type of data field"""
"""Data type of data field"""
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
"weight": C_API_DTYPE_FLOAT32,
"init_score": C_API_DTYPE_FLOAT64,
@ -185,12 +184,12 @@ FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
'int64': 'int', 'uint8': 'int', 'uint16': 'int',
'uint32': 'int', 'uint64': 'int', 'float16': 'float',
'float32': 'float', 'float64': 'float', 'bool': 'int'}
'uint32': 'int', 'uint64': 'int', 'bool': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float'}
def convert_from_sliced_object(data):
"""fix the memory of multi-dimensional sliced object"""
"""Fix the memory of multi-dimensional sliced object."""
if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous:
warnings.warn("Usage of np.ndarray subset (sliced data) is not recommended "
@ -200,7 +199,7 @@ def convert_from_sliced_object(data):
def c_float_array(data):
"""get pointer of float numpy array / list"""
"""Get pointer of float numpy array / list."""
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
@ -221,7 +220,7 @@ def c_float_array(data):
def c_int_array(data):
"""get pointer of int numpy array / list"""
"""Get pointer of int numpy array / list."""
if is_1d_list(data):
data = np.array(data, copy=False)
if is_numpy_1d_array(data):
@ -314,22 +313,27 @@ def _load_pandas_categorical(file_name):
class _InnerPredictor(object):
"""_InnerPredictor of LightGBM.
Not exposed to user.
Used only for prediction, usually used for continued training.
Note
----
Can be converted from Booster, but cannot be converted to Booster.
"""
A _InnerPredictor of LightGBM.
Only used for prediction, usually used for continued-train
Note: Can convert from Booster, but cannot convert to Booster
"""
def __init__(self, model_file=None, booster_handle=None, pred_parameter=None):
"""Initialize the _InnerPredictor. Not exposed to user
"""Initialize the _InnerPredictor.
Parameters
----------
model_file : string
model_file : string or None, optional (default=None)
Path to the model file.
booster_handle : Handle of Booster
use handle to init
pred_parameter: dict
Other parameters for the prediciton
booster_handle : object or None, optional (default=None)
Handle of Booster.
pred_parameter: dict or None, optional (default=None)
Other parameters for the prediciton.
"""
self.handle = ctypes.c_void_p()
self.__is_manage_handle = True
@ -382,30 +386,31 @@ class _InnerPredictor(object):
def predict(self, data, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False,
is_reshape=True):
"""
Predict logic
"""Predict logic.
Parameters
----------
data : string, numpy array, pandas DataFrame or scipy.sparse
Data source for prediction
When data type is string, it represents the path of txt file
num_iteration : int
Used iteration for prediction
raw_score : bool
True for predict raw score
pred_leaf : bool
True for predict leaf index
pred_contrib : bool
True for predict feature contributions
data_has_header : bool
Used for txt data, True if txt data has header
is_reshape : bool
Reshape to (nrow, ncol) if true
Data source for prediction.
When data type is string, it represents the path of txt file.
num_iteration : int, optional (default=-1)
Iteration used for prediction.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
data_has_header : bool, optional (default=False)
Whether data has header.
Used only for txt data.
is_reshape : bool, optional (default=True)
Whether to reshape to (nrow, ncol).
Returns
-------
Prediction result
result : numpy array
Prediction result.
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
@ -465,9 +470,7 @@ class _InnerPredictor(object):
return preds
def __get_num_preds(self, num_iteration, nrow, predict_type):
"""
Get size of prediction result
"""
"""Get size of prediction result."""
if nrow > MAX_INT32:
raise LightGBMError('LightGBM cannot perform prediction for data'
'with number of rows greater than MAX_INT32 (%d).\n'
@ -483,9 +486,7 @@ class _InnerPredictor(object):
return n_preds.value
def __pred_for_np2d(self, mat, num_iteration, predict_type):
"""
Predict for a 2-D numpy matrix.
"""
"""Predict for a 2-D numpy matrix."""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray or list must be 2 dimensional')
@ -534,9 +535,7 @@ class _InnerPredictor(object):
return inner_predict(mat, num_iteration, predict_type)
def __pred_for_csr(self, csr, num_iteration, predict_type):
"""
Predict for a csr data
"""
"""Predict for a CSR data."""
def inner_predict(csr, num_iteration, predict_type, preds=None):
nrow = len(csr.indptr) - 1
n_preds = self.__get_num_preds(num_iteration, nrow, predict_type)
@ -587,9 +586,7 @@ class _InnerPredictor(object):
return inner_predict(csr, num_iteration, predict_type)
def __pred_for_csc(self, csc, num_iteration, predict_type):
"""
Predict for a csc data
"""
"""Predict for a CSC data."""
nrow = csc.shape[0]
if nrow > MAX_INT32:
return self.__pred_for_csr(csc.tocsr(), num_iteration, predict_type)
@ -625,18 +622,19 @@ class _InnerPredictor(object):
class Dataset(object):
"""Dataset in LightGBM."""
def __init__(self, data, label=None, reference=None,
weight=None, group=None, init_score=None, silent=False,
feature_name='auto', categorical_feature='auto', params=None,
free_raw_data=True):
"""Construct Dataset.
"""Initialize Dataset.
Parameters
----------
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas one-column DataFrame/Series or None, optional (default=None)
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference.
@ -660,7 +658,7 @@ class Dataset(object):
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
params : dict or None, optional (default=None)
Other parameters.
Other parameters for Dataset.
free_raw_data : bool, optional (default=True)
If True, raw data is freed after constructing inner Dataset.
"""
@ -810,9 +808,7 @@ class Dataset(object):
return self.set_feature_name(feature_name)
def __init_from_np2d(self, mat, params_str, ref_dataset):
"""
Initialize data from a 2-D numpy matrix.
"""
"""Initialize data from a 2-D numpy matrix."""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
@ -836,9 +832,7 @@ class Dataset(object):
return self
def __init_from_list_np2d(self, mats, params_str, ref_dataset):
"""
Initialize data from list of 2-D numpy matrices.
"""
"""Initialize data from a list of 2-D numpy matrices."""
ncol = mats[0].shape[1]
nrow = np.zeros((len(mats),), np.int32)
if mats[0].dtype == np.float64:
@ -885,9 +879,7 @@ class Dataset(object):
return self
def __init_from_csr(self, csr, params_str, ref_dataset):
"""
Initialize data from a CSR matrix.
"""
"""Initialize data from a CSR matrix."""
if len(csr.indices) != len(csr.data):
raise ValueError('Length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data)))
self.handle = ctypes.c_void_p()
@ -913,9 +905,7 @@ class Dataset(object):
return self
def __init_from_csc(self, csc, params_str, ref_dataset):
"""
Initialize data from a csc matrix.
"""
"""Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data):
raise ValueError('Length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data)))
self.handle = ctypes.c_void_p()
@ -996,7 +986,7 @@ class Dataset(object):
data : string, numpy array, pandas DataFrame, scipy.sparse or list of numpy arrays
Data source of Dataset.
If string, it represents the path to txt file.
label : list, numpy 1-D array, pandas one-column DataFrame/Series or None, optional (default=None)
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each instance.
@ -1007,7 +997,7 @@ class Dataset(object):
silent : bool, optional (default=False)
Whether to print messages during construction.
params : dict or None, optional (default=None)
Other parameters.
Other parameters for validation Dataset.
Returns
-------
@ -1029,7 +1019,7 @@ class Dataset(object):
used_indices : list of int
Indices used to create the subset.
params : dict or None, optional (default=None)
Other parameters.
These parameters will be passed to Dataset constructor.
Returns
-------
@ -1193,9 +1183,10 @@ class Dataset(object):
"set free_raw_data=False when construct Dataset to avoid this.")
def _set_predictor(self, predictor):
"""
Set predictor for continued training, not recommended for user to call this function.
Please set init_model in engine.train or engine.cv
"""Set predictor for continued training.
It is not recommended for user to call this function.
Please use init_model argument in engine.train() or engine.cv() instead.
"""
if predictor is self._predictor:
return self
@ -1259,11 +1250,11 @@ class Dataset(object):
return self
def set_label(self, label):
"""Set label of Dataset
"""Set label of Dataset.
Parameters
----------
label : list, numpy 1-D array, pandas one-column DataFrame/Series or None
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
The label information to be set into Dataset.
Returns
@ -1420,8 +1411,11 @@ class Dataset(object):
raise LightGBMError("Cannot get num_feature before construct dataset")
def get_ref_chain(self, ref_limit=100):
"""Get a chain of Dataset objects, starting with r, then going to r.reference if exists,
then to r.reference.reference, etc. until we hit ``ref_limit`` or a reference loop.
"""Get a chain of Dataset objects.
Starts with r, then goes to r.reference (if exists),
then to r.reference.reference, etc.
until we hit ``ref_limit`` or a reference loop.
Parameters
----------
@ -1449,6 +1443,7 @@ class Dataset(object):
class Booster(object):
"""Booster in LightGBM."""
def __init__(self, params=None, train_set=None, model_file=None, silent=False):
"""Initialize the Booster.
@ -1732,7 +1727,6 @@ class Booster(object):
is_finished : bool
Whether the update was successfully finished.
"""
# need reset training data
if train_set is not None and train_set is not self.train_set:
if not isinstance(train_set, Dataset):
@ -1762,18 +1756,19 @@ class Booster(object):
return self.__boost(grad, hess)
def __boost(self, grad, hess):
"""
Boost Booster for one iteration with customized gradient statistics.
"""Boost Booster for one iteration with customized gradient statistics.
Note: For multi-class task, the score is group by class_id first, then group by row_id.
If you want to get i-th row score in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
Note
----
For multi-class task, the score is group by class_id first, then group by row_id.
If you want to get i-th row score in j-th class, the access way is score[j * num_data + i]
and you should group grad and hess in this way as well.
Parameters
----------
grad : 1d numpy array or list
grad : 1-D numpy array or 1-D list
The first order derivative (gradient).
hess : 1d numpy or 1d list
hess : 1-D numpy array or 1-D list
The second order derivative (Hessian).
Returns
@ -1863,10 +1858,10 @@ class Booster(object):
Name of the data.
feval : callable or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data.
Should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
Note: should return (eval_name, eval_result, is_higher_better) or list of such tuples.
Returns
-------
@ -1897,10 +1892,10 @@ class Booster(object):
----------
feval : callable or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data.
Should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
Note: should return (eval_name, eval_result, is_higher_better) or list of such tuples.
Returns
-------
@ -1916,10 +1911,10 @@ class Booster(object):
----------
feval : callable or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data.
Should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
Note: should return (eval_name, eval_result, is_higher_better) or list of such tuples.
Returns
-------
@ -1964,10 +1959,10 @@ class Booster(object):
Parameters
----------
start_iteration : int, optional (default=0)
Index of the iteration that will start to shuffle.
The first iteration that will be shuffled.
end_iteration : int, optional (default=-1)
The last iteration that will be shuffled.
If <= 0, means the last iteration.
If <= 0, means the last available iteration.
Returns
-------
@ -2044,7 +2039,7 @@ class Booster(object):
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
'''if buffer length is not long enough, re-allocate a buffer'''
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
@ -2088,7 +2083,7 @@ class Booster(object):
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
'''if buffer length is not long enough, reallocate a buffer'''
# if buffer length is not long enough, reallocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
@ -2103,7 +2098,7 @@ class Booster(object):
def predict(self, data, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, pred_parameter=None, **kwargs):
data_has_header=False, is_reshape=True, **kwargs):
"""Make a prediction.
Parameters
@ -2133,7 +2128,8 @@ class Booster(object):
Used only if data is string.
is_reshape : bool, optional (default=True)
If True, result is reshaped to [nrow, ncol].
**kwargs : other parameters for the prediction
**kwargs
Other parameters for the prediction.
Returns
-------
@ -2155,12 +2151,13 @@ class Booster(object):
data : string, numpy array, pandas DataFrame or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
label : list, numpy 1-D array or pandas one-column DataFrame/Series
label : list, numpy 1-D array or pandas Series / one-column DataFrame
Label for refit.
decay_rate : float, optional (default=0.9)
Decay rate of refit,
will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
**kwargs : other parameters for refit
**kwargs
Other parameters for refit.
These parameters will be passed to ``predict`` method.
Returns
@ -2214,7 +2211,7 @@ class Booster(object):
return ret.value
def _to_predictor(self, pred_parameter=None):
"""Convert to predictor"""
"""Convert to predictor."""
predictor = _InnerPredictor(booster_handle=self.handle, pred_parameter=pred_parameter)
predictor.pandas_categorical = self.pandas_categorical
return predictor
@ -2254,7 +2251,7 @@ class Booster(object):
raise ValueError("Length of feature names doesn't equal with num_feature")
return [string_buffers[i].value.decode() for i in range_(num_feature)]
def feature_importance(self, importance_type='split', iteration=-1):
def feature_importance(self, importance_type='split', iteration=None):
"""Get feature importances.
Parameters
@ -2263,12 +2260,18 @@ class Booster(object):
How the importance is calculated.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
iteration : int or None, optional (default=None)
Limit number of iterations in the feature importance calculation.
If None, if the best iteration exists, it is used; otherwise, all trees are used.
If <= 0, all trees are used (no limits).
Returns
-------
result : numpy array
Array with feature importances.
"""
if iteration is None:
iteration = self.best_iteration
if importance_type == "split":
importance_type_int = 0
elif importance_type == "gain":
@ -2287,9 +2290,7 @@ class Booster(object):
return result
def __inner_eval(self, data_name, data_idx, feval=None):
"""
Evaluate training or validation data
"""
"""Evaluate training or validation data."""
if data_idx >= self.__num_dataset:
raise ValueError("Data_idx should be smaller than number of dataset")
self.__get_eval_info()
@ -2322,9 +2323,7 @@ class Booster(object):
return ret
def __inner_predict(self, data_idx):
"""
Predict for training and validation dataset
"""
"""Predict for training and validation dataset."""
if data_idx >= self.__num_dataset:
raise ValueError("Data_idx should be smaller than number of dataset")
if self.__inner_predict_buffer[data_idx] is None:
@ -2348,9 +2347,7 @@ class Booster(object):
return self.__inner_predict_buffer[data_idx]
def __get_eval_info(self):
"""
Get inner evaluation count and names
"""
"""Get inner evaluation count and names."""
if self.__need_reload_eval_info:
self.__need_reload_eval_info = False
out_num_eval = ctypes.c_int(0)
@ -2392,7 +2389,7 @@ class Booster(object):
return self.__attr.get(key, None)
def set_attr(self, **kwargs):
"""Set the attribute of the Booster.
"""Set attributes to the Booster.
Parameters
----------
@ -2403,7 +2400,7 @@ class Booster(object):
Returns
-------
self : Booster
Booster with set attribute.
Booster with set attributes.
"""
for key, value in kwargs.items():
if value is not None:

Просмотреть файл

@ -1,5 +1,6 @@
# coding: utf-8
# pylint: disable = invalid-name, W0105, C0301
"""Callbacks library."""
from __future__ import absolute_import
import collections
@ -9,14 +10,18 @@ from .compat import range_
class EarlyStopException(Exception):
"""Exception of early stopping.
"""Exception of early stopping."""
Parameters
----------
best_iteration : int
The best iteration stopped.
"""
def __init__(self, best_iteration, best_score):
"""Create early stopping exception.
Parameters
----------
best_iteration : int
The best iteration stopped.
best_score : float
The score of the best iteration.
"""
super(EarlyStopException, self).__init__()
self.best_iteration = best_iteration
self.best_score = best_score
@ -34,7 +39,7 @@ CallbackEnv = collections.namedtuple(
def _format_eval_result(value, show_stdv=True):
"""format metric string"""
"""Format metric string."""
if len(value) == 4:
return '%s\'s %s: %g' % (value[0], value[1], value[2])
elif len(value) == 5:
@ -61,13 +66,12 @@ def print_evaluation(period=1, show_stdv=True):
callback : function
The callback that prints the evaluation results every ``period`` iteration(s).
"""
def callback(env):
"""internal function"""
def _callback(env):
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
print('[%d]\t%s' % (env.iteration + 1, result))
callback.order = 10
return callback
_callback.order = 10
return _callback
def record_evaluation(eval_result):
@ -87,19 +91,17 @@ def record_evaluation(eval_result):
raise TypeError('Eval_result should be a dictionary')
eval_result.clear()
def init(env):
"""internal function"""
def _init(env):
for data_name, _, _, _ in env.evaluation_result_list:
eval_result.setdefault(data_name, collections.defaultdict(list))
def callback(env):
"""internal function"""
def _callback(env):
if not eval_result:
init(env)
_init(env)
for data_name, eval_name, result, _ in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result)
callback.order = 20
return callback
_callback.order = 20
return _callback
def reset_parameter(**kwargs):
@ -111,7 +113,7 @@ def reset_parameter(**kwargs):
Parameters
----------
**kwargs: value should be list or function
**kwargs : value should be list or function
List of parameters for each boosting round
or a customized function that calculates the parameter in terms of
current number of round (e.g. yields learning rate decay).
@ -123,8 +125,7 @@ def reset_parameter(**kwargs):
callback : function
The callback that resets the parameter after the first iteration.
"""
def callback(env):
"""internal function"""
def _callback(env):
new_parameters = {}
for key, value in kwargs.items():
if key in ['num_class', 'num_classes',
@ -143,9 +144,9 @@ def reset_parameter(**kwargs):
if new_parameters:
env.model.reset_parameter(new_parameters)
env.params.update(new_parameters)
callback.before_iteration = True
callback.order = 10
return callback
_callback.before_iteration = True
_callback.order = 10
return _callback
def early_stopping(stopping_rounds, verbose=True):
@ -164,7 +165,6 @@ def early_stopping(stopping_rounds, verbose=True):
----------
stopping_rounds : int
The possible number of rounds without the trend occurrence.
verbose : bool, optional (default=True)
Whether to print message with early stopping information.
@ -178,8 +178,7 @@ def early_stopping(stopping_rounds, verbose=True):
best_score_list = []
cmp_op = []
def init(env):
"""internal function"""
def _init(env):
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')
@ -198,10 +197,9 @@ def early_stopping(stopping_rounds, verbose=True):
best_score.append(float('inf'))
cmp_op.append(lt)
def callback(env):
"""internal function"""
def _callback(env):
if not cmp_op:
init(env)
_init(env)
for i in range_(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if cmp_op[i](score, best_score[i]):
@ -218,5 +216,5 @@ def early_stopping(stopping_rounds, verbose=True):
print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
raise EarlyStopException(best_iter[i], best_score_list[i])
callback.order = 30
return callback
_callback.order = 30
return _callback

Просмотреть файл

@ -1,6 +1,6 @@
# coding: utf-8
# pylint: disable = C0103
"""Compatibility"""
"""Compatibility library."""
from __future__ import absolute_import
import inspect
@ -10,7 +10,7 @@ import numpy as np
is_py3 = (sys.version_info[0] == 3)
"""compatibility between python2 and python3"""
"""Compatibility between Python2 and Python3"""
if is_py3:
zip_ = zip
string_type = str
@ -19,10 +19,11 @@ if is_py3:
range_ = range
def argc_(func):
"""return number of arguments of a function"""
"""Count the number of arguments of a function."""
return len(inspect.signature(func).parameters)
def decode_string(bytestring):
"""Decode C bytestring to ordinary string."""
return bytestring.decode('utf-8')
else:
from itertools import izip as zip_
@ -32,10 +33,11 @@ else:
range_ = xrange
def argc_(func):
"""return number of arguments of a function"""
"""Count the number of arguments of a function."""
return len(inspect.getargspec(func).args)
def decode_string(bytestring):
"""Decode C bytestring to ordinary string."""
return bytestring
"""json"""
@ -48,6 +50,7 @@ except (ImportError, SyntaxError):
def json_default_with_numpy(obj):
"""Convert numpy classes to JSON serializable objects."""
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
elif isinstance(obj, np.ndarray):
@ -64,9 +67,13 @@ except ImportError:
PANDAS_INSTALLED = False
class Series(object):
"""Dummy class for pandas.Series."""
pass
class DataFrame(object):
"""Dummy class for pandas.DataFrame."""
pass
"""matplotlib"""
@ -131,4 +138,6 @@ except ImportError:
# DeprecationWarning is not shown by default, so let's create our own with higher level
class LGBMDeprecationWarning(UserWarning):
"""Custom deprecation warning."""
pass

Просмотреть файл

@ -1,6 +1,6 @@
# coding: utf-8
# pylint: disable = invalid-name, W0105
"""Training Library containing training routines of LightGBM."""
"""Library with training routines of LightGBM."""
from __future__ import absolute_import
import collections
@ -30,21 +30,21 @@ def train(params, train_set, num_boost_round=100,
params : dict
Parameters for training.
train_set : Dataset
Data to be trained.
num_boost_round: int, optional (default=100)
Data to be trained on.
num_boost_round : int, optional (default=100)
Number of boosting iterations.
valid_sets: list of Datasets or None, optional (default=None)
List of data to be evaluated during training.
valid_names: list of string or None, optional (default=None)
valid_sets : list of Datasets or None, optional (default=None)
List of data to be evaluated on during training.
valid_names : list of strings or None, optional (default=None)
Names of ``valid_sets``.
fobj : callable or None, optional (default=None)
Customized objective function.
feval : callable or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data.
Should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
Note: should return (eval_name, eval_result, is_higher_better) or list of such tuples.
To ignore the default metric corresponding to the used objective,
set the ``metric`` parameter to the string ``"None"`` in ``params``.
init_model : string, Booster or None, optional (default=None)
@ -60,23 +60,24 @@ def train(params, train_set, num_boost_round=100,
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
early_stopping_rounds: int or None, optional (default=None)
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
If early stopping occurs, the model will add ``best_iteration`` field.
evals_result: dict or None, optional (default=None)
evals_result : dict or None, optional (default=None)
This dictionary used to store all evaluation results of all the items in ``valid_sets``.
Example
-------
With a ``valid_sets`` = [valid_set, train_set],
``valid_names`` = ['eval', 'train']
and a ``params`` = ('metric':'logloss')
returns: {'train': {'logloss': ['0.48253', '0.35953', ...]},
and a ``params`` = {'metric': 'logloss'}
returns {'train': {'logloss': ['0.48253', '0.35953', ...]},
'eval': {'logloss': ['0.480385', '0.357756', ...]}}.
verbose_eval : bool or int, optional (default=True)
Requires at least one validation data.
If True, the eval metric on the valid set is printed at each boosting stage.
@ -85,9 +86,10 @@ def train(params, train_set, num_boost_round=100,
Example
-------
With ``verbose_eval`` = 4 and at least one item in evals,
With ``verbose_eval`` = 4 and at least one item in ``valid_sets``,
an evaluation metric is printed every 4 (instead of 1) boosting stages.
learning_rates: list, callable or None, optional (default=None)
learning_rates : list, callable or None, optional (default=None)
List of learning rates for each boosting round
or a customized function that calculates ``learning_rate``
in terms of current number of round (e.g. yields learning rate decay).
@ -238,31 +240,30 @@ def train(params, train_set, num_boost_round=100,
return booster
class CVBooster(object):
""""Auxiliary data struct to hold all boosters of CV."""
class _CVBooster(object):
"""Auxiliary data struct to hold all boosters of CV."""
def __init__(self):
self.boosters = []
self.best_iteration = -1
def append(self, booster):
"""add a booster to CVBooster"""
"""Add a booster to _CVBooster."""
self.boosters.append(booster)
def __getattr__(self, name):
"""redirect methods call of CVBooster"""
def handlerFunction(*args, **kwargs):
"""call methods with each booster, and concatenate their results"""
"""Redirect methods call of _CVBooster."""
def handler_function(*args, **kwargs):
"""Call methods with each booster, and concatenate their results."""
ret = []
for booster in self.boosters:
ret.append(getattr(booster, name)(*args, **kwargs))
return ret
return handlerFunction
return handler_function
def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratified=True, shuffle=True):
"""
Make an n-fold list of Booster from random indices.
"""
"""Make a n-fold list of Booster from random indices."""
full_data = full_data.construct()
num_data = full_data.num_data()
if folds is not None:
@ -301,7 +302,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
train_id = [np.concatenate([test_id[i] for i in range_(nfold) if k != i]) for k in range_(nfold)]
folds = zip_(train_id, test_id)
ret = CVBooster()
ret = _CVBooster()
for train_idx, test_idx in folds:
train_set = full_data.subset(train_idx)
valid_set = full_data.subset(test_idx)
@ -317,9 +318,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
def _agg_cv_result(raw_results):
"""
Aggregate cross-validation results.
"""
"""Aggregate cross-validation results."""
cvmap = collections.defaultdict(list)
metric_type = {}
for one_result in raw_results:
@ -356,7 +355,7 @@ def cv(params, train_set, num_boost_round=100,
Number of folds in CV.
stratified : bool, optional (default=True)
Whether to perform stratified sampling.
shuffle: bool, optional (default=True)
shuffle : bool, optional (default=True)
Whether to shuffle before splitting data.
metrics : string, list of strings or None, optional (default=None)
Evaluation metrics to be monitored while CV.
@ -365,10 +364,10 @@ def cv(params, train_set, num_boost_round=100,
Custom objective function.
feval : callable or None, optional (default=None)
Customized evaluation function.
Should accept two parameters: preds, train_data.
Should accept two parameters: preds, train_data,
and return (eval_name, eval_result, is_higher_better) or list of such tuples.
For multi-class task, the preds is group by class_id first, then group by row_id.
If you want to get i-th row preds in j-th class, the access way is preds[j * num_data + i].
Note: should return (eval_name, eval_result, is_higher_better) or list of such tuples.
To ignore the default metric corresponding to the used objective,
set ``metrics`` to the string ``"None"``.
init_model : string, Booster or None, optional (default=None)
@ -384,12 +383,12 @@ def cv(params, train_set, num_boost_round=100,
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
early_stopping_rounds: int or None, optional (default=None)
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping.
CV score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue.
Requires at least one metric. If there's more than one, will check all of them.
Last entry in evaluation history is the one from best iteration.
Last entry in evaluation history is the one from the best iteration.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
and returns transformed versions of those.
@ -400,7 +399,7 @@ def cv(params, train_set, num_boost_round=100,
If int, progress will be displayed at every given ``verbose_eval`` boosting stage.
show_stdv : bool, optional (default=True)
Whether to display the standard deviation in progress.
Results are not affected by this parameter, and always contains std.
Results are not affected by this parameter, and always contain std.
seed : int, optional (default=0)
Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callables or None, optional (default=None)

Просмотреть файл

@ -1,5 +1,5 @@
# coding: utf-8
"""Find the path to lightgbm dynamic library files."""
"""Find the path to LightGBM dynamic library files."""
import os
from platform import system
@ -7,17 +7,19 @@ from platform import system
def find_lib_path():
"""Find the path to LightGBM library files.
Returns
-------
lib_path: list(string)
List of all found library path to LightGBM
lib_path: list of strings
List of all found library paths to LightGBM.
"""
if os.environ.get('LIGHTGBM_BUILD_DOC', False):
# we don't need lib_lightgbm while building docs
return []
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [curr_path, os.path.join(curr_path, '../../'),
dll_path = [curr_path,
os.path.join(curr_path, '../../'),
os.path.join(curr_path, 'compile'),
os.path.join(curr_path, '../compile'),
os.path.join(curr_path, '../../lib/')]
@ -32,5 +34,5 @@ def find_lib_path():
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_path:
dll_path = [os.path.realpath(p) for p in dll_path]
raise Exception('Cannot find lightgbm library in following paths: ' + '\n'.join(dll_path))
raise Exception('Cannot find lightgbm library file in following paths:\n' + '\n'.join(dll_path))
return lib_path

Просмотреть файл

@ -1,6 +1,6 @@
# coding: utf-8
# pylint: disable = C0103
"""Plotting Library."""
"""Plotting library."""
from __future__ import absolute_import
import warnings
@ -15,8 +15,8 @@ from .compat import (MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED, LGBMDeprecationWa
from .sklearn import LGBMModel
def check_not_tuple_of_2_elements(obj, obj_name='obj'):
"""check object is not tuple or does not have 2 elements"""
def _check_not_tuple_of_2_elements(obj, obj_name='obj'):
"""Check object is not tuple or does not have 2 elements."""
if not isinstance(obj, tuple) or len(obj) != 2:
raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
@ -63,7 +63,7 @@ def plot_importance(booster, ax=None, height=0.2,
Figure size.
grid : bool, optional (default=True)
Whether to add a grid for axes.
**kwargs : other parameters
**kwargs
Other parameters passed to ``ax.barh()``.
Returns
@ -96,7 +96,7 @@ def plot_importance(booster, ax=None, height=0.2,
if ax is None:
if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize')
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
ylocs = np.arange(len(values))
@ -109,13 +109,13 @@ def plot_importance(booster, ax=None, height=0.2,
ax.set_yticklabels(labels)
if xlim is not None:
check_not_tuple_of_2_elements(xlim, 'xlim')
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
check_not_tuple_of_2_elements(ylim, 'ylim')
_check_not_tuple_of_2_elements(ylim, 'ylim')
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
@ -194,7 +194,7 @@ def plot_metric(booster, metric=None, dataset_names=None,
if ax is None:
if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize')
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
if dataset_names is None:
@ -229,13 +229,13 @@ def plot_metric(booster, metric=None, dataset_names=None,
ax.legend(loc='best')
if xlim is not None:
check_not_tuple_of_2_elements(xlim, 'xlim')
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
xlim = (0, num_iteration)
ax.set_xlim(xlim)
if ylim is not None:
check_not_tuple_of_2_elements(ylim, 'ylim')
_check_not_tuple_of_2_elements(ylim, 'ylim')
else:
range_result = max_result - min_result
ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
@ -270,7 +270,7 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
if precision is not None and not isinstance(value, string_type) else str(value)
def add(root, parent=None, decision=None):
"""recursively add node or edge"""
"""Recursively add node or edge."""
if 'split_index' in root: # non-leaf
name = 'split{0}'.format(root['split_index'])
if feature_names is not None:
@ -322,7 +322,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
Parameters
----------
booster : Booster or LGBMModel
Booster or LGBMModel instance.
Booster or LGBMModel instance to be converted.
tree_index : int, optional (default=0)
The index of a target tree to convert.
show_info : list of strings or None, optional (default=None)
@ -330,7 +330,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs : other parameters
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
@ -407,7 +407,7 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs : other parameters
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
@ -433,7 +433,7 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
if ax is None:
if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize')
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
graph = create_tree_digraph(booster=booster, tree_index=tree_index,

Просмотреть файл

@ -1,6 +1,6 @@
# coding: utf-8
# pylint: disable = invalid-name, W0105, C0111, C0301
"""Scikit-Learn Wrapper interface for LightGBM."""
"""Scikit-learn wrapper interface for LightGBM."""
from __future__ import absolute_import
import numpy as np
@ -16,18 +16,22 @@ from .engine import train
def _objective_function_wrapper(func):
"""Decorate an objective function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]
and you should group grad and hess in this way as well.
"""Decorate an objective function.
Note
----
For multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]
and you should group grad and hess in this way as well.
Parameters
----------
func : callable
Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group):
y_true : array-like of shape = [n_samples]
The target values.
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class)
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The predicted values.
group : array-like
Group/query data, used for ranking task.
@ -38,14 +42,13 @@ def _objective_function_wrapper(func):
The new objective function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds : array-like of shape = [n_samples] or shape = [n_samples * n_classes]
The predicted values.
dataset : ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``.
preds : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The predicted values.
dataset : Dataset
The training set from which the labels will be extracted using ``dataset.get_label()``.
"""
def inner(preds, dataset):
"""internal function"""
"""Call passed function with appropriate arguments."""
labels = dataset.get_label()
argc = argc_(func)
if argc == 2:
@ -76,24 +79,27 @@ def _objective_function_wrapper(func):
def _eval_function_wrapper(func):
"""Decorate an eval function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i].
"""Decorate an eval function.
Note
----
For multi-class task, the y_pred is group by class_id first, then group by row_id.
If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i].
Parameters
----------
func : callable
Expects a callable with following functions:
``func(y_true, y_pred)``,
``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)``
and return (eval_name->str, eval_result->float, is_bigger_better->Bool):
Expects a callable with following signatures:
``func(y_true, y_pred)``,
``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)``
and returns (eval_name->string, eval_result->float, is_bigger_better->bool):
y_true : array-like of shape = [n_samples]
The target values.
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class)
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The predicted values.
weight : array_like of shape = [n_samples]
weight : array-like of shape = [n_samples]
The weight of samples.
group : array-like
Group/query data, used for ranking task.
@ -104,14 +110,13 @@ def _eval_function_wrapper(func):
The new eval function as expected by ``lightgbm.engine.train``.
The signature is ``new_func(preds, dataset)``:
preds : array-like of shape = [n_samples] or shape = [n_samples * n_classes]
The predicted values.
dataset : ``dataset``
The training set from which the labels will be extracted using
``dataset.get_label()``.
preds : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The predicted values.
dataset : Dataset
The training set from which the labels will be extracted using ``dataset.get_label()``.
"""
def inner(preds, dataset):
"""internal function"""
"""Call passed function with appropriate arguments."""
labels = dataset.get_label()
argc = argc_(func)
if argc == 2:
@ -128,18 +133,18 @@ def _eval_function_wrapper(func):
class LGBMModel(_LGBMModelBase):
"""Implementation of the scikit-learn API for LightGBM."""
def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=100,
subsample_for_bin=200000, objective=None, class_weight=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, importance_type='split', **kwargs):
"""Construct a gradient boosting model.
r"""Construct a gradient boosting model.
Parameters
----------
boosting_type : string, optional (default="gbdt")
boosting_type : string, optional (default='gbdt')
'gbdt', traditional Gradient Boosting Decision Tree.
'dart', Dropouts meet Multiple Additive Regression Trees.
'goss', Gradient-based One-Side Sampling.
@ -168,14 +173,14 @@ class LGBMModel(_LGBMModelBase):
The 'balanced' mode uses the values of y to automatically adjust weights
inversely proportional to class frequencies in the input data as ``n_samples / (n_classes * np.bincount(y))``.
If None, all classes are supposed to have weight one.
Note that these weights will be multiplied with ``sample_weight`` (passed through the fit method)
Note, that these weights will be multiplied with ``sample_weight`` (passed through the ``fit`` method)
if ``sample_weight`` is specified.
min_split_gain : float, optional (default=0.)
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : float, optional (default=1e-3)
Minimum sum of instance weight(hessian) needed in a child(leaf).
Minimum sum of instance weight (hessian) needed in a child (leaf).
min_child_samples : int, optional (default=20)
Minimum number of data need in a child(leaf).
Minimum number of data needed in a child (leaf).
subsample : float, optional (default=1.)
Subsample ratio of the training instance.
subsample_freq : int, optional (default=0)
@ -195,14 +200,15 @@ class LGBMModel(_LGBMModelBase):
Whether to print messages while running boosting.
importance_type : string, optional (default='split')
The type of feature importance to be filled into ``feature_importances_``.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
**kwargs : other parameters
If 'split', result contains numbers of times the feature is used in a model.
If 'gain', result contains total gains of splits which use the feature.
**kwargs
Other parameters for the model.
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
Note
----
\\*\\*kwargs is not supported in sklearn, it may cause unexpected issues.
\*\*kwargs is not supported in sklearn, it may cause unexpected issues.
Attributes
----------
@ -227,8 +233,8 @@ class LGBMModel(_LGBMModelBase):
Note
----
A custom objective function can be provided for the ``objective``
parameter. In this case, it should have the signature
A custom objective function can be provided for the ``objective`` parameter.
In this case, it should have the signature
``objective(y_true, y_pred) -> grad, hess`` or
``objective(y_true, y_pred, group) -> grad, hess``:
@ -282,12 +288,37 @@ class LGBMModel(_LGBMModelBase):
self.set_params(**kwargs)
def get_params(self, deep=True):
"""Get parameters for this estimator.
Parameters
----------
deep : bool, optional (default=True)
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
params = super(LGBMModel, self).get_params(deep=deep)
params.update(self._other_params)
return params
# minor change to support `**kwargs`
def set_params(self, **params):
"""Set the parameters of this estimator.
Parameters
----------
**params
Parameter names with their new values.
Returns
-------
self : object
Returns self.
"""
for key, value in params.items():
setattr(self, key, value)
if hasattr(self, '_' + key):
@ -340,10 +371,10 @@ class LGBMModel(_LGBMModelBase):
If there's more than one, will check all of them. But the training data is ignored anyway.
verbose : bool, optional (default=True)
If True and an evaluation set is used, writes the evaluation progress.
feature_name : list of strings or 'auto', optional (default="auto")
feature_name : list of strings or 'auto', optional (default='auto')
Feature names.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of strings or int, or 'auto', optional (default="auto")
categorical_feature : list of strings or int, or 'auto', optional (default='auto')
Categorical features.
If list of int, interpreted as indices.
If list of strings, interpreted as feature names (need to specify ``feature_name`` as well).
@ -362,15 +393,15 @@ class LGBMModel(_LGBMModelBase):
Note
----
Custom eval function expects a callable with following functions:
Custom eval function expects a callable with following signatures:
``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)`` or
``func(y_true, y_pred, weight, group)``.
Returns (eval_name, eval_result, is_bigger_better) or
list of (eval_name, eval_result, is_bigger_better)
``func(y_true, y_pred, weight, group)``
and returns (eval_name, eval_result, is_bigger_better) or
list of (eval_name, eval_result, is_bigger_better):
y_true : array-like of shape = [n_samples]
The target values.
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class)
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The predicted values.
weight : array-like of shape = [n_samples]
The weight of samples.
@ -539,7 +570,8 @@ class LGBMModel(_LGBMModelBase):
like SHAP interaction values,
you can install shap package (https://github.com/slundberg/shap).
**kwargs : other parameters for the prediction
**kwargs
Other parameters for the prediction.
Returns
-------
@ -629,7 +661,7 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None):
"""Docstring is inherited from the LGBMModel."""
super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight,
init_score=init_score, eval_set=eval_set,
eval_names=eval_names,
@ -656,6 +688,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
_LGBMCheckClassificationTargets(y)
self._le = _LGBMLabelEncoder().fit(y)
@ -704,6 +737,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
def predict(self, X, raw_score=False, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, num_iteration,
pred_leaf, pred_contrib, **kwargs)
if raw_score or pred_leaf or pred_contrib:
@ -739,7 +773,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
like SHAP interaction values,
you can install shap package (https://github.com/slundberg/shap).
**kwargs : other parameters for the prediction
**kwargs
Other parameters for the prediction.
Returns
-------
@ -781,6 +816,7 @@ class LGBMRanker(LGBMModel):
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=[1], early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
"""Docstring is inherited from the LGBMModel."""
# check group data
if group is None:
raise ValueError("Should set group for ranking task")

Просмотреть файл

@ -16,7 +16,8 @@ def find_lib_path():
return []
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [curr_path, os.path.join(curr_path, '../../'),
dll_path = [curr_path,
os.path.join(curr_path, '../../'),
os.path.join(curr_path, '../../python-package/lightgbm/compile'),
os.path.join(curr_path, '../../python-package/compile'),
os.path.join(curr_path, '../../lib/')]
@ -31,7 +32,7 @@ def find_lib_path():
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_path:
dll_path = [os.path.realpath(p) for p in dll_path]
raise Exception('Cannot find lightgbm library in following paths: ' + '\n'.join(dll_path))
raise Exception('Cannot find lightgbm library file in following paths:\n' + '\n'.join(dll_path))
return lib_path