[python-package] add plot metrics (#266)

* add plot metrics

* move 'raise Exception' to check_not_tuple_of_2_elements

* rename 'plot_metrics' to 'plot_metric'

* fix misleading message/docs

* change 'Metrics' in title to 'Metric'

* fix misleading comment
This commit is contained in:
wxchan 2017-01-28 19:10:02 +08:00 коммит произвёл Guolin Ke
Родитель d43a6a3c37
Коммит 58565547e8
5 изменённых файлов: 231 добавлений и 21 удалений

Просмотреть файл

@ -970,6 +970,42 @@ The methods of each Class is in alphabetical order.
-------
ax : matplotlib Axes
####plot_metric(booster, metric=None, dataset_names=None, ax=None, xlim=None, ylim=None, title='Metric during training', xlabel='Iterations', ylabel='auto', figsize=None, grid=True):
Plot one metric during training.
Parameters
----------
booster : dict or LGBMModel
Evals_result recorded by lightgbm.train() or LGBMModel instance
metric : str or None
The metric name to plot.
Only one metric supported because different metrics have various scales.
Pass None to pick `first` one (according to dict hashcode).
dataset_names : None or list of str
List of the dataset names to plot.
Pass None to plot all datasets.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
xlim : tuple of 2 elements
Tuple passed to axes.xlim()
ylim : tuple of 2 elements
Tuple passed to axes.ylim()
title : str
Axes title. Pass None to disable.
xlabel : str
X axis title label. Pass None to disable.
ylabel : str
Y axis title label. Pass None to disable. Pass 'auto' to use `metric`.
figsize : tuple of 2 elements
Figure size
grid : bool
Whether add grid for axes
Returns
-------
ax : matplotlib Axes
####plot_tree(booster, ax=None, tree_index=0, figsize=None, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
Plot specified tree.

Просмотреть файл

@ -11,33 +11,45 @@ except ImportError:
# load or create your dataset
print('Load data...')
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
df_test = pd.read_csv('../regression/regression.test', header=None, sep='\t')
y_train = df_train[0]
X_train = df_train.drop(0, axis=1)
y_test = df_test[0]
X_test = df_test.drop(0, axis=1)
# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)
# specify your configurations as a dict
params = {
'num_leaves': 5,
'metric': ('l1', 'l2'),
'verbose': 0
}
evals_result = {} # to record eval results for plotting
print('Start training...')
# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_sets=[lgb_train, lgb_test],
feature_name=['f' + str(i + 1) for i in range(28)],
categorical_feature=[21])
categorical_feature=[21],
evals_result=evals_result,
verbose_eval=10)
print('Plot metrics during training...')
ax = lgb.plot_metric(evals_result, metric='l1')
plt.show()
print('Plot feature importances...')
# plot feature importances
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
print('Plot 84th tree...')
# plot tree
lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
print('Plot 84th tree...') # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
plt.show()

Просмотреть файл

@ -6,15 +6,17 @@ Contributors: https://github.com/Microsoft/LightGBM/graphs/contributors
from __future__ import absolute_import
from .basic import Dataset, Booster
from .engine import train, cv
from .callback import print_evaluation, record_evaluation, reset_parameter, early_stopping
from .basic import Booster, Dataset
from .callback import (early_stopping, print_evaluation, record_evaluation,
reset_parameter)
from .engine import cv, train
try:
from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker
except ImportError:
pass
try:
from .plotting import plot_importance, plot_tree
from .plotting import plot_importance, plot_metric, plot_tree
except ImportError:
pass
@ -25,4 +27,4 @@ __all__ = ['Dataset', 'Booster',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance', 'plot_tree']
'plot_importance', 'plot_metric', 'plot_tree']

Просмотреть файл

@ -3,6 +3,7 @@
"""Plotting Library."""
from __future__ import absolute_import
from copy import deepcopy
from io import BytesIO
import numpy as np
@ -11,9 +12,10 @@ from .basic import Booster, is_numpy_1d_array
from .sklearn import LGBMModel
def check_not_tuple_of_2_elements(obj):
def check_not_tuple_of_2_elements(obj, obj_name='obj'):
"""check object is not tuple or does not have 2 elements"""
return not isinstance(obj, tuple) or len(obj) != 2
if not isinstance(obj, tuple) or len(obj) != 2:
raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
def plot_importance(booster, ax=None, height=0.2,
@ -86,8 +88,8 @@ def plot_importance(booster, ax=None, height=0.2,
labels, values = zip(*tuples)
if ax is None:
if figsize is not None and check_not_tuple_of_2_elements(figsize):
raise TypeError('figsize must be a tuple of 2 elements.')
if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
ylocs = np.arange(len(values))
@ -100,15 +102,13 @@ def plot_importance(booster, ax=None, height=0.2,
ax.set_yticklabels(labels)
if xlim is not None:
if check_not_tuple_of_2_elements(xlim):
raise TypeError('xlim must be a tuple of 2 elements.')
check_not_tuple_of_2_elements(xlim, 'xlim')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
if check_not_tuple_of_2_elements(ylim):
raise TypeError('ylim must be a tuple of 2 elements.')
check_not_tuple_of_2_elements(ylim, 'ylim')
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
@ -123,6 +123,123 @@ def plot_importance(booster, ax=None, height=0.2,
return ax
def plot_metric(booster, metric=None, dataset_names=None,
ax=None, xlim=None, ylim=None,
title='Metric during training',
xlabel='Iterations', ylabel='auto',
figsize=None, grid=True):
"""Plot one metric during training.
Parameters
----------
booster : dict or LGBMModel
Evals_result recorded by lightgbm.train() or LGBMModel instance
metric : str or None
The metric name to plot.
Only one metric supported because different metrics have various scales.
Pass None to pick `first` one (according to dict hashcode).
dataset_names : None or list of str
List of the dataset names to plot.
Pass None to plot all datasets.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
xlim : tuple of 2 elements
Tuple passed to axes.xlim()
ylim : tuple of 2 elements
Tuple passed to axes.ylim()
title : str
Axes title. Pass None to disable.
xlabel : str
X axis title label. Pass None to disable.
ylabel : str
Y axis title label. Pass None to disable. Pass 'auto' to use `metric`.
figsize : tuple of 2 elements
Figure size
grid : bool
Whether add grid for axes
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You must install matplotlib to plot metric.')
if isinstance(booster, LGBMModel):
eval_results = deepcopy(booster.evals_result_)
elif isinstance(booster, dict):
eval_results = deepcopy(booster)
else:
raise TypeError('booster must be dict or LGBMModel.')
num_data = len(eval_results)
if not num_data:
raise ValueError('eval results cannot be empty.')
if ax is None:
if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
if dataset_names is None:
dataset_names = iter(eval_results.keys())
elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
raise ValueError('dataset_names should be iterable and cannot be empty')
else:
dataset_names = iter(dataset_names)
name = next(dataset_names) # take one as sample
metrics_for_one = eval_results[name]
num_metric = len(metrics_for_one)
if metric is None:
if num_metric > 1:
print('Warning: more than one metric available, picking one to plot.')
metric, results = metrics_for_one.popitem()
else:
if metric not in metrics_for_one:
raise KeyError('No given metric in eval results.')
results = metrics_for_one[metric]
num_iteration, max_result, min_result = len(results), max(results), min(results)
x_ = range(num_iteration)
ax.plot(x_, results, label=name)
for name in dataset_names:
metrics_for_one = eval_results[name]
results = metrics_for_one[metric]
max_result, min_result = max(max(results), max_result), min(min(results), min_result)
ax.plot(x_, results, label=name)
ax.legend(loc='best')
if xlim is not None:
check_not_tuple_of_2_elements(xlim, 'xlim')
else:
xlim = (0, num_iteration)
ax.set_xlim(xlim)
if ylim is not None:
check_not_tuple_of_2_elements(ylim, 'ylim')
else:
range_result = max_result - min_result
ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
ax.set_ylim(ylim)
if ylabel == 'auto':
ylabel = metric
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
def _to_graphviz(graph, tree_info, show_info, feature_names):
"""Convert specified tree to graphviz instance."""
@ -173,7 +290,7 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
Target axes instance. If None, new figure and axes will be created.
tree_index : int, default 0
Specify tree index of target tree.
figsize : tuple
figsize : tuple of 2 elements
Figure size.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
@ -201,8 +318,8 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
raise ImportError('You must install graphviz to plot tree.')
if ax is None:
if figsize is not None and check_not_tuple_of_2_elements(figsize):
raise TypeError('xlim must be a tuple of 2 elements.')
if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
if isinstance(booster, LGBMModel):

Просмотреть файл

@ -63,6 +63,49 @@ class TestBasic(unittest.TestCase):
def test_plot_tree(self):
pass
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
def test_plot_metrics(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
test_data = lgb.Dataset(X_test, y_test, reference=train_data)
params = {
"objective": "binary",
"metric": {"binary_logloss", "binary_error"},
"verbose": -1,
"num_leaves": 3
}
evals_result0 = {}
gbm0 = lgb.train(params, train_data,
valid_sets=[train_data, test_data],
valid_names=['v1', 'v2'],
num_boost_round=10,
evals_result=evals_result0,
verbose_eval=False)
ax0 = lgb.plot_metric(evals_result0)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Metric during training')
self.assertEqual(ax0.get_xlabel(), 'Iterations')
self.assertIn(ax0.get_ylabel(), {'binary_logloss', 'binary_error'})
ax0 = lgb.plot_metric(evals_result0, metric='binary_error')
ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])
evals_result1 = {}
gbm1 = lgb.train(params, train_data,
num_boost_round=10,
evals_result=evals_result1,
verbose_eval=False)
self.assertRaises(ValueError, lgb.plot_metric, evals_result1)
gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
print("----------------------------------------------------------------------")
print("running test_plotting.py")