[python-package] add plot tree (#262)

* add plot tree

* add docs

* add example

* add test

* fix test

* fix decision type

* add show_info

* use feature name if available
This commit is contained in:
wxchan 2017-01-25 19:03:00 +08:00 коммит произвёл Guolin Ke
Родитель 5c5dce3752
Коммит 8980fc7220
6 изменённых файлов: 204 добавлений и 36 удалений

Просмотреть файл

@ -23,12 +23,12 @@ script:
- mkdir build && cd build && cmake .. && make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py && python test_plotting.py
- cd $TRAVIS_BUILD_DIR && pep8 --ignore=E501 .
- rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py && python test_plotting.py
notifications:
email: false

Просмотреть файл

@ -928,22 +928,22 @@ The methods of each Class is in alphabetical order.
##Plotting
####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, grid=True, **kwargs):
####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, figsize=None, grid=True, **kwargs):
Plot model feature importances.
Parameters
----------
booster : Booster, LGBMModel or array
Booster or LGBMModel instance, or array of feature importances
Booster or LGBMModel instance, or array of feature importances.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
height : float
Bar height, passed to ax.barh()
xlim : tuple
Tuple passed to axes.xlim()
ylim : tuple
Tuple passed to axes.ylim()
Bar height, passed to ax.barh().
xlim : tuple of 2 elements
Tuple passed to axes.xlim().
ylim : tuple of 2 elements
Tuple passed to axes.ylim().
title : str
Axes title. Pass None to disable.
xlabel : str
@ -951,18 +951,47 @@ The methods of each Class is in alphabetical order.
ylabel : str
Y axis title label. Pass None to disable.
importance_type : str
How the importance is calculated: "split" or "gain"
"split" is the number of times a feature is used in a model
"gain" is the total gain of splits which use the feature
How the importance is calculated: "split" or "gain".
"split" is the number of times a feature is used in a model.
"gain" is the total gain of splits which use the feature.
max_num_features : int
Max number of top features displayed on plot.
If None or smaller than 1, all features will be displayed.
ignore_zero : bool
Ignore features with zero importance
Ignore features with zero importance.
figsize : tuple of 2 elements
Figure size.
grid : bool
Whether add grid for axes
Whether add grid for axes.
**kwargs :
Other keywords passed to ax.barh()
Other keywords passed to ax.barh().
Returns
-------
ax : matplotlib Axes
####plot_tree(booster, ax=None, tree_index=0, figsize=None, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
Plot specified tree.
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
tree_index : int, default 0
Specify tree index of target tree.
figsize : tuple of 2 elements
Figure size.
graph_attr: dict
Mapping of (attribute, value) pairs for the graph.
node_attr: dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr: dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
Returns
-------

Просмотреть файл

@ -20,6 +20,7 @@ lgb_train = lgb.Dataset(X_train, y_train)
# specify your configurations as a dict
params = {
'num_leaves': 5,
'verbose': 0
}
@ -27,9 +28,16 @@ print('Start training...')
# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=10)
num_boost_round=100,
feature_name=['f' + str(i + 1) for i in range(28)],
categorical_feature=[21])
print('Plot feature importances...')
# plot feature importances
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
print('Plot 84th tree...')
# plot tree
lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
plt.show()

Просмотреть файл

@ -14,7 +14,7 @@ try:
except ImportError:
pass
try:
from .plotting import plot_importance
from .plotting import plot_importance, plot_tree
except ImportError:
pass
@ -25,4 +25,4 @@ __all__ = ['Dataset', 'Booster',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance']
'plot_importance', 'plot_tree']

Просмотреть файл

@ -3,17 +3,24 @@
"""Plotting Library."""
from __future__ import absolute_import
from io import BytesIO
import numpy as np
from .basic import Booster, is_numpy_1d_array
from .sklearn import LGBMModel
def check_not_tuple_of_2_elements(obj):
"""check object is not tuple or does not have 2 elements"""
return not isinstance(obj, tuple) or len(obj) != 2
def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None,
ignore_zero=True, grid=True, **kwargs):
ignore_zero=True, figsize=None, grid=True, **kwargs):
"""Plot model feature importances.
Parameters
@ -24,9 +31,9 @@ def plot_importance(booster, ax=None, height=0.2,
Target axes instance. If None, new figure and axes will be created.
height : float
Bar height, passed to ax.barh()
xlim : tuple
xlim : tuple of 2 elements
Tuple passed to axes.xlim()
ylim : tuple
ylim : tuple of 2 elements
Tuple passed to axes.ylim()
title : str
Axes title. Pass None to disable.
@ -43,6 +50,8 @@ def plot_importance(booster, ax=None, height=0.2,
If None or smaller than 1, all features will be displayed.
ignore_zero : bool
Ignore features with zero importance
figsize : tuple of 2 elements
Figure size
grid : bool
Whether add grid for axes
**kwargs :
@ -55,7 +64,7 @@ def plot_importance(booster, ax=None, height=0.2,
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You must install matplotlib for plotting library')
raise ImportError('You must install matplotlib to plot importance.')
if isinstance(booster, LGBMModel):
importance = booster.booster_.feature_importance(importance_type=importance_type)
@ -64,10 +73,10 @@ def plot_importance(booster, ax=None, height=0.2,
elif is_numpy_1d_array(booster) or isinstance(booster, list):
importance = booster
else:
raise ValueError('booster must be Booster or array instance')
raise TypeError('booster must be Booster, LGBMModel or array instance.')
if not len(importance):
raise ValueError('Booster feature_importances are empty')
raise ValueError('Booster feature_importances are empty.')
tuples = sorted(enumerate(importance), key=lambda x: x[1])
if ignore_zero:
@ -77,7 +86,9 @@ def plot_importance(booster, ax=None, height=0.2,
labels, values = zip(*tuples)
if ax is None:
_, ax = plt.subplots(1, 1)
if figsize is not None and check_not_tuple_of_2_elements(figsize):
raise TypeError('figsize must be a tuple of 2 elements.')
_, ax = plt.subplots(1, 1, figsize=figsize)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
@ -89,15 +100,15 @@ def plot_importance(booster, ax=None, height=0.2,
ax.set_yticklabels(labels)
if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim) != 2:
raise ValueError('xlim must be a tuple of 2 elements')
if check_not_tuple_of_2_elements(xlim):
raise TypeError('xlim must be a tuple of 2 elements.')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
if not isinstance(ylim, tuple) or len(ylim) != 2:
raise ValueError('ylim must be a tuple of 2 elements')
if check_not_tuple_of_2_elements(ylim):
raise TypeError('ylim must be a tuple of 2 elements.')
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
@ -110,3 +121,118 @@ def plot_importance(booster, ax=None, height=0.2,
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
def _to_graphviz(graph, tree_info, show_info, feature_names):
"""Convert specified tree to graphviz instance."""
def add(root, parent=None, decision=None):
"""recursively add node or edge"""
if 'split_index' in root: # non-leaf
name = 'split' + str(root['split_index'])
if feature_names is not None:
label = 'split_feature_name:' + str(feature_names[root['split_feature']])
else:
label = 'split_feature_index:' + str(root['split_feature'])
label += '\nthreshold:' + str(root['threshold'])
for info in show_info:
if info in {'split_gain', 'internal_value', 'internal_count'}:
label += '\n' + info + ':' + str(root[info])
graph.node(name, label=label)
if root['decision_type'] == 'no_greater':
l_dec, r_dec = '<=', '>'
elif root['decision_type'] == 'is':
l_dec, r_dec = 'is', "isn't"
else:
raise ValueError('Invalid decision type in tree model.')
add(root['left_child'], name, l_dec)
add(root['right_child'], name, r_dec)
else: # leaf
name = 'left' + str(root['leaf_index'])
label = 'leaf_value:' + str(root['leaf_value'])
if 'leaf_count' in show_info:
label += '\nleaf_count:' + str(root['leaf_count'])
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)
add(tree_info['tree_structure'])
return graph
def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph_attr=None, node_attr=None, edge_attr=None,
show_info=None):
"""Plot specified tree.
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
tree_index : int, default 0
Specify tree index of target tree.
figsize : tuple
Figure size.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
node_attr : dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
import matplotlib.image as image
except ImportError:
raise ImportError('You must install matplotlib to plot tree.')
try:
from graphviz import Digraph
except ImportError:
raise ImportError('You must install graphviz to plot tree.')
if ax is None:
if figsize is not None and check_not_tuple_of_2_elements(figsize):
raise TypeError('xlim must be a tuple of 2 elements.')
_, ax = plt.subplots(1, 1, figsize=figsize)
if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
model = booster.dump_model()
tree_infos = model['tree_info']
if 'feature_names' in model:
feature_names = model['feature_names']
else:
feature_names = None
if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index]
else:
raise IndexError('tree_index is out of range.')
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
if show_info is None:
show_info = []
ret = _to_graphviz(graph, tree_info, show_info, feature_names)
s = BytesIO()
s.write(ret.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax

Просмотреть файл

@ -7,15 +7,16 @@ from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
try:
from matplotlib.axes import Axes
MATPLOTLIB_INSTALLED = True
import matplotlib
matplotlib.use('Agg')
matplotlib_installed = True
except ImportError:
MATPLOTLIB_INSTALLED = False
matplotlib_installed = False
class TestBasic(unittest.TestCase):
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib not installed')
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
def test_plot_importance(self):
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
@ -27,7 +28,7 @@ class TestBasic(unittest.TestCase):
}
gbm0 = lgb.train(params, train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0)
self.assertIsInstance(ax0, Axes)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Feature importance')
self.assertEqual(ax0.get_xlabel(), 'Feature importance')
self.assertEqual(ax0.get_ylabel(), 'Features')
@ -37,7 +38,7 @@ class TestBasic(unittest.TestCase):
gbm1.fit(X_train, y_train)
ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
self.assertIsInstance(ax1, Axes)
self.assertIsInstance(ax1, matplotlib.axes.Axes)
self.assertEqual(ax1.get_title(), 't')
self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y')
@ -48,7 +49,7 @@ class TestBasic(unittest.TestCase):
ax2 = lgb.plot_importance(gbm0.feature_importance(),
color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, Axes)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
@ -58,6 +59,10 @@ class TestBasic(unittest.TestCase):
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
@unittest.skip('Graphviz are not executables on Travis')
def test_plot_tree(self):
pass
print("----------------------------------------------------------------------")
print("running test_plotting.py")