Fix bug in networkx plot function with 0 error strenghts
Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
4f317345dc
Коммит
72986a859d
|
@ -32,7 +32,7 @@ def plot_causal_graph_networkx(
|
||||||
if (source, target) not in causal_strengths:
|
if (source, target) not in causal_strengths:
|
||||||
causal_strengths[(source, target)] = strength
|
causal_strengths[(source, target)] = strength
|
||||||
|
|
||||||
if strength is not None:
|
if causal_strengths[(source, target)] is not None:
|
||||||
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
|
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
|
||||||
|
|
||||||
if (source, target) not in colors:
|
if (source, target) not in colors:
|
||||||
|
|
|
@ -176,6 +176,11 @@ def bar_plot(
|
||||||
|
|
||||||
|
|
||||||
def _calc_arrow_width(strength: float, max_strength: float):
|
def _calc_arrow_width(strength: float, max_strength: float):
|
||||||
|
if max_strength == 0:
|
||||||
|
return 4.1
|
||||||
|
elif max_strength < 0:
|
||||||
|
raise ValueError("Got a negative strength! The strength needs to be positive.")
|
||||||
|
|
||||||
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)
|
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
from _pytest.python_api import approx
|
from _pytest.python_api import approx
|
||||||
|
|
||||||
from dowhy.utils import plot, plot_adjacency_matrix
|
from dowhy.utils import plot, plot_adjacency_matrix
|
||||||
|
from dowhy.utils.networkx_plotting import plot_causal_graph_networkx
|
||||||
from dowhy.utils.plotting import _calc_arrow_width, bar_plot
|
from dowhy.utils.plotting import _calc_arrow_width, bar_plot
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +50,10 @@ def test_calc_arrow_width():
|
||||||
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
|
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
|
||||||
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
|
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
|
||||||
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
|
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
|
||||||
|
assert _calc_arrow_width(100, max_strength=0) == 4.1
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_calc_arrow_width(100, max_strength=-1)
|
||||||
|
|
||||||
|
|
||||||
def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error():
|
def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче