2022-06-01 21:46:37 +03:00
|
|
|
import networkx as nx
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
from flaky import flaky
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
from dowhy.gcm import (
|
|
|
|
InvertibleStructuralCausalModel,
|
|
|
|
RejectionResult,
|
|
|
|
auto,
|
|
|
|
fit,
|
|
|
|
kernel_based,
|
|
|
|
refute_causal_structure,
|
|
|
|
refute_invertible_model,
|
|
|
|
)
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
def _generate_simple_non_linear_data() -> pd.DataFrame:
|
|
|
|
X = np.random.normal(loc=0, scale=1, size=5000)
|
2022-08-20 01:51:25 +03:00
|
|
|
Y = X**2 + np.random.normal(loc=0, scale=1, size=5000)
|
2022-06-01 21:46:37 +03:00
|
|
|
Z = np.exp(-Y) + np.random.normal(loc=0, scale=1, size=5000)
|
|
|
|
|
|
|
|
return pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=5)
|
2023-02-10 20:45:54 +03:00
|
|
|
def test_given_collider_when_refuting_causal_structure_then_not_rejected():
|
2022-06-01 21:46:37 +03:00
|
|
|
# collider: X->Z<-Y
|
2022-08-20 01:51:25 +03:00
|
|
|
collider_dag = nx.DiGraph([("X", "Z"), ("Y", "Z")])
|
2022-06-01 21:46:37 +03:00
|
|
|
X = np.random.normal(size=500)
|
|
|
|
Y = np.random.normal(size=500)
|
|
|
|
Z = 2 * X + 3 * Y + np.random.normal(size=500)
|
|
|
|
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
|
|
|
rejection_result, rejection_summary = refute_causal_structure(collider_dag, data)
|
|
|
|
|
|
|
|
assert rejection_result == RejectionResult.NOT_REJECTED
|
2022-08-20 01:51:25 +03:00
|
|
|
assert rejection_summary["X"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["X"]["edge_dependence_test"] == dict()
|
|
|
|
assert rejection_summary["Y"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["Y"]["edge_dependence_test"] == dict()
|
|
|
|
assert rejection_summary["Z"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["Z"]["edge_dependence_test"]["X"]["success"] == True
|
|
|
|
assert rejection_summary["Z"]["edge_dependence_test"]["Y"]["success"] == True
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=5)
|
2023-02-10 20:45:54 +03:00
|
|
|
def test_given_chain_when_refuting_causal_structure_then_not_rejected():
|
2022-06-01 21:46:37 +03:00
|
|
|
# chain: X->Z->Y
|
2022-08-20 01:51:25 +03:00
|
|
|
chain_dag = nx.DiGraph([("X", "Z"), ("Z", "Y")])
|
2022-06-01 21:46:37 +03:00
|
|
|
X = np.random.normal(size=500)
|
|
|
|
Z = 2 * X + np.random.normal(size=500)
|
|
|
|
Y = 3 * Z + np.random.normal(size=500)
|
|
|
|
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
|
|
|
rejection_result, rejection_summary = refute_causal_structure(chain_dag, data)
|
|
|
|
|
|
|
|
assert rejection_result == RejectionResult.NOT_REJECTED
|
2022-08-20 01:51:25 +03:00
|
|
|
assert rejection_summary["X"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["X"]["edge_dependence_test"] == dict()
|
|
|
|
assert rejection_summary["Z"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["Z"]["edge_dependence_test"]["X"]["success"] == True
|
|
|
|
assert rejection_summary["Y"]["local_markov_test"]["success"] == True
|
|
|
|
assert rejection_summary["Y"]["edge_dependence_test"]["Z"]["success"] == True
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=5)
|
2023-02-10 20:45:54 +03:00
|
|
|
def test_given_fork_when_refuting_causal_structure_then_not_rejected():
|
2022-06-01 21:46:37 +03:00
|
|
|
# fork: X<-Z->Y
|
2022-08-20 01:51:25 +03:00
|
|
|
fork_dag = nx.DiGraph([("Z", "X"), ("Z", "Y")])
|
2022-06-01 21:46:37 +03:00
|
|
|
Z = np.random.normal(size=500)
|
|
|
|
X = 2 * Z + np.random.normal(size=500)
|
|
|
|
Y = 3 * Z + np.random.normal(size=500)
|
|
|
|
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
|
|
|
rejection_result, rejection_summary = refute_causal_structure(fork_dag, data)
|
|
|
|
|
|
|
|
assert rejection_result == RejectionResult.NOT_REJECTED
|
2022-08-20 01:51:25 +03:00
|
|
|
assert rejection_summary["Z"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["Z"]["edge_dependence_test"] == dict()
|
|
|
|
assert rejection_summary["X"]["local_markov_test"]["success"] == True
|
|
|
|
assert rejection_summary["X"]["edge_dependence_test"]["Z"]["success"] == True
|
|
|
|
assert rejection_summary["Y"]["local_markov_test"]["success"] == True
|
|
|
|
assert rejection_summary["Y"]["edge_dependence_test"]["Z"]["success"] == True
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=5)
|
2023-02-10 20:45:54 +03:00
|
|
|
def test_given_general_dag_when_refuting_causal_structure_then_not_rejected():
|
2022-06-01 21:46:37 +03:00
|
|
|
# general DAG: X<-Z->Y, X->Y
|
2022-08-20 01:51:25 +03:00
|
|
|
general_dag = nx.DiGraph([("Z", "X"), ("Z", "Y"), ("X", "Y")])
|
2022-06-01 21:46:37 +03:00
|
|
|
Z = np.random.normal(size=500)
|
|
|
|
X = 2 * Z + np.random.normal(size=500)
|
|
|
|
Y = 2 * Z + 3 * X + np.random.normal(size=500)
|
|
|
|
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
|
|
|
rejection_result, rejection_summary = refute_causal_structure(general_dag, data)
|
|
|
|
|
|
|
|
assert rejection_result == RejectionResult.NOT_REJECTED
|
2022-08-20 01:51:25 +03:00
|
|
|
assert rejection_summary["Z"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["Z"]["edge_dependence_test"] == dict()
|
|
|
|
assert rejection_summary["X"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["X"]["edge_dependence_test"]["Z"]["success"] == True
|
|
|
|
assert rejection_summary["Y"]["local_markov_test"] == dict()
|
|
|
|
assert rejection_summary["Y"]["edge_dependence_test"]["Z"]["success"] == True
|
|
|
|
assert rejection_summary["Y"]["edge_dependence_test"]["X"]["success"] == True
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=5)
|
2023-02-10 20:45:54 +03:00
|
|
|
def test_given_fdr_bh_when_refuting_causal_structure_then_return_correct_adjusted_p_values():
|
2022-06-01 21:46:37 +03:00
|
|
|
# fork: X<-Z->Y
|
2022-08-20 01:51:25 +03:00
|
|
|
fork_dag = nx.DiGraph([("Z", "X"), ("Z", "Y")])
|
2022-06-01 21:46:37 +03:00
|
|
|
Z = np.random.normal(size=500)
|
|
|
|
X = 2 * Z + np.random.normal(size=500)
|
|
|
|
Y = 3 * Z + np.random.normal(size=500)
|
|
|
|
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
2022-08-20 01:51:25 +03:00
|
|
|
rejection_result, rejection_summary = refute_causal_structure(fork_dag, data, fdr_control_method="fdr_bh")
|
|
|
|
|
|
|
|
assert (
|
|
|
|
rejection_summary["X"]["local_markov_test"]["fdr_adjusted_p_value"]
|
|
|
|
>= rejection_summary["X"]["local_markov_test"]["p_value"]
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
rejection_summary["X"]["edge_dependence_test"]["Z"]["fdr_adjusted_p_value"]
|
|
|
|
>= rejection_summary["X"]["edge_dependence_test"]["Z"]["p_value"]
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
rejection_summary["Y"]["local_markov_test"]["fdr_adjusted_p_value"]
|
|
|
|
>= rejection_summary["Y"]["local_markov_test"]["p_value"]
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
rejection_summary["Y"]["edge_dependence_test"]["Z"]["fdr_adjusted_p_value"]
|
|
|
|
>= rejection_summary["Y"]["edge_dependence_test"]["Z"]["p_value"]
|
|
|
|
)
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_when_using_refute_causal_structure_without_fdrc_then_nans_for_adjusted_p_values_are_returned():
|
2022-08-20 01:51:25 +03:00
|
|
|
fork_dag = nx.DiGraph([("Z", "X"), ("Z", "Y")])
|
2022-06-01 21:46:37 +03:00
|
|
|
Z = np.random.normal(size=500)
|
|
|
|
X = 2 * Z + np.random.normal(size=500)
|
|
|
|
Y = 3 * Z + np.random.normal(size=500)
|
|
|
|
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
|
|
|
_, rejection_summary = refute_causal_structure(fork_dag, data, fdr_control_method=None)
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
assert np.isnan(rejection_summary["X"]["local_markov_test"]["fdr_adjusted_p_value"])
|
|
|
|
assert np.isnan(rejection_summary["X"]["edge_dependence_test"]["Z"]["fdr_adjusted_p_value"])
|
|
|
|
assert np.isnan(rejection_summary["Y"]["local_markov_test"]["fdr_adjusted_p_value"])
|
|
|
|
assert np.isnan(rejection_summary["Y"]["edge_dependence_test"]["Z"]["fdr_adjusted_p_value"])
|
2022-06-01 21:46:37 +03:00
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
assert not np.isnan(rejection_summary["X"]["local_markov_test"]["p_value"])
|
|
|
|
assert not np.isnan(rejection_summary["X"]["edge_dependence_test"]["Z"]["p_value"])
|
|
|
|
assert not np.isnan(rejection_summary["Y"]["local_markov_test"]["p_value"])
|
|
|
|
assert not np.isnan(rejection_summary["Y"]["edge_dependence_test"]["Z"]["p_value"])
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
2023-07-18 17:06:56 +03:00
|
|
|
@flaky(max_runs=5)
|
2022-06-01 21:46:37 +03:00
|
|
|
def test_given_non_linear_data_and_correct_dag_when_refute_invertible_model_then_not_reject_model():
|
|
|
|
data = _generate_simple_non_linear_data()
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
causal_model = InvertibleStructuralCausalModel(nx.DiGraph([("X", "Y"), ("Y", "Z")])) # X->Y->Z
|
2022-06-01 21:46:37 +03:00
|
|
|
auto.assign_causal_mechanisms(causal_model, data, auto.AssignmentQuality.GOOD)
|
|
|
|
|
|
|
|
fit(causal_model, data)
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
assert (
|
|
|
|
refute_invertible_model(
|
2023-07-18 17:06:56 +03:00
|
|
|
causal_model,
|
|
|
|
data,
|
|
|
|
independence_test=lambda x, y: kernel_based(x, y, use_bootstrap=False),
|
2022-08-20 01:51:25 +03:00
|
|
|
)
|
|
|
|
== RejectionResult.NOT_REJECTED
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
refute_invertible_model(
|
|
|
|
causal_model,
|
|
|
|
data,
|
2023-07-18 17:06:56 +03:00
|
|
|
independence_test=lambda x, y: kernel_based(x, y, use_bootstrap=False),
|
2022-08-20 01:51:25 +03:00
|
|
|
fdr_control_method="fdr_bh",
|
|
|
|
)
|
|
|
|
== RejectionResult.NOT_REJECTED
|
|
|
|
)
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=2)
|
|
|
|
def test_given_non_linear_data_and_incorrect_dag_when_refute_invertible_model_then_reject_model():
|
|
|
|
data = _generate_simple_non_linear_data()
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
causal_model = InvertibleStructuralCausalModel(nx.DiGraph([("Z", "Y"), ("Y", "X")])) # X<-Y<-Z
|
2022-06-01 21:46:37 +03:00
|
|
|
auto.assign_causal_mechanisms(causal_model, data, auto.AssignmentQuality.GOOD)
|
|
|
|
|
|
|
|
fit(causal_model, data)
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
assert (
|
|
|
|
refute_invertible_model(
|
|
|
|
causal_model, data, independence_test=lambda x, y: kernel_based(x, y, bootstrap_num_runs=5)
|
|
|
|
)
|
|
|
|
== RejectionResult.REJECTED
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
refute_invertible_model(
|
|
|
|
causal_model,
|
|
|
|
data,
|
|
|
|
independence_test=lambda x, y: kernel_based(x, y, bootstrap_num_runs=5),
|
|
|
|
fdr_control_method="fdr_bh",
|
|
|
|
)
|
|
|
|
== RejectionResult.REJECTED
|
|
|
|
)
|
2022-06-01 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
@flaky(max_runs=3)
|
|
|
|
def test_given_non_linear_data_and_incorrect_dag_with_collider_when_refute_invertible_model_then_reject_model():
|
|
|
|
data = _generate_simple_non_linear_data()
|
2022-08-20 01:51:25 +03:00
|
|
|
causal_model = InvertibleStructuralCausalModel(nx.DiGraph([("X", "Y"), ("Z", "Y")])) # X->Y<-Z
|
2022-06-01 21:46:37 +03:00
|
|
|
auto.assign_causal_mechanisms(causal_model, data, auto.AssignmentQuality.GOOD)
|
|
|
|
|
|
|
|
fit(causal_model, data)
|
|
|
|
|
2022-08-20 01:51:25 +03:00
|
|
|
assert (
|
|
|
|
refute_invertible_model(
|
|
|
|
causal_model, data, independence_test=lambda x, y: kernel_based(x, y, bootstrap_num_runs=10)
|
|
|
|
)
|
|
|
|
== RejectionResult.REJECTED
|
|
|
|
)
|
|
|
|
assert (
|
|
|
|
refute_invertible_model(
|
|
|
|
causal_model,
|
|
|
|
data,
|
|
|
|
independence_test=lambda x, y: kernel_based(x, y, bootstrap_num_runs=10),
|
|
|
|
fdr_control_method="fdr_bh",
|
|
|
|
)
|
|
|
|
== RejectionResult.REJECTED
|
|
|
|
)
|