Fix bug in networkx plot function with 0 error strenghts
Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
4f317345dc
Коммит
72986a859d
|
@ -32,7 +32,7 @@ def plot_causal_graph_networkx(
|
|||
if (source, target) not in causal_strengths:
|
||||
causal_strengths[(source, target)] = strength
|
||||
|
||||
if strength is not None:
|
||||
if causal_strengths[(source, target)] is not None:
|
||||
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
|
||||
|
||||
if (source, target) not in colors:
|
||||
|
|
|
@ -176,6 +176,11 @@ def bar_plot(
|
|||
|
||||
|
||||
def _calc_arrow_width(strength: float, max_strength: float):
|
||||
if max_strength == 0:
|
||||
return 4.1
|
||||
elif max_strength < 0:
|
||||
raise ValueError("Got a negative strength! The strength needs to be positive.")
|
||||
|
||||
return 0.1 + 4.0 * float(abs(strength)) / float(max_strength)
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from dowhy.utils import plot, plot_adjacency_matrix
|
||||
from dowhy.utils.networkx_plotting import plot_causal_graph_networkx
|
||||
from dowhy.utils.plotting import _calc_arrow_width, bar_plot
|
||||
|
||||
|
||||
|
@ -48,6 +50,10 @@ def test_calc_arrow_width():
|
|||
assert _calc_arrow_width(0.5, max_strength=0.5) == approx(4.1, abs=0.01)
|
||||
assert _calc_arrow_width(0.35, max_strength=0.5) == approx(2.9, abs=0.01)
|
||||
assert _calc_arrow_width(100, max_strength=101) == approx(4.06, abs=0.01)
|
||||
assert _calc_arrow_width(100, max_strength=0) == 4.1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_calc_arrow_width(100, max_strength=-1)
|
||||
|
||||
|
||||
def test_given_misspecified_uncertainties_when_bar_plot_then_does_not_raise_error():
|
||||
|
|
Загрузка…
Ссылка в новой задаче