Refactor gcm divergence unit tests

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Patrick Bloebaum 2022-09-19 10:18:36 -07:00 коммит произвёл Patrick Blöbaum
Родитель 93f1852b32
Коммит 3e42bacb43
1 изменённых файлов: 6 добавлений и 6 удалений

Просмотреть файл

@ -11,7 +11,7 @@ from dowhy.gcm.divergence import (
@flaky(max_runs=5)
def test_estimate_kl_divergence_continuous():
def test_given_simple_gaussian_data_when_estimate_kl_divergence_continuous_then_returns_expected_result():
X = np.random.normal(0, 1, 10000)
Y = np.random.normal(1, 1, 10000)
@ -20,7 +20,7 @@ def test_estimate_kl_divergence_continuous():
@flaky(max_runs=5)
def test_estimate_kl_divergence_categorical():
def test_given_simple_categorical_data_estimate_kl_divergence_categorical_then_returns_expected_result():
X = np.random.choice(4, 1000, replace=True, p=[0.25, 0.5, 0.125, 0.125]).astype(str)
Y = np.random.choice(4, 1000, replace=True, p=[0.5, 0.25, 0.125, 0.125]).astype(str)
@ -30,7 +30,7 @@ def test_estimate_kl_divergence_categorical():
)
def test_estimate_kl_divergence_of_probabilities():
def test_given_probability_vectors_when_estimate_kl_divergence_of_probabilities_then_returns_expected_result():
assert estimate_kl_divergence_of_probabilities(
np.array([[0.25, 0.5, 0.125, 0.125], [0.5, 0.25, 0.125, 0.125]]),
np.array([[0.5, 0.25, 0.125, 0.125], [0.25, 0.5, 0.125, 0.125]]),
@ -38,7 +38,7 @@ def test_estimate_kl_divergence_of_probabilities():
@flaky(max_runs=5)
def test_auto_estimate_kl_divergence_continuous():
def test_given_simple_gaussian_data_when_auto_estimate_kl_divergence_then_correctly_selects_continuous_version():
X = np.random.normal(0, 1, 10000)
Y = np.random.normal(1, 1, 10000)
@ -47,7 +47,7 @@ def test_auto_estimate_kl_divergence_continuous():
@flaky(max_runs=5)
def test_auto_estimate_kl_divergence_categorical():
def test_given_categorical_data_when_auto_estimate_kl_divergence_then_correctly_selects_categorical_version():
X = np.random.choice(4, 1000, replace=True, p=[0.25, 0.5, 0.125, 0.125]).astype(str)
Y = np.random.choice(4, 1000, replace=True, p=[0.5, 0.25, 0.125, 0.125]).astype(str)
@ -55,7 +55,7 @@ def test_auto_estimate_kl_divergence_categorical():
assert auto_estimate_kl_divergence(X, Y) == approx(0.25 * np.log(0.25 / 0.5) + 0.5 * np.log(0.5 / 0.25), abs=0.1)
def test_auto_estimate_kl_divergence_probabilities():
def test_given_probability_vectors_when_auto_estimate_kl_divergence_then_correctly_selects_probability_version():
assert auto_estimate_kl_divergence(
np.array([[0.25, 0.5, 0.125, 0.125], [0.5, 0.25, 0.125, 0.125]]),
np.array([[0.5, 0.25, 0.125, 0.125], [0.25, 0.5, 0.125, 0.125]]),