зеркало из https://github.com/microsoft/LightGBM.git
[python-package] add plot tree (#262)
* add plot tree * add docs * add example * add test * fix test * fix decision type * add show_info * use feature name if available
This commit is contained in:
Родитель
5c5dce3752
Коммит
8980fc7220
|
@ -23,12 +23,12 @@ script:
|
|||
- mkdir build && cd build && cmake .. && make -j
|
||||
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
|
||||
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
|
||||
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py
|
||||
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py && python test_plotting.py
|
||||
- cd $TRAVIS_BUILD_DIR && pep8 --ignore=E501 .
|
||||
- rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j
|
||||
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
|
||||
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
|
||||
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py
|
||||
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py && python test_plotting.py
|
||||
|
||||
notifications:
|
||||
email: false
|
||||
|
|
|
@ -928,22 +928,22 @@ The methods of each Class is in alphabetical order.
|
|||
|
||||
##Plotting
|
||||
|
||||
####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, grid=True, **kwargs):
|
||||
####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, figsize=None, grid=True, **kwargs):
|
||||
|
||||
Plot model feature importances.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster, LGBMModel or array
|
||||
Booster or LGBMModel instance, or array of feature importances
|
||||
Booster or LGBMModel instance, or array of feature importances.
|
||||
ax : matplotlib Axes
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
height : float
|
||||
Bar height, passed to ax.barh()
|
||||
xlim : tuple
|
||||
Tuple passed to axes.xlim()
|
||||
ylim : tuple
|
||||
Tuple passed to axes.ylim()
|
||||
Bar height, passed to ax.barh().
|
||||
xlim : tuple of 2 elements
|
||||
Tuple passed to axes.xlim().
|
||||
ylim : tuple of 2 elements
|
||||
Tuple passed to axes.ylim().
|
||||
title : str
|
||||
Axes title. Pass None to disable.
|
||||
xlabel : str
|
||||
|
@ -951,18 +951,47 @@ The methods of each Class is in alphabetical order.
|
|||
ylabel : str
|
||||
Y axis title label. Pass None to disable.
|
||||
importance_type : str
|
||||
How the importance is calculated: "split" or "gain"
|
||||
"split" is the number of times a feature is used in a model
|
||||
"gain" is the total gain of splits which use the feature
|
||||
How the importance is calculated: "split" or "gain".
|
||||
"split" is the number of times a feature is used in a model.
|
||||
"gain" is the total gain of splits which use the feature.
|
||||
max_num_features : int
|
||||
Max number of top features displayed on plot.
|
||||
If None or smaller than 1, all features will be displayed.
|
||||
ignore_zero : bool
|
||||
Ignore features with zero importance
|
||||
Ignore features with zero importance.
|
||||
figsize : tuple of 2 elements
|
||||
Figure size.
|
||||
grid : bool
|
||||
Whether add grid for axes
|
||||
Whether add grid for axes.
|
||||
**kwargs :
|
||||
Other keywords passed to ax.barh()
|
||||
Other keywords passed to ax.barh().
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
|
||||
####plot_tree(booster, ax=None, tree_index=0, figsize=None, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
|
||||
Plot specified tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster, LGBMModel
|
||||
Booster or LGBMModel instance.
|
||||
ax : matplotlib Axes
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
tree_index : int, default 0
|
||||
Specify tree index of target tree.
|
||||
figsize : tuple of 2 elements
|
||||
Figure size.
|
||||
graph_attr: dict
|
||||
Mapping of (attribute, value) pairs for the graph.
|
||||
node_attr: dict
|
||||
Mapping of (attribute, value) pairs set for all nodes.
|
||||
edge_attr: dict
|
||||
Mapping of (attribute, value) pairs set for all edges.
|
||||
show_info : list
|
||||
Information shows on nodes.
|
||||
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -20,6 +20,7 @@ lgb_train = lgb.Dataset(X_train, y_train)
|
|||
|
||||
# specify your configurations as a dict
|
||||
params = {
|
||||
'num_leaves': 5,
|
||||
'verbose': 0
|
||||
}
|
||||
|
||||
|
@ -27,9 +28,16 @@ print('Start training...')
|
|||
# train
|
||||
gbm = lgb.train(params,
|
||||
lgb_train,
|
||||
num_boost_round=10)
|
||||
num_boost_round=100,
|
||||
feature_name=['f' + str(i + 1) for i in range(28)],
|
||||
categorical_feature=[21])
|
||||
|
||||
print('Plot feature importances...')
|
||||
# plot feature importances
|
||||
ax = lgb.plot_importance(gbm, max_num_features=10)
|
||||
plt.show()
|
||||
|
||||
print('Plot 84th tree...')
|
||||
# plot tree
|
||||
lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
|
||||
plt.show()
|
||||
|
|
|
@ -14,7 +14,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from .plotting import plot_importance
|
||||
from .plotting import plot_importance, plot_tree
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -25,4 +25,4 @@ __all__ = ['Dataset', 'Booster',
|
|||
'train', 'cv',
|
||||
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
|
||||
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
|
||||
'plot_importance']
|
||||
'plot_importance', 'plot_tree']
|
||||
|
|
|
@ -3,17 +3,24 @@
|
|||
"""Plotting Library."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basic import Booster, is_numpy_1d_array
|
||||
from .sklearn import LGBMModel
|
||||
|
||||
|
||||
def check_not_tuple_of_2_elements(obj):
|
||||
"""check object is not tuple or does not have 2 elements"""
|
||||
return not isinstance(obj, tuple) or len(obj) != 2
|
||||
|
||||
|
||||
def plot_importance(booster, ax=None, height=0.2,
|
||||
xlim=None, ylim=None, title='Feature importance',
|
||||
xlabel='Feature importance', ylabel='Features',
|
||||
importance_type='split', max_num_features=None,
|
||||
ignore_zero=True, grid=True, **kwargs):
|
||||
ignore_zero=True, figsize=None, grid=True, **kwargs):
|
||||
"""Plot model feature importances.
|
||||
|
||||
Parameters
|
||||
|
@ -24,9 +31,9 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
Target axes instance. If None, new figure and axes will be created.
|
||||
height : float
|
||||
Bar height, passed to ax.barh()
|
||||
xlim : tuple
|
||||
xlim : tuple of 2 elements
|
||||
Tuple passed to axes.xlim()
|
||||
ylim : tuple
|
||||
ylim : tuple of 2 elements
|
||||
Tuple passed to axes.ylim()
|
||||
title : str
|
||||
Axes title. Pass None to disable.
|
||||
|
@ -43,6 +50,8 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
If None or smaller than 1, all features will be displayed.
|
||||
ignore_zero : bool
|
||||
Ignore features with zero importance
|
||||
figsize : tuple of 2 elements
|
||||
Figure size
|
||||
grid : bool
|
||||
Whether add grid for axes
|
||||
**kwargs :
|
||||
|
@ -55,7 +64,7 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib for plotting library')
|
||||
raise ImportError('You must install matplotlib to plot importance.')
|
||||
|
||||
if isinstance(booster, LGBMModel):
|
||||
importance = booster.booster_.feature_importance(importance_type=importance_type)
|
||||
|
@ -64,10 +73,10 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
elif is_numpy_1d_array(booster) or isinstance(booster, list):
|
||||
importance = booster
|
||||
else:
|
||||
raise ValueError('booster must be Booster or array instance')
|
||||
raise TypeError('booster must be Booster, LGBMModel or array instance.')
|
||||
|
||||
if not len(importance):
|
||||
raise ValueError('Booster feature_importances are empty')
|
||||
raise ValueError('Booster feature_importances are empty.')
|
||||
|
||||
tuples = sorted(enumerate(importance), key=lambda x: x[1])
|
||||
if ignore_zero:
|
||||
|
@ -77,7 +86,9 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
labels, values = zip(*tuples)
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
if figsize is not None and check_not_tuple_of_2_elements(figsize):
|
||||
raise TypeError('figsize must be a tuple of 2 elements.')
|
||||
_, ax = plt.subplots(1, 1, figsize=figsize)
|
||||
|
||||
ylocs = np.arange(len(values))
|
||||
ax.barh(ylocs, values, align='center', height=height, **kwargs)
|
||||
|
@ -89,15 +100,15 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
ax.set_yticklabels(labels)
|
||||
|
||||
if xlim is not None:
|
||||
if not isinstance(xlim, tuple) or len(xlim) != 2:
|
||||
raise ValueError('xlim must be a tuple of 2 elements')
|
||||
if check_not_tuple_of_2_elements(xlim):
|
||||
raise TypeError('xlim must be a tuple of 2 elements.')
|
||||
else:
|
||||
xlim = (0, max(values) * 1.1)
|
||||
ax.set_xlim(xlim)
|
||||
|
||||
if ylim is not None:
|
||||
if not isinstance(ylim, tuple) or len(ylim) != 2:
|
||||
raise ValueError('ylim must be a tuple of 2 elements')
|
||||
if check_not_tuple_of_2_elements(ylim):
|
||||
raise TypeError('ylim must be a tuple of 2 elements.')
|
||||
else:
|
||||
ylim = (-1, len(values))
|
||||
ax.set_ylim(ylim)
|
||||
|
@ -110,3 +121,118 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||
ax.set_ylabel(ylabel)
|
||||
ax.grid(grid)
|
||||
return ax
|
||||
|
||||
|
||||
def _to_graphviz(graph, tree_info, show_info, feature_names):
|
||||
"""Convert specified tree to graphviz instance."""
|
||||
|
||||
def add(root, parent=None, decision=None):
|
||||
"""recursively add node or edge"""
|
||||
if 'split_index' in root: # non-leaf
|
||||
name = 'split' + str(root['split_index'])
|
||||
if feature_names is not None:
|
||||
label = 'split_feature_name:' + str(feature_names[root['split_feature']])
|
||||
else:
|
||||
label = 'split_feature_index:' + str(root['split_feature'])
|
||||
label += '\nthreshold:' + str(root['threshold'])
|
||||
for info in show_info:
|
||||
if info in {'split_gain', 'internal_value', 'internal_count'}:
|
||||
label += '\n' + info + ':' + str(root[info])
|
||||
graph.node(name, label=label)
|
||||
if root['decision_type'] == 'no_greater':
|
||||
l_dec, r_dec = '<=', '>'
|
||||
elif root['decision_type'] == 'is':
|
||||
l_dec, r_dec = 'is', "isn't"
|
||||
else:
|
||||
raise ValueError('Invalid decision type in tree model.')
|
||||
add(root['left_child'], name, l_dec)
|
||||
add(root['right_child'], name, r_dec)
|
||||
else: # leaf
|
||||
name = 'left' + str(root['leaf_index'])
|
||||
label = 'leaf_value:' + str(root['leaf_value'])
|
||||
if 'leaf_count' in show_info:
|
||||
label += '\nleaf_count:' + str(root['leaf_count'])
|
||||
graph.node(name, label=label)
|
||||
if parent is not None:
|
||||
graph.edge(parent, name, decision)
|
||||
|
||||
add(tree_info['tree_structure'])
|
||||
return graph
|
||||
|
||||
|
||||
def plot_tree(booster, ax=None, tree_index=0, figsize=None,
|
||||
graph_attr=None, node_attr=None, edge_attr=None,
|
||||
show_info=None):
|
||||
"""Plot specified tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster, LGBMModel
|
||||
Booster or LGBMModel instance.
|
||||
ax : matplotlib Axes
|
||||
Target axes instance. If None, new figure and axes will be created.
|
||||
tree_index : int, default 0
|
||||
Specify tree index of target tree.
|
||||
figsize : tuple
|
||||
Figure size.
|
||||
graph_attr : dict
|
||||
Mapping of (attribute, value) pairs for the graph.
|
||||
node_attr : dict
|
||||
Mapping of (attribute, value) pairs set for all nodes.
|
||||
edge_attr : dict
|
||||
Mapping of (attribute, value) pairs set for all edges.
|
||||
show_info : list
|
||||
Information shows on nodes.
|
||||
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ax : matplotlib Axes
|
||||
"""
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.image as image
|
||||
except ImportError:
|
||||
raise ImportError('You must install matplotlib to plot tree.')
|
||||
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
raise ImportError('You must install graphviz to plot tree.')
|
||||
|
||||
if ax is None:
|
||||
if figsize is not None and check_not_tuple_of_2_elements(figsize):
|
||||
raise TypeError('xlim must be a tuple of 2 elements.')
|
||||
_, ax = plt.subplots(1, 1, figsize=figsize)
|
||||
|
||||
if isinstance(booster, LGBMModel):
|
||||
booster = booster.booster_
|
||||
elif not isinstance(booster, Booster):
|
||||
raise TypeError('booster must be Booster or LGBMModel.')
|
||||
|
||||
model = booster.dump_model()
|
||||
tree_infos = model['tree_info']
|
||||
if 'feature_names' in model:
|
||||
feature_names = model['feature_names']
|
||||
else:
|
||||
feature_names = None
|
||||
|
||||
if tree_index < len(tree_infos):
|
||||
tree_info = tree_infos[tree_index]
|
||||
else:
|
||||
raise IndexError('tree_index is out of range.')
|
||||
|
||||
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
|
||||
|
||||
if show_info is None:
|
||||
show_info = []
|
||||
ret = _to_graphviz(graph, tree_info, show_info, feature_names)
|
||||
|
||||
s = BytesIO()
|
||||
s.write(ret.pipe(format='png'))
|
||||
s.seek(0)
|
||||
img = image.imread(s)
|
||||
|
||||
ax.imshow(img)
|
||||
ax.axis('off')
|
||||
return ax
|
||||
|
|
|
@ -7,15 +7,16 @@ from sklearn.datasets import load_breast_cancer
|
|||
from sklearn.model_selection import train_test_split
|
||||
|
||||
try:
|
||||
from matplotlib.axes import Axes
|
||||
MATPLOTLIB_INSTALLED = True
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
matplotlib_installed = True
|
||||
except ImportError:
|
||||
MATPLOTLIB_INSTALLED = False
|
||||
matplotlib_installed = False
|
||||
|
||||
|
||||
class TestBasic(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib not installed')
|
||||
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
|
||||
def test_plot_importance(self):
|
||||
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
|
||||
train_data = lgb.Dataset(X_train, y_train)
|
||||
|
@ -27,7 +28,7 @@ class TestBasic(unittest.TestCase):
|
|||
}
|
||||
gbm0 = lgb.train(params, train_data, num_boost_round=10)
|
||||
ax0 = lgb.plot_importance(gbm0)
|
||||
self.assertIsInstance(ax0, Axes)
|
||||
self.assertIsInstance(ax0, matplotlib.axes.Axes)
|
||||
self.assertEqual(ax0.get_title(), 'Feature importance')
|
||||
self.assertEqual(ax0.get_xlabel(), 'Feature importance')
|
||||
self.assertEqual(ax0.get_ylabel(), 'Features')
|
||||
|
@ -37,7 +38,7 @@ class TestBasic(unittest.TestCase):
|
|||
gbm1.fit(X_train, y_train)
|
||||
|
||||
ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
|
||||
self.assertIsInstance(ax1, Axes)
|
||||
self.assertIsInstance(ax1, matplotlib.axes.Axes)
|
||||
self.assertEqual(ax1.get_title(), 't')
|
||||
self.assertEqual(ax1.get_xlabel(), 'x')
|
||||
self.assertEqual(ax1.get_ylabel(), 'y')
|
||||
|
@ -48,7 +49,7 @@ class TestBasic(unittest.TestCase):
|
|||
ax2 = lgb.plot_importance(gbm0.feature_importance(),
|
||||
color=['r', 'y', 'g', 'b'],
|
||||
title=None, xlabel=None, ylabel=None)
|
||||
self.assertIsInstance(ax2, Axes)
|
||||
self.assertIsInstance(ax2, matplotlib.axes.Axes)
|
||||
self.assertEqual(ax2.get_title(), '')
|
||||
self.assertEqual(ax2.get_xlabel(), '')
|
||||
self.assertEqual(ax2.get_ylabel(), '')
|
||||
|
@ -58,6 +59,10 @@ class TestBasic(unittest.TestCase):
|
|||
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
|
||||
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
|
||||
|
||||
@unittest.skip('Graphviz are not executables on Travis')
|
||||
def test_plot_tree(self):
|
||||
pass
|
||||
|
||||
|
||||
print("----------------------------------------------------------------------")
|
||||
print("running test_plotting.py")
|
||||
|
|
Загрузка…
Ссылка в новой задаче