[python] Allow to register custom logger in Python-package (#3820)

* centralize Python-package logging in one place

* continue

* fix test name

* removed unused import

* enhance test

* fix lint

* hotfix test

* workaround for GPU test

* remove custom logger from Dask-package

* replace one log func with flags by multiple funcs
This commit is contained in:
Nikita Titov 2021-01-24 23:37:45 +03:00 коммит произвёл GitHub
Родитель ac706e10e4
Коммит b7ccdaf066
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 203 добавлений и 53 удалений

Просмотреть файл

@ -55,3 +55,11 @@ Plotting
plot_metric
plot_tree
create_tree_digraph
Utilities
---------
.. autosummary::
:toctree: pythonapi/
register_logger

Просмотреть файл

@ -3,7 +3,7 @@
Contributors: https://github.com/microsoft/LightGBM/graphs/contributors.
"""
from .basic import Booster, Dataset
from .basic import Booster, Dataset, register_logger
from .callback import (early_stopping, print_evaluation, record_evaluation,
reset_parameter)
from .engine import cv, train, CVBooster
@ -28,6 +28,7 @@ if os.path.isfile(os.path.join(dir_path, 'VERSION.txt')):
__version__ = version_file.read().strip()
__all__ = ['Dataset', 'Booster', 'CVBooster',
'register_logger',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',

Просмотреть файл

@ -5,8 +5,10 @@ import ctypes
import json
import os
import warnings
from tempfile import NamedTemporaryFile
from collections import OrderedDict
from functools import wraps
from logging import Logger
from tempfile import NamedTemporaryFile
import numpy as np
import scipy.sparse
@ -15,9 +17,64 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, is_dtype_sparse, DataTa
from .libpath import find_lib_path
class _DummyLogger:
def info(self, msg):
print(msg)
def warning(self, msg):
warnings.warn(msg, stacklevel=3)
_LOGGER = _DummyLogger()
def register_logger(logger):
"""Register custom logger.
Parameters
----------
logger : logging.Logger
Custom logger.
"""
if not isinstance(logger, Logger):
raise TypeError("Logger should inherit logging.Logger class")
global _LOGGER
_LOGGER = logger
def _normalize_native_string(func):
"""Join log messages from native library which come by chunks."""
msg_normalized = []
@wraps(func)
def wrapper(msg):
nonlocal msg_normalized
if msg.strip() == '':
msg = ''.join(msg_normalized)
msg_normalized = []
return func(msg)
else:
msg_normalized.append(msg)
return wrapper
def _log_info(msg):
_LOGGER.info(msg)
def _log_warning(msg):
_LOGGER.warning(msg)
@_normalize_native_string
def _log_native(msg):
_LOGGER.info(msg)
def _log_callback(msg):
"""Redirect logs from native library into Python console."""
print("{0:s}".format(msg.decode('utf-8')), end='')
"""Redirect logs from native library into Python."""
_log_native("{0:s}".format(msg.decode('utf-8')))
def _load_lib():
@ -329,8 +386,8 @@ def convert_from_sliced_object(data):
"""Fix the memory of multi-dimensional sliced object."""
if isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous:
warnings.warn("Usage of np.ndarray subset (sliced data) is not recommended "
"due to it will double the peak memory cost in LightGBM.")
_log_warning("Usage of np.ndarray subset (sliced data) is not recommended "
"due to it will double the peak memory cost in LightGBM.")
return np.copy(data)
return data
@ -620,7 +677,7 @@ class _InnerPredictor:
preds, nrow = self.__pred_for_np2d(data.to_numpy(), start_iteration, num_iteration, predict_type)
else:
try:
warnings.warn('Converting data to scipy sparse matrix.')
_log_warning('Converting data to scipy sparse matrix.')
csr = scipy.sparse.csr_matrix(data)
except BaseException:
raise TypeError('Cannot predict data for type {}'.format(type(data).__name__))
@ -1103,9 +1160,9 @@ class Dataset:
.co_varnames[:getattr(self.__class__, '_lazy_init').__code__.co_argcount])
for key, _ in params.items():
if key in args_names:
warnings.warn('{0} keyword has been found in `params` and will be ignored.\n'
'Please use {0} argument of the Dataset constructor to pass this parameter.'
.format(key))
_log_warning('{0} keyword has been found in `params` and will be ignored.\n'
'Please use {0} argument of the Dataset constructor to pass this parameter.'
.format(key))
# user can set verbose with params, it has higher priority
if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and silent:
params["verbose"] = -1
@ -1126,7 +1183,7 @@ class Dataset:
if categorical_indices:
for cat_alias in _ConfigAliases.get("categorical_feature"):
if cat_alias in params:
warnings.warn('{} in param dict is overridden.'.format(cat_alias))
_log_warning('{} in param dict is overridden.'.format(cat_alias))
params.pop(cat_alias, None)
params['categorical_column'] = sorted(categorical_indices)
@ -1172,7 +1229,7 @@ class Dataset:
self.set_group(group)
if isinstance(predictor, _InnerPredictor):
if self._predictor is None and init_score is not None:
warnings.warn("The init_score will be overridden by the prediction of init_model.")
_log_warning("The init_score will be overridden by the prediction of init_model.")
self._set_init_score_by_predictor(predictor, data)
elif init_score is not None:
self.set_init_score(init_score)
@ -1314,7 +1371,7 @@ class Dataset:
if self.reference is not None:
reference_params = self.reference.get_params()
if self.get_params() != reference_params:
warnings.warn('Overriding the parameters from Reference Dataset.')
_log_warning('Overriding the parameters from Reference Dataset.')
self._update_params(reference_params)
if self.used_indices is None:
# create valid
@ -1583,11 +1640,11 @@ class Dataset:
self.categorical_feature = categorical_feature
return self._free_handle()
elif categorical_feature == 'auto':
warnings.warn('Using categorical_feature in Dataset.')
_log_warning('Using categorical_feature in Dataset.')
return self
else:
warnings.warn('categorical_feature in Dataset is overridden.\n'
'New categorical_feature is {}'.format(sorted(list(categorical_feature))))
_log_warning('categorical_feature in Dataset is overridden.\n'
'New categorical_feature is {}'.format(sorted(list(categorical_feature))))
self.categorical_feature = categorical_feature
return self._free_handle()
else:
@ -1840,8 +1897,8 @@ class Dataset:
elif isinstance(self.data, DataTable):
self.data = self.data[self.used_indices, :]
else:
warnings.warn("Cannot subset {} type of raw data.\n"
"Returning original raw data".format(type(self.data).__name__))
_log_warning("Cannot subset {} type of raw data.\n"
"Returning original raw data".format(type(self.data).__name__))
self.need_slice = False
if self.data is None:
raise LightGBMError("Cannot call `get_data` after freed raw data, "
@ -2011,10 +2068,10 @@ class Dataset:
old_self_data_type)
err_msg += ("Set free_raw_data=False when construct Dataset to avoid this"
if was_none else "Freeing raw data")
warnings.warn(err_msg)
_log_warning(err_msg)
self.feature_name = self.get_feature_name()
warnings.warn("Reseting categorical features.\n"
"You can set new categorical features via ``set_categorical_feature`` method")
_log_warning("Reseting categorical features.\n"
"You can set new categorical features via ``set_categorical_feature`` method")
self.categorical_feature = "auto"
self.pandas_categorical = None
return self
@ -2834,7 +2891,7 @@ class Booster:
self.handle,
ctypes.byref(out_num_class)))
if verbose:
print('Finished loading model, total used %d iterations' % int(out_num_iterations.value))
_log_info('Finished loading model, total used %d iterations' % int(out_num_iterations.value))
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
return self

Просмотреть файл

@ -1,10 +1,9 @@
# coding: utf-8
"""Callbacks library."""
import collections
import warnings
from operator import gt, lt
from .basic import _ConfigAliases
from .basic import _ConfigAliases, _log_info, _log_warning
class EarlyStopException(Exception):
@ -67,7 +66,7 @@ def print_evaluation(period=1, show_stdv=True):
def _callback(env):
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
print('[%d]\t%s' % (env.iteration + 1, result))
_log_info('[%d]\t%s' % (env.iteration + 1, result))
_callback.order = 10
return _callback
@ -180,15 +179,14 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not enabled[0]:
warnings.warn('Early stopping is not available in dart mode')
_log_warning('Early stopping is not available in dart mode')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')
if verbose:
msg = "Training until validation scores don't improve for {} rounds"
print(msg.format(stopping_rounds))
_log_info("Training until validation scores don't improve for {} rounds".format(stopping_rounds))
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
@ -205,10 +203,10 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
def _final_iteration_check(env, eval_name_splitted, i):
if env.iteration == env.end_iteration - 1:
if verbose:
print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
_log_info('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
if first_metric_only:
print("Evaluated only: {}".format(eval_name_splitted[-1]))
_log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env):
@ -232,10 +230,10 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - best_iter[i] >= stopping_rounds:
if verbose:
print('Early stopping, best iteration is:\n[%d]\t%s' % (
_log_info('Early stopping, best iteration is:\n[%d]\t%s' % (
best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
if first_metric_only:
print("Evaluated only: {}".format(eval_name_splitted[-1]))
_log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30

Просмотреть файл

@ -6,7 +6,6 @@ Dask.Array and Dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import logging
import socket
from collections import defaultdict
from copy import deepcopy
@ -22,11 +21,9 @@ from dask import dataframe as dd
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait
from .basic import _ConfigAliases, _LIB, _safe_call
from .basic import _ConfigAliases, _LIB, _log_warning, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
logger = logging.getLogger(__name__)
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port.
@ -257,10 +254,10 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
'voting_parallel'
}
if tree_learner is None:
logger.warning('Parameter tree_learner not set. Using "data" as default')
_log_warning('Parameter tree_learner not set. Using "data" as default')
params['tree_learner'] = 'data'
elif tree_learner.lower() not in allowed_tree_learners:
logger.warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner)
params['tree_learner'] = 'data'
local_listen_port = 12400

Просмотреть файл

@ -2,13 +2,12 @@
"""Library with training routines of LightGBM."""
import collections
import copy
import warnings
from operator import attrgetter
import numpy as np
from . import callback
from .basic import Booster, Dataset, LightGBMError, _ConfigAliases, _InnerPredictor
from .basic import Booster, Dataset, LightGBMError, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold
@ -146,12 +145,12 @@ def train(params, train_set, num_boost_round=100,
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
num_boost_round = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
@ -525,12 +524,12 @@ def cv(params, train_set, num_boost_round=100,
params['objective'] = 'none'
for alias in _ConfigAliases.get("num_iterations"):
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
_log_warning("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)

Просмотреть файл

@ -1,12 +1,11 @@
# coding: utf-8
"""Plotting library."""
import warnings
from copy import deepcopy
from io import BytesIO
import numpy as np
from .basic import Booster
from .basic import Booster, _log_warning
from .compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from .sklearn import LGBMModel
@ -326,8 +325,7 @@ def plot_metric(booster, metric=None, dataset_names=None,
num_metric = len(metrics_for_one)
if metric is None:
if num_metric > 1:
msg = "More than one metric available, picking one to plot."
warnings.warn(msg, stacklevel=2)
_log_warning("More than one metric available, picking one to plot.")
metric, results = metrics_for_one.popitem()
else:
if metric not in metrics_for_one:

Просмотреть файл

@ -1,13 +1,12 @@
# coding: utf-8
"""Scikit-learn wrapper interface for LightGBM."""
import copy
import warnings
from inspect import signature
import numpy as np
from .basic import Dataset, LightGBMError, _ConfigAliases
from .basic import Dataset, LightGBMError, _ConfigAliases, _log_warning
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight,
@ -931,9 +930,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
"""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
warnings.warn("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
"Returning raw scores instead.")
_log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
"Returning raw scores instead.")
return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib:
return result

Просмотреть файл

@ -0,0 +1,93 @@
# coding: utf-8
import logging
import numpy as np
import lightgbm as lgb
def test_register_logger(tmp_path):
logger = logging.getLogger("LightGBM")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s | %(message)s')
log_filename = str(tmp_path / "LightGBM_test_logger.log")
file_handler = logging.FileHandler(log_filename, mode="w", encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
def dummy_metric(_, __):
logger.debug('In dummy_metric')
return 'dummy_metric', 1, True
lgb.register_logger(logger)
X = np.array([[1, 2, 3],
[1, 2, 4],
[1, 2, 4],
[1, 2, 3]],
dtype=np.float32)
y = np.array([0, 1, 1, 0])
lgb_data = lgb.Dataset(X, y)
eval_records = {}
lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']},
lgb_data, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_data], evals_result=eval_records,
categorical_feature=[1], early_stopping_rounds=4, verbose_eval=2)
lgb.plot_metric(eval_records)
expected_log = r"""
WARNING | categorical_feature in Dataset is overridden.
New categorical_feature is [1]
INFO | [LightGBM] [Warning] There are no meaningful features, as all feature values are constant.
INFO | [LightGBM] [Info] Number of positive: 2, number of negative: 2
INFO | [LightGBM] [Info] Total Bins 0
INFO | [LightGBM] [Info] Number of data points in the train set: 4, number of used features: 0
INFO | [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | Training until validation scores don't improve for 4 rounds
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [2] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [4] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [6] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [8] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [10] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | Did not meet early stopping. Best iteration is:
[1] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
WARNING | More than one metric available, picking one to plot.
""".strip()
gpu_lines = [
"INFO | [LightGBM] [Info] This is the GPU trainer",
"INFO | [LightGBM] [Info] Using GPU Device:",
"INFO | [LightGBM] [Info] Compiling OpenCL Kernel with 16 bins...",
"INFO | [LightGBM] [Info] GPU programs have been built",
"INFO | [LightGBM] [Warning] GPU acceleration is disabled because no non-trivial dense features can be found"
]
with open(log_filename, "rt", encoding="utf-8") as f:
actual_log = f.read().strip()
actual_log_wo_gpu_stuff = []
for line in actual_log.split("\n"):
if not any(line.startswith(gpu_line) for gpu_line in gpu_lines):
actual_log_wo_gpu_stuff.append(line)
assert "\n".join(actual_log_wo_gpu_stuff) == expected_log