validate_graph -> run_validations, use inspect for `data` arg in validate_tpa. Restrict `falsify_graph` to those tests presented in the paper

Signed-off-by: eeulig <contact@eeulig.com>
This commit is contained in:
eeulig 2023-07-18 13:06:46 +02:00 коммит произвёл Patrick Blöbaum
Родитель 379f188d67
Коммит 4e30c63fe1
3 изменённых файлов: 56 добавлений и 72 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -6,6 +6,7 @@ import warnings
from dataclasses import dataclass, field
from enum import Enum, auto
from functools import partial
from inspect import getfullargspec
from itertools import permutations
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Union
@ -50,7 +51,7 @@ class FalsifyConst(Enum):
FALSIFY_METHODS = {
FalsifyConst.VALIDATE_LMC: "LMC",
FalsifyConst.VALIDATE_PD: "Faithfulness",
FalsifyConst.VALIDATE_TPA: "tPa",
FalsifyConst.VALIDATE_TPA: "TPa",
FalsifyConst.VALIDATE_CM: "Causal Minimality",
}
@ -179,7 +180,6 @@ def validate_tpa(
causal_graph: DirectedGraph,
causal_graph_reference: DirectedGraph,
include_unconditional: bool = True,
data: Optional[pd.DataFrame] = None,
) -> Dict[str, int]:
"""
Graphical criterion to evaluate which pairwise parental d-separations (parental triples) in `causal_graph` are
@ -191,7 +191,6 @@ def validate_tpa(
:param causal_graph: Causal graph for which to evaluate parental d-separations (G')
:param causal_graph_reference: Causal graph where we test if d-separation holds (G)
:param include_unconditional: Test also unconditional independencies of root nodes.
:param data: IGNORED! No data is needed for this validation method and thus the data argument is ignored!
:return: Validation summary with number of d-separations implied by `causal_graph` and number of times these are
violated in the graph `causal_graph_reference`.
"""
@ -380,7 +379,7 @@ def validate_cm(
return validation_summary
def validate_graph(
def run_validations(
causal_graph: DirectedGraph,
data: pd.DataFrame,
methods: Optional[Union[Callable, Tuple[Callable, ...], List[Callable]]] = partial(
@ -391,7 +390,7 @@ def validate_graph(
Validate a given causal graph using observational data and some given methods. If methods are provided, they must
be wrapped in a partial object, with their respective parameters. E.g., if one wants to test the local
Markov conditions and the pairwise dependencies (unconditional faithfulness), then call
validate_graph(G, data, methods=(
run_validations(G, data, methods=(
partial(validate_lmc, independence_test=..., conditional_independence_test=...),
partial(validate_pd, independence_test=...),
)
@ -402,7 +401,7 @@ def validate_graph(
:param methods: Method functions wrapped in wrap_partial. E.g.
wrap_partial(validate_lmc, data=data, independence_test=..., conditional_independence_test=...).
If no methods are provided we run validate_lmc with optional keyword arguments provided to
validate_graph.
run_validations.
:return: Validation summary as dict.
"""
@ -412,7 +411,9 @@ def validate_graph(
validation_summary = dict()
for m in methods:
m_summary = m(causal_graph=causal_graph, data=data)
if "data" in getfullargspec(m).args:
m = partial(m, data=data)
m_summary = m(causal_graph=causal_graph)
m_name = m_summary.pop(FalsifyConst.METHOD)
validation_summary[m_name] = m_summary
@ -510,8 +511,6 @@ class EvaluationResult:
def falsify_graph(
causal_graph: DirectedGraph,
data: pd.DataFrame,
methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
suggestion_methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
suggestions: bool = False,
independence_test: Callable[[np.ndarray, np.ndarray], float] = kernel_based,
conditional_independence_test: Callable[[np.ndarray, np.ndarray, np.ndarray], float] = kernel_based,
@ -543,22 +542,22 @@ def falsify_graph(
By default, we only run 1 / `significance_level` permutations as those are enough to falsify a graph with type I
error probability `significance_level` at some given `significance_level`. If you are interested in a more exact
estimate of the p-value of whish to plot a histogram to see how the given DAG compares to random node permutations,
estimate of the p-value or wish to plot a histogram to see how the given DAG compares to random node permutations,
you should set `n_permutations` to some larger value (e.g. 100 or 1000). If `n_permutations=-1` we test on all
n_nodes! permutations.
n_nodes! permutations (the default if plot_histogram=True).
`methods` and `suggestion_methods` must be wrapped in partial(method, **kwargs) (c.f. `validate_graph`).
Additionally, this method allows to return suggestions to the user (suggestions=True). This is done by testing for
violations of causal minimality via `validate_cm`.
Related paper:
Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023).
Toward Falsifying Causal Graphs Using a Permutation-Based Test.
https://arxiv.org/abs/2305.09565
Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023).
Toward Falsifying Causal Graphs Using a Permutation-Based Test.
https://arxiv.org/abs/2305.09565
:param causal_graph: A directed acyclic graph (DAG).
:param data: Observations of variables in the DAG.
:param methods: Validation methods to perform.
:param suggestion_methods: Methods to run on the given graph to provide additional suggestions.
:param suggestions: Provide suggestions generated using the `suggestion_methods`.
:param suggestions: Provide suggestions to the user. At the moment the only source of suggestions comes from
validating causal minimality (using validate_cm).
:param independence_test: Independence test to use for checking pairwise independencies.
:param conditional_independence_test: Conditional independence test to use.
:param significance_level: Significance level for the permutation test.
@ -580,28 +579,19 @@ def falsify_graph(
if not plot_kwargs:
plot_kwargs = {}
# If no methods are provided, use default ones: validate_lmc, validate_tpa
if not methods:
methods = (
partial(
validate_lmc,
independence_test=independence_test,
conditional_independence_test=conditional_independence_test,
significance_level=significance_ci,
p_values_memory=p_values_memory,
n_jobs=n_jobs,
),
partial(validate_tpa, causal_graph_reference=causal_graph),
)
elif isinstance(methods, Callable):
methods = (methods,)
methods = (
partial(
validate_lmc,
independence_test=independence_test,
conditional_independence_test=conditional_independence_test,
significance_level=significance_ci,
p_values_memory=p_values_memory,
n_jobs=n_jobs,
),
partial(validate_tpa, causal_graph_reference=causal_graph),
)
# If no suggestion methods are provided, but suggestions=True, use default ones: validate_cm
if not suggestions:
suggestion_methods = tuple()
elif suggestions and isinstance(suggestion_methods, Callable):
suggestion_methods = (suggestion_methods,)
elif suggestions:
if suggestions:
suggestion_methods = (
partial(
validate_cm,
@ -612,8 +602,10 @@ def falsify_graph(
n_jobs=n_jobs,
),
)
else:
suggestion_methods = tuple()
summary_given = validate_graph(
summary_given = run_validations(
causal_graph,
data,
methods=methods + suggestion_methods,
@ -902,7 +894,7 @@ def _permutation_based(
perm_gen = _PermuteNodes(causal_graph, n_permutations=n_permutations, exclude_original_order=exclude_original_order)
validation_summary = {FalsifyConst.PERM_GRAPHS: []}
for permuted_graph in tqdm(perm_gen, desc="Test permutations of given graph", disable=not show_progress_bar):
res = validate_graph(
res = run_validations(
causal_graph=permuted_graph,
data=data,
methods=methods,

Просмотреть файл

@ -11,8 +11,8 @@ from dowhy.datasets import generate_random_graph
from dowhy.gcm.falsify import (
FalsifyConst,
_PermuteNodes,
run_validations,
validate_cm,
validate_graph,
validate_lmc,
validate_pd,
validate_tpa,
@ -39,7 +39,7 @@ def _generate_simple_non_linear_data() -> pd.DataFrame:
@flaky(max_runs=1)
def test_given_exclude_original_order_when_generating_permutations_then_return_correct_permutations():
num_nodes = np.random.randint(1, 10)
num_nodes = np.random.randint(2, 10)
perms = set()
G = generate_random_graph(n=num_nodes)
perm_gen = _PermuteNodes(G, exclude_original_order=True, n_permutations=-1)
@ -52,7 +52,7 @@ def test_given_exclude_original_order_when_generating_permutations_then_return_c
@flaky(max_runs=1)
def test_given_not_exclude_original_order_when_generating_permutations_then_return_correct_permutations():
num_nodes = np.random.randint(1, 10)
num_nodes = np.random.randint(2, 10)
found_orig_perm = False
perms = set()
G = generate_random_graph(n=num_nodes)
@ -74,7 +74,7 @@ def test_given_correct_collider_when_validating_graph_then_report_no_violations(
Y = np.random.normal(size=500)
Z = 2 * X + 3 * Y + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
summary = run_validations(
true_dag,
data,
methods=(
@ -101,7 +101,7 @@ def test_given_wrong_collider_when_validating_graph_then_report_violations():
Y = np.random.normal(size=500)
Z = 2 * X + 3 * Y + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -127,7 +127,7 @@ def test_given_correct_chain_when_validating_graph_then_report_no_violations():
Y = 0.6 * X + np.random.normal(size=500)
Z = Y + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
summary = run_validations(
true_dag,
data,
methods=(
@ -154,7 +154,7 @@ def test_given_wrong_chain_when_validating_graph_then_report_violations():
Y = 0.6 * X + np.random.normal(size=500)
Z = Y + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -177,7 +177,7 @@ def test_given_empty_DAG_and_data_when_validating_graph_then_report_no_violation
# Empty graph
dag = nx.DiGraph()
data = pd.DataFrame()
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -203,7 +203,7 @@ def test_given_correct_full_DAG_when_validating_graph_then_report_no_violations(
X2 = 1.2 * X0 + 0.7 * X1 + np.random.normal(size=500)
X3 = 0.6 * X0 + 0.8 * X1 + 1.3 * X2 + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, X3=X3))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -227,7 +227,7 @@ def test_given_correct_single_node_when_validating_graph_then_report_no_violatio
dag = nx.DiGraph()
dag.add_node("X")
data = pd.DataFrame(data=dict(X=np.random.normal(size=500)))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -252,7 +252,7 @@ def test_given_correct_single_edge_when_validating_graph_then_report_no_violatio
X0 = np.random.normal(size=500)
X1 = 2 * X0 + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -280,7 +280,7 @@ def test_given_wrong_single_edge_when_validating_graph_then_report_violations():
X0 = np.random.normal(size=500)
X1 = 2 * X0 + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -304,7 +304,7 @@ def test_given_correct_categorical_when_validating_graph_then_report_no_violatio
dag = nx.DiGraph([("X", "Z"), ("Z", "Y")])
X, Y, Z = _generate_categorical_data()
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=(
@ -334,7 +334,7 @@ def test_given_non_minimal_DAG_when_validating_causal_minimality_then_report_vio
X2 = np.random.normal(size=500)
Y = 2 * X0 + 3 * X1 + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
summary = validate_graph(
summary = run_validations(
given_dag,
data,
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
@ -352,7 +352,7 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol
X2 = np.random.normal(size=500)
Y = 2 * X0 + 3 * X1 + np.random.normal(size=500)
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
summary = validate_graph(
summary = run_validations(
dag,
data,
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),