зеркало из https://github.com/microsoft/LightGBM.git
python-package: add graphviz.Digraph parameters (#400)
* python-package: add graphviz.Digraph parameters * examples: add a plottig example with graphviz * fix tree index in print
This commit is contained in:
Родитель
223b164e86
Коммит
053d9d8a43
|
@ -1046,24 +1046,45 @@ The methods of each Class is in alphabetical order.
|
|||
-------
|
||||
ax : matplotlib Axes
|
||||
|
||||
#### create_tree_digraph(booster, tree_index=0, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
|
||||
#### create_tree_digraph(booster, tree_index=0, show_info=None, name=None, comment=None, filename=None, directory=None, format=None, engine=None, encoding=None, graph_attr=None, node_attr=None, edge_attr=None, body=None, strict=False):
|
||||
Create a digraph of specified tree.
|
||||
|
||||
See:
|
||||
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster, LGBMModel
|
||||
Booster or LGBMModel instance.
|
||||
tree_index : int, default 0
|
||||
Specify tree index of target tree.
|
||||
show_info : list
|
||||
Information shows on nodes.
|
||||
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
|
||||
name : str
|
||||
Graph name used in the source code.
|
||||
comment : str
|
||||
Comment added to the first line of the source.
|
||||
filename : str
|
||||
Filename for saving the source (defaults to name + '.gv').
|
||||
directory : str
|
||||
(Sub)directory for source saving and rendering.
|
||||
format : str
|
||||
Rendering output format ('pdf', 'png', ...).
|
||||
engine : str
|
||||
Layout command used ('dot', 'neato', ...).
|
||||
encoding : str
|
||||
Encoding for saving the source.
|
||||
graph_attr : dict
|
||||
Mapping of (attribute, value) pairs for the graph.
|
||||
node_attr : dict
|
||||
Mapping of (attribute, value) pairs set for all nodes.
|
||||
edge_attr : dict
|
||||
Mapping of (attribute, value) pairs set for all edges.
|
||||
show_info : list
|
||||
Information shows on nodes.
|
||||
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
|
||||
body : list of str
|
||||
Iterable of lines to add to the graph body.
|
||||
strict : bool
|
||||
Iterable of lines to add to the graph body.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -53,3 +53,7 @@ plt.show()
|
|||
print('Plot 84th tree...') # one tree use categorical feature to split
|
||||
ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
|
||||
plt.show()
|
||||
|
||||
print('Plot 84th tree with graphviz...')
|
||||
graph = lgb.create_tree_digraph(gbm, tree_index=83, name='Tree84')
|
||||
graph.render(view=True)
|
||||
|
|
|
@ -242,8 +242,14 @@ def plot_metric(booster, metric=None, dataset_names=None,
|
|||
|
||||
|
||||
def _to_graphviz(tree_info, show_info, feature_names,
|
||||
graph_attr=None, node_attr=None, edge_attr=None):
|
||||
"""Convert specified tree to graphviz instance."""
|
||||
name=None, comment=None, filename=None, directory=None,
|
||||
format=None, engine=None, encoding=None, graph_attr=None,
|
||||
node_attr=None, edge_attr=None, body=None, strict=False):
|
||||
"""Convert specified tree to graphviz instance.
|
||||
|
||||
See:
|
||||
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
|
||||
"""
|
||||
try:
|
||||
from graphviz import Digraph
|
||||
except ImportError:
|
||||
|
@ -279,31 +285,56 @@ def _to_graphviz(tree_info, show_info, feature_names,
|
|||
if parent is not None:
|
||||
graph.edge(parent, name, decision)
|
||||
|
||||
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
|
||||
graph = Digraph(name=name, comment=comment, filename=filename, directory=directory,
|
||||
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
|
||||
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
|
||||
add(tree_info['tree_structure'])
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def create_tree_digraph(booster, tree_index=0, graph_attr=None,
|
||||
node_attr=None, edge_attr=None, show_info=None):
|
||||
def create_tree_digraph(booster, tree_index=0, show_info=None,
|
||||
name=None, comment=None, filename=None, directory=None,
|
||||
format=None, engine=None, encoding=None, graph_attr=None,
|
||||
node_attr=None, edge_attr=None, body=None, strict=False):
|
||||
"""Create a digraph of specified tree.
|
||||
|
||||
See:
|
||||
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
booster : Booster, LGBMModel
|
||||
Booster or LGBMModel instance.
|
||||
tree_index : int, default 0
|
||||
Specify tree index of target tree.
|
||||
show_info : list
|
||||
Information shows on nodes.
|
||||
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
|
||||
name : str
|
||||
Graph name used in the source code.
|
||||
comment : str
|
||||
Comment added to the first line of the source.
|
||||
filename : str
|
||||
Filename for saving the source (defaults to name + '.gv').
|
||||
directory : str
|
||||
(Sub)directory for source saving and rendering.
|
||||
format : str
|
||||
Rendering output format ('pdf', 'png', ...).
|
||||
engine : str
|
||||
Layout command used ('dot', 'neato', ...).
|
||||
encoding : str
|
||||
Encoding for saving the source.
|
||||
graph_attr : dict
|
||||
Mapping of (attribute, value) pairs for the graph.
|
||||
node_attr : dict
|
||||
Mapping of (attribute, value) pairs set for all nodes.
|
||||
edge_attr : dict
|
||||
Mapping of (attribute, value) pairs set for all edges.
|
||||
show_info : list
|
||||
Information shows on nodes.
|
||||
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
|
||||
body : list of str
|
||||
Iterable of lines to add to the graph body.
|
||||
strict : bool
|
||||
Iterable of lines to add to the graph body.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -330,7 +361,9 @@ def create_tree_digraph(booster, tree_index=0, graph_attr=None,
|
|||
show_info = []
|
||||
|
||||
graph = _to_graphviz(tree_info, show_info, feature_names,
|
||||
graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
|
||||
name=name, comment=comment, filename=filename, directory=directory,
|
||||
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
|
||||
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
|
||||
|
||||
return graph
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче