Fix bug in networkx plot function with 0 error strenghts

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Patrick Bloebaum 2023-12-08 12:00:35 -08:00 коммит произвёл Patrick Blöbaum
Родитель 4f317345dc
Коммит 72986a859d
3 изменённых файлов: 12 добавлений и 1 удалений

Просмотреть файл

@ -32,7 +32,7 @@ def plot_causal_graph_networkx(
if (source, target) not in causal_strengths: if (source, target) not in causal_strengths:
causal_strengths[(source, target)] = strength causal_strengths[(source, target)] = strength
if strength is not None: if causal_strengths[(source, target)] is not None:
max_strength = max(max_strength, abs(causal_strengths[(source, target)])) max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
if (source, target) not in colors: if (source, target) not in colors:

Просмотреть файл

@ -176,6 +176,11 @@ def bar_plot(
def _calc_arrow_width(strength: float, max_strength: float): def _calc_arrow_width(strength: float, max_strength: float):
if max_strength == 0:
return 4.1
elif max_strength < 0:
raise ValueError("Got a negative strength! The strength needs to be positive.")
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength) return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)

Просмотреть файл

@ -1,9 +1,11 @@
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest
from _pytest.python_api import approx from _pytest.python_api import approx
from dowhy.utils import plot, plot_adjacency_matrix from dowhy.utils import plot, plot_adjacency_matrix
from dowhy.utils.networkx_plotting import plot_causal_graph_networkx
from dowhy.utils.plotting import _calc_arrow_width, bar_plot from dowhy.utils.plotting import _calc_arrow_width, bar_plot
@ -48,6 +50,10 @@ def test_calc_arrow_width():
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01) assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01) assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01) assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
assert _calc_arrow_width(100, max_strength=0) == 4.1
with pytest.raises(ValueError):
_calc_arrow_width(100, max_strength=-1)
def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error(): def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error():