validate_graph -> run_validations, use inspect for `data` arg in validate_tpa. Restrict `falsify_graph` to those tests presented in the paper
Signed-off-by: eeulig <contact@eeulig.com>
This commit is contained in:
Родитель
379f188d67
Коммит
4e30c63fe1
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -6,6 +6,7 @@ import warnings
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from inspect import getfullargspec
|
||||
from itertools import permutations
|
||||
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Union
|
||||
|
||||
|
@ -50,7 +51,7 @@ class FalsifyConst(Enum):
|
|||
FALSIFY_METHODS = {
|
||||
FalsifyConst.VALIDATE_LMC: "LMC",
|
||||
FalsifyConst.VALIDATE_PD: "Faithfulness",
|
||||
FalsifyConst.VALIDATE_TPA: "tPa",
|
||||
FalsifyConst.VALIDATE_TPA: "TPa",
|
||||
FalsifyConst.VALIDATE_CM: "Causal Minimality",
|
||||
}
|
||||
|
||||
|
@ -179,7 +180,6 @@ def validate_tpa(
|
|||
causal_graph: DirectedGraph,
|
||||
causal_graph_reference: DirectedGraph,
|
||||
include_unconditional: bool = True,
|
||||
data: Optional[pd.DataFrame] = None,
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Graphical criterion to evaluate which pairwise parental d-separations (parental triples) in `causal_graph` are
|
||||
|
@ -191,7 +191,6 @@ def validate_tpa(
|
|||
:param causal_graph: Causal graph for which to evaluate parental d-separations (G')
|
||||
:param causal_graph_reference: Causal graph where we test if d-separation holds (G)
|
||||
:param include_unconditional: Test also unconditional independencies of root nodes.
|
||||
:param data: IGNORED! No data is needed for this validation method and thus the data argument is ignored!
|
||||
:return: Validation summary with number of d-separations implied by `causal_graph` and number of times these are
|
||||
violated in the graph `causal_graph_reference`.
|
||||
"""
|
||||
|
@ -380,7 +379,7 @@ def validate_cm(
|
|||
return validation_summary
|
||||
|
||||
|
||||
def validate_graph(
|
||||
def run_validations(
|
||||
causal_graph: DirectedGraph,
|
||||
data: pd.DataFrame,
|
||||
methods: Optional[Union[Callable, Tuple[Callable, ...], List[Callable]]] = partial(
|
||||
|
@ -391,7 +390,7 @@ def validate_graph(
|
|||
Validate a given causal graph using observational data and some given methods. If methods are provided, they must
|
||||
be wrapped in a partial object, with their respective parameters. E.g., if one wants to test the local
|
||||
Markov conditions and the pairwise dependencies (unconditional faithfulness), then call
|
||||
validate_graph(G, data, methods=(
|
||||
run_validations(G, data, methods=(
|
||||
partial(validate_lmc, independence_test=..., conditional_independence_test=...),
|
||||
partial(validate_pd, independence_test=...),
|
||||
)
|
||||
|
@ -402,7 +401,7 @@ def validate_graph(
|
|||
:param methods: Method functions wrapped in wrap_partial. E.g.
|
||||
wrap_partial(validate_lmc, data=data, independence_test=..., conditional_independence_test=...).
|
||||
If no methods are provided we run validate_lmc with optional keyword arguments provided to
|
||||
validate_graph.
|
||||
run_validations.
|
||||
:return: Validation summary as dict.
|
||||
"""
|
||||
|
||||
|
@ -412,7 +411,9 @@ def validate_graph(
|
|||
validation_summary = dict()
|
||||
|
||||
for m in methods:
|
||||
m_summary = m(causal_graph=causal_graph, data=data)
|
||||
if "data" in getfullargspec(m).args:
|
||||
m = partial(m, data=data)
|
||||
m_summary = m(causal_graph=causal_graph)
|
||||
m_name = m_summary.pop(FalsifyConst.METHOD)
|
||||
validation_summary[m_name] = m_summary
|
||||
|
||||
|
@ -510,8 +511,6 @@ class EvaluationResult:
|
|||
def falsify_graph(
|
||||
causal_graph: DirectedGraph,
|
||||
data: pd.DataFrame,
|
||||
methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
|
||||
suggestion_methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
|
||||
suggestions: bool = False,
|
||||
independence_test: Callable[[np.ndarray, np.ndarray], float] = kernel_based,
|
||||
conditional_independence_test: Callable[[np.ndarray, np.ndarray, np.ndarray], float] = kernel_based,
|
||||
|
@ -543,22 +542,22 @@ def falsify_graph(
|
|||
|
||||
By default, we only run 1 / `significance_level` permutations as those are enough to falsify a graph with type I
|
||||
error probability `significance_level` at some given `significance_level`. If you are interested in a more exact
|
||||
estimate of the p-value of whish to plot a histogram to see how the given DAG compares to random node permutations,
|
||||
estimate of the p-value or wish to plot a histogram to see how the given DAG compares to random node permutations,
|
||||
you should set `n_permutations` to some larger value (e.g. 100 or 1000). If `n_permutations=-1` we test on all
|
||||
n_nodes! permutations.
|
||||
n_nodes! permutations (the default if plot_histogram=True).
|
||||
|
||||
`methods` and `suggestion_methods` must be wrapped in partial(method, **kwargs) (c.f. `validate_graph`).
|
||||
Additionally, this method allows to return suggestions to the user (suggestions=True). This is done by testing for
|
||||
violations of causal minimality via `validate_cm`.
|
||||
|
||||
Related paper:
|
||||
Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023).
|
||||
Toward Falsifying Causal Graphs Using a Permutation-Based Test.
|
||||
https://arxiv.org/abs/2305.09565
|
||||
Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023).
|
||||
Toward Falsifying Causal Graphs Using a Permutation-Based Test.
|
||||
https://arxiv.org/abs/2305.09565
|
||||
|
||||
:param causal_graph: A directed acyclic graph (DAG).
|
||||
:param data: Observations of variables in the DAG.
|
||||
:param methods: Validation methods to perform.
|
||||
:param suggestion_methods: Methods to run on the given graph to provide additional suggestions.
|
||||
:param suggestions: Provide suggestions generated using the `suggestion_methods`.
|
||||
:param suggestions: Provide suggestions to the user. At the moment the only source of suggestions comes from
|
||||
validating causal minimality (using validate_cm).
|
||||
:param independence_test: Independence test to use for checking pairwise independencies.
|
||||
:param conditional_independence_test: Conditional independence test to use.
|
||||
:param significance_level: Significance level for the permutation test.
|
||||
|
@ -580,28 +579,19 @@ def falsify_graph(
|
|||
if not plot_kwargs:
|
||||
plot_kwargs = {}
|
||||
|
||||
# If no methods are provided, use default ones: validate_lmc, validate_tpa
|
||||
if not methods:
|
||||
methods = (
|
||||
partial(
|
||||
validate_lmc,
|
||||
independence_test=independence_test,
|
||||
conditional_independence_test=conditional_independence_test,
|
||||
significance_level=significance_ci,
|
||||
p_values_memory=p_values_memory,
|
||||
n_jobs=n_jobs,
|
||||
),
|
||||
partial(validate_tpa, causal_graph_reference=causal_graph),
|
||||
)
|
||||
elif isinstance(methods, Callable):
|
||||
methods = (methods,)
|
||||
methods = (
|
||||
partial(
|
||||
validate_lmc,
|
||||
independence_test=independence_test,
|
||||
conditional_independence_test=conditional_independence_test,
|
||||
significance_level=significance_ci,
|
||||
p_values_memory=p_values_memory,
|
||||
n_jobs=n_jobs,
|
||||
),
|
||||
partial(validate_tpa, causal_graph_reference=causal_graph),
|
||||
)
|
||||
|
||||
# If no suggestion methods are provided, but suggestions=True, use default ones: validate_cm
|
||||
if not suggestions:
|
||||
suggestion_methods = tuple()
|
||||
elif suggestions and isinstance(suggestion_methods, Callable):
|
||||
suggestion_methods = (suggestion_methods,)
|
||||
elif suggestions:
|
||||
if suggestions:
|
||||
suggestion_methods = (
|
||||
partial(
|
||||
validate_cm,
|
||||
|
@ -612,8 +602,10 @@ def falsify_graph(
|
|||
n_jobs=n_jobs,
|
||||
),
|
||||
)
|
||||
else:
|
||||
suggestion_methods = tuple()
|
||||
|
||||
summary_given = validate_graph(
|
||||
summary_given = run_validations(
|
||||
causal_graph,
|
||||
data,
|
||||
methods=methods + suggestion_methods,
|
||||
|
@ -902,7 +894,7 @@ def _permutation_based(
|
|||
perm_gen = _PermuteNodes(causal_graph, n_permutations=n_permutations, exclude_original_order=exclude_original_order)
|
||||
validation_summary = {FalsifyConst.PERM_GRAPHS: []}
|
||||
for permuted_graph in tqdm(perm_gen, desc="Test permutations of given graph", disable=not show_progress_bar):
|
||||
res = validate_graph(
|
||||
res = run_validations(
|
||||
causal_graph=permuted_graph,
|
||||
data=data,
|
||||
methods=methods,
|
||||
|
|
|
@ -11,8 +11,8 @@ from dowhy.datasets import generate_random_graph
|
|||
from dowhy.gcm.falsify import (
|
||||
FalsifyConst,
|
||||
_PermuteNodes,
|
||||
run_validations,
|
||||
validate_cm,
|
||||
validate_graph,
|
||||
validate_lmc,
|
||||
validate_pd,
|
||||
validate_tpa,
|
||||
|
@ -39,7 +39,7 @@ def _generate_simple_non_linear_data() -> pd.DataFrame:
|
|||
|
||||
@flaky(max_runs=1)
|
||||
def test_given_exclude_original_order_when_generating_permutations_then_return_correct_permutations():
|
||||
num_nodes = np.random.randint(1, 10)
|
||||
num_nodes = np.random.randint(2, 10)
|
||||
perms = set()
|
||||
G = generate_random_graph(n=num_nodes)
|
||||
perm_gen = _PermuteNodes(G, exclude_original_order=True, n_permutations=-1)
|
||||
|
@ -52,7 +52,7 @@ def test_given_exclude_original_order_when_generating_permutations_then_return_c
|
|||
|
||||
@flaky(max_runs=1)
|
||||
def test_given_not_exclude_original_order_when_generating_permutations_then_return_correct_permutations():
|
||||
num_nodes = np.random.randint(1, 10)
|
||||
num_nodes = np.random.randint(2, 10)
|
||||
found_orig_perm = False
|
||||
perms = set()
|
||||
G = generate_random_graph(n=num_nodes)
|
||||
|
@ -74,7 +74,7 @@ def test_given_correct_collider_when_validating_graph_then_report_no_violations(
|
|||
Y = np.random.normal(size=500)
|
||||
Z = 2 * X + 3 * Y + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
true_dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -101,7 +101,7 @@ def test_given_wrong_collider_when_validating_graph_then_report_violations():
|
|||
Y = np.random.normal(size=500)
|
||||
Z = 2 * X + 3 * Y + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -127,7 +127,7 @@ def test_given_correct_chain_when_validating_graph_then_report_no_violations():
|
|||
Y = 0.6 * X + np.random.normal(size=500)
|
||||
Z = Y + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
true_dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -154,7 +154,7 @@ def test_given_wrong_chain_when_validating_graph_then_report_violations():
|
|||
Y = 0.6 * X + np.random.normal(size=500)
|
||||
Z = Y + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -177,7 +177,7 @@ def test_given_empty_DAG_and_data_when_validating_graph_then_report_no_violation
|
|||
# Empty graph
|
||||
dag = nx.DiGraph()
|
||||
data = pd.DataFrame()
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -203,7 +203,7 @@ def test_given_correct_full_DAG_when_validating_graph_then_report_no_violations(
|
|||
X2 = 1.2 * X0 + 0.7 * X1 + np.random.normal(size=500)
|
||||
X3 = 0.6 * X0 + 0.8 * X1 + 1.3 * X2 + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, X3=X3))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -227,7 +227,7 @@ def test_given_correct_single_node_when_validating_graph_then_report_no_violatio
|
|||
dag = nx.DiGraph()
|
||||
dag.add_node("X")
|
||||
data = pd.DataFrame(data=dict(X=np.random.normal(size=500)))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -252,7 +252,7 @@ def test_given_correct_single_edge_when_validating_graph_then_report_no_violatio
|
|||
X0 = np.random.normal(size=500)
|
||||
X1 = 2 * X0 + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -280,7 +280,7 @@ def test_given_wrong_single_edge_when_validating_graph_then_report_violations():
|
|||
X0 = np.random.normal(size=500)
|
||||
X1 = 2 * X0 + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -304,7 +304,7 @@ def test_given_correct_categorical_when_validating_graph_then_report_no_violatio
|
|||
dag = nx.DiGraph([("X", "Z"), ("Z", "Y")])
|
||||
X, Y, Z = _generate_categorical_data()
|
||||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
|
@ -334,7 +334,7 @@ def test_given_non_minimal_DAG_when_validating_causal_minimality_then_report_vio
|
|||
X2 = np.random.normal(size=500)
|
||||
Y = 2 * X0 + 3 * X1 + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
given_dag,
|
||||
data,
|
||||
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
|
@ -352,7 +352,7 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol
|
|||
X2 = np.random.normal(size=500)
|
||||
Y = 2 * X0 + 3 * X1 + np.random.normal(size=500)
|
||||
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
|
||||
summary = validate_graph(
|
||||
summary = run_validations(
|
||||
dag,
|
||||
data,
|
||||
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
|
|
Загрузка…
Ссылка в новой задаче