python-package: add graphviz.Digraph parameters (#400)

* python-package: add graphviz.Digraph parameters

* examples: add a plottig example with graphviz

* fix tree index in print
This commit is contained in:
Tsukasa OMOTO 2017-04-12 22:36:51 +09:00 коммит произвёл Guolin Ke
Родитель 223b164e86
Коммит 053d9d8a43
3 изменённых файлов: 71 добавлений и 13 удалений

Просмотреть файл

@ -1046,24 +1046,45 @@ The methods of each Class is in alphabetical order.
-------
ax : matplotlib Axes
#### create_tree_digraph(booster, tree_index=0, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
#### create_tree_digraph(booster, tree_index=0, show_info=None, name=None, comment=None, filename=None, directory=None, format=None, engine=None, encoding=None, graph_attr=None, node_attr=None, edge_attr=None, body=None, strict=False):
Create a digraph of specified tree.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
tree_index : int, default 0
Specify tree index of target tree.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
name : str
Graph name used in the source code.
comment : str
Comment added to the first line of the source.
filename : str
Filename for saving the source (defaults to name + '.gv').
directory : str
(Sub)directory for source saving and rendering.
format : str
Rendering output format ('pdf', 'png', ...).
engine : str
Layout command used ('dot', 'neato', ...).
encoding : str
Encoding for saving the source.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
node_attr : dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
body : list of str
Iterable of lines to add to the graph body.
strict : bool
Iterable of lines to add to the graph body.
Returns
-------

Просмотреть файл

@ -53,3 +53,7 @@ plt.show()
print('Plot 84th tree...') # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
plt.show()
print('Plot 84th tree with graphviz...')
graph = lgb.create_tree_digraph(gbm, tree_index=83, name='Tree84')
graph.render(view=True)

Просмотреть файл

@ -242,8 +242,14 @@ def plot_metric(booster, metric=None, dataset_names=None,
def _to_graphviz(tree_info, show_info, feature_names,
graph_attr=None, node_attr=None, edge_attr=None):
"""Convert specified tree to graphviz instance."""
name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
"""Convert specified tree to graphviz instance.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
try:
from graphviz import Digraph
except ImportError:
@ -279,31 +285,56 @@ def _to_graphviz(tree_info, show_info, feature_names,
if parent is not None:
graph.edge(parent, name, decision)
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
graph = Digraph(name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
add(tree_info['tree_structure'])
return graph
def create_tree_digraph(booster, tree_index=0, graph_attr=None,
node_attr=None, edge_attr=None, show_info=None):
def create_tree_digraph(booster, tree_index=0, show_info=None,
name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
"""Create a digraph of specified tree.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
tree_index : int, default 0
Specify tree index of target tree.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
name : str
Graph name used in the source code.
comment : str
Comment added to the first line of the source.
filename : str
Filename for saving the source (defaults to name + '.gv').
directory : str
(Sub)directory for source saving and rendering.
format : str
Rendering output format ('pdf', 'png', ...).
engine : str
Layout command used ('dot', 'neato', ...).
encoding : str
Encoding for saving the source.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
node_attr : dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
body : list of str
Iterable of lines to add to the graph body.
strict : bool
Iterable of lines to add to the graph body.
Returns
-------
@ -330,7 +361,9 @@ def create_tree_digraph(booster, tree_index=0, graph_attr=None,
show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names,
graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
return graph