Revise plotting functions
There is now one common plotting function in dowhy.utils.plotting. Before, there were two different plotting functions in the gcm module and the causal_graph object. This change also improves the fallback networkx plot to look a bit fancier than before. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
0b5e2c3efa
Коммит
cff72a84bf
0
docs/source/example_notebooks/dowhy-conditional-treatment-effects.ipynb
Executable file → Normal file
0
docs/source/example_notebooks/dowhy-conditional-treatment-effects.ipynb
Executable file → Normal file
|
@ -487,4 +487,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
|
|
@ -414,4 +414,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,7 +115,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gcm.util.plot(causal_graph, figure_size=[20, 20])"
|
||||
"from dowhy.utils import plot\n",
|
||||
"\n",
|
||||
"plot(causal_graph, figure_size=[20, 20])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -330,4 +332,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,11 +74,12 @@
|
|||
"source": [
|
||||
"import networkx as nx\n",
|
||||
"import dowhy.gcm as gcm\n",
|
||||
"from dowhy.utils import plot\n",
|
||||
"\n",
|
||||
"causal_model = gcm.InvertibleStructuralCausalModel(nx.DiGraph([('Treatment', 'Vision'), ('Condition', 'Vision')]))\n",
|
||||
"gcm.auto.assign_causal_mechanisms(causal_model, medical_data)\n",
|
||||
"\n",
|
||||
"gcm.util.plot(causal_model.graph)\n",
|
||||
"plot(causal_model.graph)\n",
|
||||
"\n",
|
||||
"gcm.fit(causal_model, medical_data)"
|
||||
]
|
||||
|
@ -109,11 +110,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In cases where we want to examine a hypothetical outcome if an event had not happened or if it had happened\n",
|
||||
"differently, we employ the so-called Counterfactual logic based on structural causal models. Given:\n",
|
||||
|
@ -129,11 +126,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"counterfactual_data1 = gcm.counterfactual_samples(causal_model,\n",
|
||||
|
@ -295,4 +288,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,11 +39,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
|
@ -54,11 +50,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let us also take a look at the pair-wise scatter plots and histograms of the variables."
|
||||
]
|
||||
|
@ -66,11 +58,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"axes = pd.plotting.scatter_matrix(normal_data, figsize=(10, 10), c='#ff0d57', alpha=0.2, hist_kwds={'color':['#1E88E5']});\n",
|
||||
|
@ -82,22 +70,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In the matrix above, the plots on the diagonal line are histograms of variables, whereas those outside of the diagonal are scatter plots of pair of variables. The histograms of services without a dependency, namely `Customer DB`, `Product DB`, `Order DB` and `Shipping Cost Service`, have shapes similar to one half of a Gaussian distribution. The scatter plots of various pairs of variables (e.g., `API` and `www`, `www` and `Website`, `Order Service` and `Order DB`) show linear relations. We shall use this information shortly to assign generative causal models to nodes in the causal graph."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setting up the causal graph\n",
|
||||
"\n",
|
||||
|
@ -110,15 +90,12 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import networkx as nx\n",
|
||||
"from dowhy import gcm\n",
|
||||
"from dowhy.utils import plot, bar_plot\n",
|
||||
"\n",
|
||||
"causal_graph = nx.DiGraph([('www', 'Website'),\n",
|
||||
" ('Auth Service', 'www'),\n",
|
||||
|
@ -141,16 +118,12 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gcm.util.plot(causal_graph, figure_size=[13, 13])"
|
||||
"plot(causal_graph, figure_size=[13, 13])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div class=\"alert alert-block alert-info\">\n",
|
||||
"Here, we are interested in the causal relationships between latencies of services rather than the order of calling the services.\n",
|
||||
|
@ -159,11 +132,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We will use the information from the pair-wise scatter plots and histograms to manually assign causal models. In particular, we assign half-Normal distributions to the root nodes (i.e., `Customer DB`, `Product DB`, `Order DB` and `Shipping Cost Service`). For non-root nodes, we assign linear additive noise models (which scatter plots of many parent-child pairs indicate) with empirical distribution of noise terms."
|
||||
]
|
||||
|
@ -215,9 +184,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -240,9 +206,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -273,9 +236,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -293,11 +253,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div class=\"alert alert-block alert-info\">\n",
|
||||
"By default, a quantile-based anomaly score is used that estimates the negative log-probability of a sample being\n",
|
||||
|
@ -313,14 +269,11 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gcm.util.bar_plot(median_attribs, uncertainty_attribs, 'Attribution Score')"
|
||||
"bar_plot(median_attribs, uncertainty_attribs, 'Attribution Score')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -332,11 +285,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Scenario 2: Observing permanent degradation of latencies\n",
|
||||
"\n",
|
||||
|
@ -348,11 +297,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"outlier_data = pd.read_csv(\"rca_microservice_architecture_anomaly_1000.csv\")\n",
|
||||
|
@ -361,11 +306,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We are interested in the increased latency of `Website` on average for 1000 requests which the customers directly experienced."
|
||||
]
|
||||
|
@ -373,11 +314,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"outlier_data['Website'].mean() - normal_data['Website'].mean()"
|
||||
|
@ -385,22 +322,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The _Website_ is slower on average (by almost 2 seconds) than usual. Why?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Attributing permanent degradation of latencies at a target service to other services\n",
|
||||
"\n",
|
||||
|
@ -410,11 +339,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
|
@ -427,7 +352,7 @@
|
|||
" difference_estimation_func=lambda x, y: np.mean(y) - np.mean(x)),\n",
|
||||
" num_bootstrap_resamples = 10)\n",
|
||||
"\n",
|
||||
"gcm.util.bar_plot(median_attribs, uncertainty_attribs, 'Attribution Score')"
|
||||
"bar_plot(median_attribs, uncertainty_attribs, 'Attribution Score')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -439,11 +364,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Scenario 3: Simulating the intervention of shifting resources\n",
|
||||
"\n",
|
||||
|
@ -457,11 +378,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"median_mean_latencies, uncertainty_mean_latencies = gcm.confidence_intervals(\n",
|
||||
|
@ -490,7 +407,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"avg_website_latency_before = outlier_data.mean().to_dict()['Website']\n",
|
||||
"gcm.util.bar_plot(dict(before=avg_website_latency_before, after=median_mean_latencies['Website']),\n",
|
||||
"bar_plot(dict(before=avg_website_latency_before, after=median_mean_latencies['Website']),\n",
|
||||
" dict(before=np.array([avg_website_latency_before, avg_website_latency_before]), after=uncertainty_mean_latencies['Website']),\n",
|
||||
" ylabel='Avg. Website Latency',\n",
|
||||
" figure_size=(3, 2),\n",
|
||||
|
@ -518,11 +435,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from scipy.stats import truncexpon, halfnorm\n",
|
||||
|
@ -595,11 +508,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The anomalous data is generated in the following way:"
|
||||
]
|
||||
|
@ -607,11 +516,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def unobserved_intrinsic_latencies_anomalous(num_samples):\n",
|
||||
|
@ -634,11 +539,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here, we significantly increased the average time of the *Caching Service* by two seconds, which coincides with our\n",
|
||||
"results from the RCA. Note that a high latency in *Caching Service* would lead to a constantly higher latency in upstream\n",
|
||||
|
@ -668,4 +569,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,9 +40,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -61,9 +58,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -78,9 +72,6 @@
|
|||
"metadata": {
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
|
@ -94,9 +85,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -166,9 +154,6 @@
|
|||
"metadata": {
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
|
@ -184,21 +169,19 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import networkx as nx\n",
|
||||
"import dowhy.gcm as gcm\n",
|
||||
"from dowhy.utils import plot\n",
|
||||
"\n",
|
||||
"causal_graph = nx.DiGraph([('demand', 'submitted'),\n",
|
||||
" ('constraint', 'submitted'),\n",
|
||||
" ('submitted', 'confirmed'), \n",
|
||||
" ('confirmed', 'received')])\n",
|
||||
"gcm.util.plot(causal_graph)"
|
||||
"plot(causal_graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -293,9 +276,6 @@
|
|||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
},
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
|
|
@ -6,6 +6,7 @@ import networkx as nx
|
|||
|
||||
from dowhy.utils.api import parse_state
|
||||
from dowhy.utils.graph_operations import daggity_to_dot
|
||||
from dowhy.utils.plotting import plot
|
||||
|
||||
|
||||
class CausalGraph:
|
||||
|
@ -97,40 +98,8 @@ class CausalGraph:
|
|||
# Adding node attributes
|
||||
self._graph = self.add_node_attributes(observed_node_names)
|
||||
|
||||
def view_graph(self, layout="dot", size=(8, 6), file_name="causal_model"):
|
||||
out_filename = "{}.png".format(file_name)
|
||||
try:
|
||||
import pygraphviz as pgv
|
||||
|
||||
agraph = nx.drawing.nx_agraph.to_agraph(self._graph)
|
||||
agraph.graph_attr.update(size="{},{}!".format(size[0], size[0]))
|
||||
agraph.draw(out_filename, format="png", prog=layout)
|
||||
except:
|
||||
self.logger.warning(
|
||||
"Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed."
|
||||
)
|
||||
self.logger.info("Using Matplotlib for plotting")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=size)
|
||||
solid_edges = [(n1, n2) for n1, n2, e in self._graph.edges(data=True) if "style" not in e]
|
||||
dashed_edges = [
|
||||
(n1, n2) for n1, n2, e in self._graph.edges(data=True) if ("style" in e and e["style"] == "dashed")
|
||||
]
|
||||
plt.clf()
|
||||
|
||||
pos = nx.layout.shell_layout(self._graph)
|
||||
nx.draw_networkx_nodes(self._graph, pos, node_color="yellow", node_size=400)
|
||||
nx.draw_networkx_edges(self._graph, pos, edgelist=solid_edges, arrowstyle="-|>", arrowsize=12)
|
||||
nx.draw_networkx_edges(
|
||||
self._graph, pos, edgelist=dashed_edges, arrowstyle="-|>", style="dashed", arrowsize=12
|
||||
)
|
||||
|
||||
labels = nx.draw_networkx_labels(self._graph, pos)
|
||||
|
||||
plt.axis("off")
|
||||
plt.savefig(out_filename)
|
||||
plt.draw()
|
||||
def view_graph(self, layout=None, size=None, file_name="causal_model"):
|
||||
plot(self._graph, layout_prog=layout, figure_size=size, filename=file_name + ".png")
|
||||
|
||||
def build_graph(self, common_cause_names, instrument_names, effect_modifier_names, mediator_names):
|
||||
"""Creates nodes and edges based on variable names and their semantics.
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import logging
|
||||
from copy import deepcopy
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot
|
||||
from networkx.drawing import nx_pydot
|
||||
|
||||
from dowhy.utils import plotting
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -20,173 +19,33 @@ def plot(
|
|||
figure_size: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Convenience function to plot causal graphs. This function uses different backends based on what's
|
||||
available on the system. The best result is achieved when using Graphviz as the backend. This requires both
|
||||
the Python pygraphviz package (``pip install pygraphviz``) and the shared system library (e.g. ``brew install
|
||||
graphviz`` or ``apt-get install graphviz``). When graphviz is not available, it will fall back to the
|
||||
networkx backend.
|
||||
|
||||
:param causal_graph: The graph to be plotted
|
||||
:param causal_strengths: An optional dictionary with Edge -> float entries.
|
||||
:param colors: An optional dictionary with color specifications for edges or nodes.
|
||||
:param filename: An optional filename if the output should be plotted into a file.
|
||||
:param display_plot: Optionally specify if the plot should be displayed or not (default to True).
|
||||
:param figure_size: A tuple to define the width and height (as a tuple) of the pyplot. This is used to parameter to
|
||||
modify pyplot's 'figure.figsize' parameter. If None is given, the current/default value is used.
|
||||
:param kwargs: Remaining parameters will be passed through to the backend verbatim.
|
||||
|
||||
**Example usage**::
|
||||
|
||||
>>> plot(nx.DiGraph([('X', 'Y')])) # plots X -> Y
|
||||
>>> plot(nx.DiGraph([('X', 'Y')]), causal_strengths={('X', 'Y'): 0.43}) # annotates arrow with 0.43
|
||||
>>> plot(nx.DiGraph([('X', 'Y')]), colors={('X', 'Y'): 'red', 'X': 'green'}) # colors X -> Y red and X green
|
||||
"""
|
||||
try:
|
||||
from dowhy.gcm.util.pygraphviz import _plot_causal_graph_graphviz
|
||||
|
||||
try:
|
||||
_plot_causal_graph_graphviz(
|
||||
causal_graph,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as error:
|
||||
_logger.info(
|
||||
"There was an error when trying to plot the graph via graphviz, falling back to networkx "
|
||||
"plotting. If graphviz is not installed, consider installing it for better looking plots. The"
|
||||
" error is:" + str(error)
|
||||
)
|
||||
_plot_causal_graph_networkx(
|
||||
causal_graph,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
_logger.info(
|
||||
"Pygraphviz installation not found, falling back to networkx plotting. "
|
||||
"For better looking plots, consider installing pygraphviz. Note This requires both the Python "
|
||||
"pygraphviz package (``pip install pygraphviz``) and the shared system library (e.g. "
|
||||
"``brew install graphviz`` or ``apt-get install graphviz``)"
|
||||
)
|
||||
_plot_causal_graph_networkx(
|
||||
causal_graph,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
"""Deprecated, please use dowhy.utils.plotting.plot() instead."""
|
||||
warnings.warn(
|
||||
"The plot method is deprecated. Use the plot function from dowhy.utils.plotting instead!", DeprecationWarning
|
||||
)
|
||||
plotting.plot(
|
||||
causal_graph=causal_graph,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def plot_adjacency_matrix(
|
||||
adjacency_matrix: pd.DataFrame, is_directed: bool, filename: Optional[str] = None, display_plot: bool = True
|
||||
) -> None:
|
||||
plot(
|
||||
nx.from_pandas_adjacency(adjacency_matrix, nx.DiGraph() if is_directed else nx.Graph()),
|
||||
display_plot=display_plot,
|
||||
filename=filename,
|
||||
"""Deprecated, please use dowhy.utils.plotting.plot_adjacency_matrix() instead."""
|
||||
warnings.warn(
|
||||
"The plot method is deprecated. Use the plot_adjacency_matrix function from dowhy.utils.plotting instead!",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
def _plot_causal_graph_networkx(
|
||||
causal_graph: nx.Graph,
|
||||
pydot_layout_prog: Optional[str] = None,
|
||||
causal_strengths: Optional[Dict[Tuple[Any, Any], float]] = None,
|
||||
colors: Optional[Dict[Union[Any, Tuple[Any, Any]], str]] = None,
|
||||
filename: Optional[str] = None,
|
||||
display_plot: bool = True,
|
||||
label_wrap_length: int = 3,
|
||||
figure_size: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
if "graph" not in causal_graph.graph:
|
||||
causal_graph.graph["graph"] = {"rankdir": "TD"}
|
||||
|
||||
if pydot_layout_prog is not None:
|
||||
layout = nx_pydot.pydot_layout(causal_graph, prog=pydot_layout_prog)
|
||||
else:
|
||||
layout = nx.spring_layout(causal_graph)
|
||||
|
||||
if causal_strengths is None:
|
||||
causal_strengths = {}
|
||||
else:
|
||||
causal_strengths = deepcopy(causal_strengths)
|
||||
if colors is None:
|
||||
colors = {}
|
||||
else:
|
||||
colors = deepcopy(colors)
|
||||
|
||||
max_strength = 0.0
|
||||
for (source, target, strength) in causal_graph.edges(data="CAUSAL_STRENGTH", default=1):
|
||||
if (source, target) not in causal_strengths:
|
||||
causal_strengths[(source, target)] = strength
|
||||
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
|
||||
if (source, target) not in colors:
|
||||
colors[(source, target)] = "black"
|
||||
|
||||
for edge in causal_graph.edges:
|
||||
if edge[0] == edge[1]:
|
||||
raise ValueError(
|
||||
"Node %s has a self-cycle, i.e. a node pointing to itself. Plotting self-cycles is "
|
||||
"currently only supported for plots using Graphviz! Consider installing the corresponding"
|
||||
"requirements." % edge[0]
|
||||
)
|
||||
|
||||
# Wrapping labels if they are too long
|
||||
labels = {}
|
||||
for node in causal_graph.nodes:
|
||||
if node not in colors:
|
||||
colors[node] = "lightblue"
|
||||
node_name_splits = str(node).split(" ")
|
||||
for i in range(1, len(node_name_splits)):
|
||||
if len(node_name_splits[i - 1]) > label_wrap_length:
|
||||
node_name_splits[i] = "\n" + node_name_splits[i]
|
||||
else:
|
||||
node_name_splits[i] = " " + node_name_splits[i]
|
||||
|
||||
labels[node] = "".join(node_name_splits)
|
||||
|
||||
if figure_size is not None:
|
||||
org_fig_size = pyplot.rcParams["figure.figsize"]
|
||||
pyplot.rcParams["figure.figsize"] = figure_size
|
||||
|
||||
figure = pyplot.figure()
|
||||
|
||||
nx.draw(
|
||||
causal_graph,
|
||||
pos=layout,
|
||||
node_color=[colors[node] for node in causal_graph.nodes()],
|
||||
edge_color=[colors[(s, t)] for (s, t) in causal_graph.edges()],
|
||||
linewidths=0.25,
|
||||
labels=labels,
|
||||
font_size=8,
|
||||
font_weight="bold",
|
||||
node_size=2000,
|
||||
width=[_calc_arrow_width(causal_strengths[(s, t)], max_strength) for (s, t) in causal_graph.edges()],
|
||||
plotting.plot_adjacency_matrix(
|
||||
adjacency_matrix=adjacency_matrix, is_directed=is_directed, filename=filename, display_plot=display_plot
|
||||
)
|
||||
|
||||
if display_plot:
|
||||
pyplot.show()
|
||||
|
||||
if figure_size is not None:
|
||||
pyplot.rcParams["figure.figsize"] = org_fig_size
|
||||
|
||||
if filename is not None:
|
||||
figure.savefig(filename)
|
||||
|
||||
|
||||
def _calc_arrow_width(strength: float, max_strength: float):
|
||||
return 0.2 + 4.0 * float(abs(strength)) / float(max_strength)
|
||||
|
||||
|
||||
def bar_plot(
|
||||
values: Dict[str, float],
|
||||
|
@ -200,43 +59,20 @@ def bar_plot(
|
|||
xticks_rotation: int = 90,
|
||||
sort_names: bool = True,
|
||||
) -> None:
|
||||
"""Convenience function to make a bar plot of the given values with uncertainty bars, if provided. Useful for all
|
||||
kinds of attribution results (including confidence intervals).
|
||||
|
||||
:param values: A dictionary where the keys are the labels and the values are the values to be plotted.
|
||||
:param uncertainties: A dictionary of attributes to be added to the error bars.
|
||||
:param ylabel: The label for the y-axis.
|
||||
:param filename: An optional filename if the output should be plotted into a file.
|
||||
:param display_plot: Optionally specify if the plot should be displayed or not (default to True).
|
||||
:param figure_size: The size of the figure to be plotted.
|
||||
:param bar_width: The width of the bars.
|
||||
:param xticks: Explicitly specify the labels for the bars on the x-axis.
|
||||
:param xticks_rotation: Specify the rotation of the labels on the x-axis.
|
||||
:param sort_names: If True, the names in the plot are sorted alphabetically. If False, the order as given in values
|
||||
are used.
|
||||
"""
|
||||
if sort_names:
|
||||
values = {k: values[k] for k in sorted(values)}
|
||||
|
||||
if uncertainties is None:
|
||||
uncertainties = {node: [values[node], values[node]] for node in values}
|
||||
|
||||
figure, ax = pyplot.subplots(figsize=figure_size)
|
||||
ci_plus = [uncertainties[node][1] - values[node] for node in values.keys()]
|
||||
ci_minus = [values[node] - uncertainties[node][0] for node in values.keys()]
|
||||
yerr = np.array([ci_minus, ci_plus])
|
||||
yerr[abs(yerr) < 10**-7] = 0
|
||||
pyplot.bar(values.keys(), values.values(), yerr=yerr, ecolor="#1E88E5", color="#ff0d57", width=bar_width)
|
||||
pyplot.ylabel(ylabel)
|
||||
pyplot.xticks(rotation=xticks_rotation)
|
||||
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.spines["top"].set_visible(False)
|
||||
if xticks:
|
||||
pyplot.xticks(list(uncertainties.keys()), xticks)
|
||||
|
||||
if display_plot:
|
||||
pyplot.show()
|
||||
|
||||
if filename is not None:
|
||||
figure.savefig(filename)
|
||||
"""Deprecated, please use dowhy.utils.plotting.bar_plot() instead."""
|
||||
warnings.warn(
|
||||
"The plot method is deprecated. Use the bar_plot function from dowhy.utils.plotting instead!",
|
||||
DeprecationWarning,
|
||||
)
|
||||
plotting.bar_plot(
|
||||
values=values,
|
||||
uncertainties=uncertainties,
|
||||
ylabel=ylabel,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
bar_width=bar_width,
|
||||
xticks=xticks,
|
||||
xticks_rotation=xticks_rotation,
|
||||
sort_names=sort_names,
|
||||
)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .plotting import bar_plot, plot, plot_adjacency_matrix
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pygraphviz
|
||||
from matplotlib import image, pyplot
|
||||
|
||||
|
||||
def _plot_causal_graph_graphviz(
|
||||
def plot_causal_graph_graphviz(
|
||||
causal_graph: nx.Graph,
|
||||
layout_prog: Optional[str] = None,
|
||||
display_causal_strengths: bool = True,
|
||||
causal_strengths: Optional[Dict[Tuple[Any, Any], float]] = None,
|
||||
colors: Optional[Dict[Union[Any, Tuple[Any, Any]], str]] = None,
|
||||
|
@ -27,6 +26,9 @@ def _plot_causal_graph_graphviz(
|
|||
else:
|
||||
colors = deepcopy(colors)
|
||||
|
||||
if layout_prog is None:
|
||||
layout_prog = "dot"
|
||||
|
||||
max_strength = 0.0
|
||||
for (source, target, strength) in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
|
||||
if (source, target) not in causal_strengths:
|
||||
|
@ -54,6 +56,8 @@ def _plot_causal_graph_graphviz(
|
|||
else:
|
||||
tmp_label = str(" %s" % str(int(causal_strength * 100) / 100))
|
||||
|
||||
from dowhy.utils.plotting import _calc_arrow_width
|
||||
|
||||
pygraphviz_graph.add_edge(
|
||||
str(source),
|
||||
str(target),
|
||||
|
@ -64,7 +68,7 @@ def _plot_causal_graph_graphviz(
|
|||
else:
|
||||
pygraphviz_graph.add_edge(str(source), str(target), color=color)
|
||||
|
||||
pygraphviz_graph.layout(prog="dot")
|
||||
pygraphviz_graph.layout(prog=layout_prog)
|
||||
if filename is not None:
|
||||
filename, file_extension = os.path.splitext(filename)
|
||||
if file_extension == "":
|
||||
|
@ -72,25 +76,6 @@ def _plot_causal_graph_graphviz(
|
|||
pygraphviz_graph.draw(filename + file_extension)
|
||||
|
||||
if display_plot:
|
||||
from dowhy.utils.plotting import _plot_as_pyplot_figure
|
||||
|
||||
_plot_as_pyplot_figure(pygraphviz_graph, figure_size)
|
||||
|
||||
|
||||
def _calc_arrow_width(strength: float, max_strength: float):
|
||||
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)
|
||||
|
||||
|
||||
def _plot_as_pyplot_figure(pygraphviz_graph: pygraphviz.AGraph, figure_size: Optional[Tuple[int, int]] = None) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pygraphviz_graph.draw(tmp_dir_name + os.sep + "Graph.png")
|
||||
img = image.imread(tmp_dir_name + os.sep + "Graph.png")
|
||||
|
||||
if figure_size is not None:
|
||||
org_fig_size = pyplot.rcParams["figure.figsize"]
|
||||
pyplot.rcParams["figure.figsize"] = figure_size
|
||||
|
||||
pyplot.imshow(img)
|
||||
pyplot.axis("off")
|
||||
pyplot.show()
|
||||
|
||||
if figure_size is not None:
|
||||
pyplot.rcParams["figure.figsize"] = org_fig_size
|
|
@ -0,0 +1,272 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from networkx.drawing import nx_pydot
|
||||
|
||||
|
||||
def plot_causal_graph_networkx(
|
||||
causal_graph: nx.Graph,
|
||||
layout_prog: Optional[str] = None,
|
||||
causal_strengths: Optional[Dict[Tuple[Any, Any], float]] = None,
|
||||
colors: Optional[Dict[Union[Any, Tuple[Any, Any]], str]] = None,
|
||||
filename: Optional[str] = None,
|
||||
display_plot: bool = True,
|
||||
label_wrap_length: int = 3,
|
||||
figure_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
if causal_strengths is None:
|
||||
causal_strengths = {}
|
||||
else:
|
||||
causal_strengths = deepcopy(causal_strengths)
|
||||
if colors is None:
|
||||
colors = {}
|
||||
else:
|
||||
colors = deepcopy(colors)
|
||||
|
||||
max_strength = 0.0
|
||||
for (source, target, strength) in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
|
||||
if (source, target) not in causal_strengths:
|
||||
causal_strengths[(source, target)] = strength
|
||||
|
||||
if strength is not None:
|
||||
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
|
||||
|
||||
if (source, target) not in colors:
|
||||
colors[(source, target)] = "gray"
|
||||
|
||||
for edge in causal_graph.edges:
|
||||
if edge[0] == edge[1]:
|
||||
raise ValueError(
|
||||
"Node %s has a self-cycle, i.e. a node pointing to itself. Plotting self-cycles is "
|
||||
"currently only supported for plots using Graphviz! Consider installing the corresponding"
|
||||
"requirements." % edge[0]
|
||||
)
|
||||
|
||||
# Wrapping labels if they are too long
|
||||
labels = {}
|
||||
for node in causal_graph.nodes:
|
||||
if node not in colors:
|
||||
colors[node] = "skyblue"
|
||||
|
||||
node_name_splits = str(node).split(" ")
|
||||
for i in range(1, len(node_name_splits)):
|
||||
if len(node_name_splits[i - 1]) > label_wrap_length:
|
||||
node_name_splits[i] = "\n" + node_name_splits[i]
|
||||
else:
|
||||
node_name_splits[i] = " " + node_name_splits[i]
|
||||
|
||||
labels[node] = "".join(node_name_splits)
|
||||
|
||||
from dowhy.utils.plotting import _calc_arrow_width
|
||||
|
||||
edge_widths = {
|
||||
(s, t): 2 if causal_strengths[(s, t)] is None else _calc_arrow_width(causal_strengths[(s, t)], max_strength)
|
||||
for (s, t) in causal_graph.edges()
|
||||
}
|
||||
|
||||
if layout_prog is not None:
|
||||
layout = nx_pydot.pydot_layout(causal_graph, prog=layout_prog)
|
||||
if figure_size is not None:
|
||||
figure = plt.figure(figsize=figure_size)
|
||||
else:
|
||||
figure = plt.figure()
|
||||
|
||||
nx.draw(
|
||||
causal_graph,
|
||||
pos=layout,
|
||||
node_color=[colors[node] for node in causal_graph.nodes()],
|
||||
edge_color=[colors[(s, t)] for (s, t) in causal_graph.edges()],
|
||||
labels=labels,
|
||||
font_weight="bold",
|
||||
node_size=2000,
|
||||
arrowsize=20,
|
||||
alpha=0.8,
|
||||
width=[edge_widths[(s, t)] for (s, t) in causal_graph.edges()],
|
||||
)
|
||||
else:
|
||||
figure = _draw_graph_with_custom_layout(causal_graph, colors, edge_widths, figure_size)
|
||||
|
||||
plt.gca().set_axis_off()
|
||||
if display_plot:
|
||||
plt.show()
|
||||
|
||||
if filename is not None:
|
||||
figure.savefig(filename)
|
||||
|
||||
|
||||
def _draw_graph_with_custom_layout(
|
||||
graph: nx.Graph,
|
||||
colors: Dict[Any, str],
|
||||
edge_widths: Dict[Tuple[Any, Any], float],
|
||||
figure_size: Optional[List[int]] = None,
|
||||
):
|
||||
# This layout tries to mimic the graphviz layout in a simpler form. The depth grows horizontally here instead of
|
||||
# vertically.
|
||||
if isinstance(graph, nx.DiGraph):
|
||||
graph = nx.DiGraph(graph)
|
||||
else:
|
||||
graph = nx.Graph(graph)
|
||||
|
||||
layers = _custom_assign_layers(graph)
|
||||
nx.set_node_attributes(graph, layers, "layer")
|
||||
node_positions = nx.multipartite_layout(graph, subset_key="layer")
|
||||
|
||||
if figure_size is None:
|
||||
# Set the figure size based on the number of nodes
|
||||
figure = plt.figure(
|
||||
figsize=(
|
||||
max(np.max([v for v in layers.values()]) * 2, 5),
|
||||
max(_custom_count_nodes_vertically(graph) * 1.25, 5),
|
||||
)
|
||||
)
|
||||
else:
|
||||
figure = plt.figure(figsize=figure_size)
|
||||
|
||||
nx.draw_networkx_nodes(
|
||||
graph,
|
||||
node_positions,
|
||||
node_size=2000,
|
||||
node_color=[colors[node] for node in graph.nodes()],
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
vertical_neighbor_indicator = _custom_create_vertical_neighbor_indicator(graph, node_positions)
|
||||
|
||||
# Nodes that are vertically connected, but not neighbors should be connected via a curved edge.
|
||||
edges_with_curved_line = [
|
||||
(u, v)
|
||||
for u, v in graph.edges()
|
||||
if graph.nodes[u]["layer"] == graph.nodes[v]["layer"] and not vertical_neighbor_indicator.loc[str(u), str(v)]
|
||||
]
|
||||
nx.draw_networkx_edges(
|
||||
graph,
|
||||
node_positions,
|
||||
edgelist=edges_with_curved_line,
|
||||
width=[edge_widths[(s, t)] for (s, t) in edges_with_curved_line],
|
||||
edge_color=[colors[(s, t)] for (s, t) in edges_with_curved_line],
|
||||
arrowsize=20,
|
||||
connectionstyle="arc3,rad=0.5",
|
||||
alpha=0.8,
|
||||
min_source_margin=25,
|
||||
min_target_margin=21,
|
||||
)
|
||||
|
||||
# All other nodes should be connected with a straight line.
|
||||
edges_with_straigth_line = [
|
||||
(u, v)
|
||||
for u, v in graph.edges()
|
||||
if (graph.nodes[u]["layer"] == graph.nodes[v]["layer"] and vertical_neighbor_indicator.loc[str(u), str(v)])
|
||||
or graph.nodes[u]["layer"] != graph.nodes[v]["layer"]
|
||||
]
|
||||
nx.draw_networkx_edges(
|
||||
graph,
|
||||
node_positions,
|
||||
edgelist=edges_with_straigth_line,
|
||||
width=[edge_widths[(s, t)] for (s, t) in edges_with_straigth_line],
|
||||
edge_color=[colors[(s, t)] for (s, t) in edges_with_straigth_line],
|
||||
arrowsize=20,
|
||||
alpha=0.8,
|
||||
min_source_margin=25,
|
||||
min_target_margin=21,
|
||||
)
|
||||
|
||||
# Draw labels node labels
|
||||
for node, (x, y) in node_positions.items():
|
||||
plt.text(x, y, node, ha="center", va="center", color="black", fontweight="bold")
|
||||
|
||||
return figure
|
||||
|
||||
|
||||
def _custom_assign_layers(graph):
|
||||
# Each node gets a depth assigned, based on the distance to the closest root node.
|
||||
layers = {}
|
||||
|
||||
if not isinstance(graph, nx.DiGraph):
|
||||
sub_graphs = [graph.subgraph(c) for c in nx.connected_components(graph)]
|
||||
# In case of undirected graphs, we just take any node as root node.
|
||||
root_nodes = [list(sub_graph.nodes)[0] for sub_graph in sub_graphs]
|
||||
else:
|
||||
sub_graphs = [graph]
|
||||
root_nodes = [n for n, d in graph.in_degree() if d == 0]
|
||||
|
||||
for sub_graph in sub_graphs:
|
||||
nodes_in_subgraph = list(sub_graph.nodes)
|
||||
|
||||
for node in nodes_in_subgraph:
|
||||
min_distance = float("inf")
|
||||
|
||||
for root_node in root_nodes:
|
||||
try:
|
||||
distance = nx.shortest_path_length(graph, root_node, node)
|
||||
min_distance = min(min_distance, distance)
|
||||
except nx.NetworkXNoPath:
|
||||
# No path to root node, ignore this connection then.
|
||||
continue
|
||||
|
||||
layers[node] = min_distance
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
def _custom_count_nodes_vertically(graph):
|
||||
# Counts the number of vertical nodes in the same layers.
|
||||
layer_count = {}
|
||||
for n in graph.nodes:
|
||||
if graph.nodes[n]["layer"] not in layer_count:
|
||||
layer_count[graph.nodes[n]["layer"]] = 0
|
||||
|
||||
layer_count[graph.nodes[n]["layer"]] += 1
|
||||
|
||||
return np.max([v for v in layer_count.items()])
|
||||
|
||||
|
||||
def _custom_create_vertical_neighbor_indicator(graph, pos):
|
||||
# Creates a matrix indicating whether two nodes are vertical neighbors.
|
||||
all_nodes = list(graph.nodes)
|
||||
vertical_neighbor_indicator = pd.DataFrame(
|
||||
np.zeros((len(all_nodes), len(all_nodes))).astype(bool),
|
||||
index=[str(n) for n in all_nodes],
|
||||
columns=[str(n) for n in all_nodes],
|
||||
)
|
||||
|
||||
# Get all y coordinates per layer
|
||||
layer_y_coords = {}
|
||||
for n in graph.nodes:
|
||||
if graph.nodes[n]["layer"] not in layer_y_coords:
|
||||
layer_y_coords[graph.nodes[n]["layer"]] = []
|
||||
|
||||
layer_y_coords[graph.nodes[n]["layer"]].append((n, pos[n][1]))
|
||||
|
||||
# Sort the y-coordinates
|
||||
for layer in layer_y_coords:
|
||||
layer_y_coords[layer].sort(key=lambda x: x[1])
|
||||
|
||||
layer_y_coords_map = {}
|
||||
for layer in layer_y_coords:
|
||||
for i, k in enumerate(layer_y_coords[layer]):
|
||||
if k[0] in layer_y_coords_map:
|
||||
raise RuntimeError("Something went wrong when creating the layer y-coordinate map.")
|
||||
layer_y_coords_map[k[0]] = i
|
||||
|
||||
for n1 in all_nodes:
|
||||
for n2 in all_nodes:
|
||||
if n1 == n2:
|
||||
vertical_neighbor_indicator.loc[str(n1), str(n2)] = True
|
||||
continue
|
||||
|
||||
n1_layer = graph.nodes[n1]["layer"]
|
||||
n2_layer = graph.nodes[n2]["layer"]
|
||||
|
||||
if n1_layer != n2_layer:
|
||||
vertical_neighbor_indicator.loc[str(n1), str(n2)] = False
|
||||
continue
|
||||
|
||||
vertical_neighbor_indicator.loc[str(n1), str(n2)] = (
|
||||
layer_y_coords_map[n1] == layer_y_coords_map[n2] + 1
|
||||
) or (layer_y_coords_map[n1] == layer_y_coords_map[n2] - 1)
|
||||
|
||||
return vertical_neighbor_indicator
|
|
@ -0,0 +1,189 @@
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pygraphviz
|
||||
from matplotlib import image
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def plot(
|
||||
causal_graph: nx.Graph,
|
||||
layout_prog: Optional[str] = None,
|
||||
causal_strengths: Optional[Dict[Tuple[Any, Any], float]] = None,
|
||||
colors: Optional[Dict[Union[Any, Tuple[Any, Any]], str]] = None,
|
||||
filename: Optional[str] = None,
|
||||
display_plot: bool = True,
|
||||
figure_size: Optional[Tuple[int, int]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Convenience function to plot causal graphs. This function uses different backends based on what's
|
||||
available on the system. The best result is achieved when using Graphviz as the backend. This requires both
|
||||
the shared system library (e.g. ``brew install graphviz`` or ``apt-get install graphviz``) and the Python pygraphviz
|
||||
package (``pip install pygraphviz``). When graphviz is not available, it will fall back to the networkx backend.
|
||||
|
||||
:param causal_graph: The graph to be plotted
|
||||
:param layout_prog: Defines the layout type. If None is given, the 'dot' layout is used for graphviz plots and a
|
||||
customized layout for networkx plots.
|
||||
:param causal_strengths: An optional dictionary with Edge -> float entries.
|
||||
:param colors: An optional dictionary with color specifications for edges or nodes.
|
||||
:param filename: An optional filename if the output should be plotted into a file.
|
||||
:param display_plot: Optionally specify if the plot should be displayed or not (default to True).
|
||||
:param figure_size: A tuple to define the width and height (as a tuple) of the pyplot. This is used to parameter to
|
||||
modify pyplot's 'figure.figsize' parameter. If None is given, the current/default value is used.
|
||||
:param kwargs: Remaining parameters will be passed through to the backend verbatim.
|
||||
|
||||
**Example usage**::
|
||||
|
||||
>>> plot(nx.DiGraph([('X', 'Y')])) # plots X -> Y
|
||||
>>> plot(nx.DiGraph([('X', 'Y')]), causal_strengths={('X', 'Y'): 0.43}) # annotates arrow with 0.43
|
||||
>>> plot(nx.DiGraph([('X', 'Y')]), colors={('X', 'Y'): 'red', 'X': 'green'}) # colors X -> Y red and X green
|
||||
"""
|
||||
try:
|
||||
from dowhy.utils.graphviz_plotting import plot_causal_graph_graphviz
|
||||
|
||||
try:
|
||||
plot_causal_graph_graphviz(
|
||||
causal_graph,
|
||||
layout_prog=layout_prog,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as error:
|
||||
from dowhy.utils.networkx_plotting import plot_causal_graph_networkx
|
||||
|
||||
_logger.info(
|
||||
"There was an error when trying to plot the graph via graphviz, falling back to networkx "
|
||||
"plotting. If graphviz is not installed, consider installing it for better looking plots. The"
|
||||
" error is:" + str(error)
|
||||
)
|
||||
|
||||
plot_causal_graph_networkx(
|
||||
causal_graph,
|
||||
layout_prog=layout_prog,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
from dowhy.utils.networkx_plotting import plot_causal_graph_networkx
|
||||
|
||||
_logger.info(
|
||||
"Pygraphviz installation not found, falling back to networkx plotting. "
|
||||
"For better looking plots, consider installing pygraphviz. Note This requires both the Python "
|
||||
"pygraphviz package (``pip install pygraphviz``) and the shared system library (e.g. "
|
||||
"``brew install graphviz`` or ``apt-get install graphviz``)"
|
||||
)
|
||||
|
||||
plot_causal_graph_networkx(
|
||||
causal_graph,
|
||||
layout_prog=layout_prog,
|
||||
causal_strengths=causal_strengths,
|
||||
colors=colors,
|
||||
filename=filename,
|
||||
display_plot=display_plot,
|
||||
figure_size=figure_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def plot_adjacency_matrix(
|
||||
adjacency_matrix: pd.DataFrame, is_directed: bool, filename: Optional[str] = None, display_plot: bool = True
|
||||
) -> None:
|
||||
plot(
|
||||
nx.from_pandas_adjacency(adjacency_matrix, nx.DiGraph() if is_directed else nx.Graph()),
|
||||
display_plot=display_plot,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
def bar_plot(
|
||||
values: Dict[str, float],
|
||||
uncertainties: Optional[Dict[str, Tuple[float, float]]] = None,
|
||||
ylabel: str = "",
|
||||
filename: Optional[str] = None,
|
||||
display_plot: bool = True,
|
||||
figure_size: Optional[List[int]] = None,
|
||||
bar_width: float = 0.8,
|
||||
xticks: List[str] = None,
|
||||
xticks_rotation: int = 90,
|
||||
sort_names: bool = False,
|
||||
) -> None:
|
||||
"""Convenience function to make a bar plot of the given values with uncertainty bars, if provided. Useful for all
|
||||
kinds of attribution results (including confidence intervals).
|
||||
|
||||
:param values: A dictionary where the keys are the labels and the values are the values to be plotted.
|
||||
:param uncertainties: A dictionary of attributes to be added to the error bars.
|
||||
:param ylabel: The label for the y-axis.
|
||||
:param filename: An optional filename if the output should be plotted into a file.
|
||||
:param display_plot: Optionally specify if the plot should be displayed or not (default to True).
|
||||
:param figure_size: The size of the figure to be plotted.
|
||||
:param bar_width: The width of the bars.
|
||||
:param xticks: Explicitly specify the labels for the bars on the x-axis.
|
||||
:param xticks_rotation: Specify the rotation of the labels on the x-axis.
|
||||
:param sort_names: If True, the names in the plot are sorted alphabetically. If False, the order as given in values
|
||||
are used.
|
||||
"""
|
||||
if sort_names:
|
||||
values = {k: values[k] for k in sorted(values)}
|
||||
|
||||
if xticks is not None:
|
||||
xticks = sorted(xticks)
|
||||
|
||||
if uncertainties is None:
|
||||
uncertainties = {node: [values[node], values[node]] for node in values}
|
||||
|
||||
figure, ax = plt.subplots(figsize=figure_size)
|
||||
ci_plus = [uncertainties[node][1] - values[node] for node in values.keys()]
|
||||
ci_minus = [values[node] - uncertainties[node][0] for node in values.keys()]
|
||||
yerr = np.array([ci_minus, ci_plus])
|
||||
yerr[abs(yerr) < 10**-7] = 0
|
||||
plt.bar(values.keys(), values.values(), yerr=yerr, ecolor="#1E88E5", color="#ff0d57", width=bar_width)
|
||||
plt.ylabel(ylabel)
|
||||
plt.xticks(rotation=xticks_rotation)
|
||||
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.spines["top"].set_visible(False)
|
||||
if xticks:
|
||||
plt.xticks(list(uncertainties.keys()), xticks)
|
||||
|
||||
if display_plot:
|
||||
plt.show()
|
||||
|
||||
if filename is not None:
|
||||
figure.savefig(filename)
|
||||
|
||||
|
||||
def _calc_arrow_width(strength: float, max_strength: float):
|
||||
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)
|
||||
|
||||
|
||||
def _plot_as_pyplot_figure(pygraphviz_graph: pygraphviz.AGraph, figure_size: Optional[Tuple[int, int]] = None) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pygraphviz_graph.draw(tmp_dir_name + os.sep + "Graph.png")
|
||||
img = image.imread(tmp_dir_name + os.sep + "Graph.png")
|
||||
|
||||
if figure_size is not None:
|
||||
org_fig_size = plt.rcParams["figure.figsize"]
|
||||
plt.rcParams["figure.figsize"] = figure_size
|
||||
|
||||
plt.imshow(img)
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
if figure_size is not None:
|
||||
plt.rcParams["figure.figsize"] = org_fig_size
|
|
@ -1,12 +0,0 @@
|
|||
from pytest import approx
|
||||
|
||||
from dowhy.gcm.util.pygraphviz import _calc_arrow_width
|
||||
|
||||
|
||||
def test_calc_arrow_width():
|
||||
assert _calc_arrow_width(0.4, max_strength=0.5) == approx(3.3, abs=0.01)
|
||||
assert _calc_arrow_width(0.2, max_strength=0.5) == approx(1.7, abs=0.01)
|
||||
assert _calc_arrow_width(-0.2, max_strength=0.5) == approx(1.7, abs=0.01)
|
||||
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
|
||||
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
|
||||
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
|
|
@ -1,8 +1,10 @@
|
|||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from dowhy.gcm.util import plot, plot_adjacency_matrix
|
||||
from dowhy.utils import plot, plot_adjacency_matrix
|
||||
from dowhy.utils.plotting import _calc_arrow_width
|
||||
|
||||
|
||||
def test_when_plot_does_not_raise_exception():
|
||||
|
@ -37,3 +39,12 @@ def test_given_colors_when_plot_graph_then_does_not_modify_input_object():
|
|||
plot(nx.DiGraph([("X", "Y"), ("Y", "Z")]), colors=colors)
|
||||
|
||||
assert colors == {("X", "Y"): "red", "X": "blue"}
|
||||
|
||||
|
||||
def test_calc_arrow_width():
|
||||
assert _calc_arrow_width(0.4, max_strength=0.5) == approx(3.3, abs=0.01)
|
||||
assert _calc_arrow_width(0.2, max_strength=0.5) == approx(1.7, abs=0.01)
|
||||
assert _calc_arrow_width(-0.2, max_strength=0.5) == approx(1.7, abs=0.01)
|
||||
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
|
||||
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
|
||||
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
|
Загрузка…
Ссылка в новой задаче