remove wrap_partial, simplify validate_graph, None for missing p-values

Signed-off-by: eeulig <contact@eeulig.com>
This commit is contained in:
eeulig 2023-07-13 18:29:36 +02:00 коммит произвёл Patrick Blöbaum
Родитель 70b6b983b0
Коммит 379f188d67
3 изменённых файлов: 258 добавлений и 257 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -4,8 +4,8 @@ Functions in this module should be considered experimental, meaning there might
"""
import warnings
from dataclasses import dataclass, field
from enum import Enum
from functools import partial, update_wrapper
from enum import Enum, auto
from functools import partial
from itertools import permutations
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Union
@ -24,26 +24,35 @@ from dowhy.gcm.util.general import set_random_seed
from dowhy.gcm.validation import _get_non_descendants
from dowhy.graph import DirectedGraph, get_ordered_predecessors
FALSIFY_METHODS = {
"validate_lmc": "LMC",
"validate_pd": "Faithfulness",
"validate_tpa": "tPa",
"validate_cm": "Causal Minimality",
}
VIOLATION_COLOR = "red"
COLORS = list(mcolors.TABLEAU_COLORS.values())
class FalsifyConst(Enum):
N_VIOLATIONS = 0
N_TESTS = 1
P_VALUE = 2
P_VALUES = 3
GIVEN_VIOLATIONS = 4
PERM_VIOLATIONS = 5
F_GIVEN_VIOLATIONS = 6
F_PERM_VIOLATIONS = 7
LOCAL_VIOLATION_INSIGHT = 8
N_VIOLATIONS = auto()
N_TESTS = auto()
P_VALUE = auto()
P_VALUES = auto()
GIVEN_VIOLATIONS = auto()
PERM_VIOLATIONS = auto()
F_GIVEN_VIOLATIONS = auto()
F_PERM_VIOLATIONS = auto()
LOCAL_VIOLATION_INSIGHT = auto()
METHOD = auto()
VALIDATE_LMC = auto()
VALIDATE_TPA = auto()
VALIDATE_PD = auto()
VALIDATE_CM = auto()
PERM_GRAPHS = auto()
MEC = auto()
FALSIFY_METHODS = {
FalsifyConst.VALIDATE_LMC: "LMC",
FalsifyConst.VALIDATE_PD: "Faithfulness",
FalsifyConst.VALIDATE_TPA: "tPa",
FalsifyConst.VALIDATE_CM: "Causal Minimality",
}
@dataclass
@ -58,7 +67,7 @@ class _PValuesMemory:
def add_p_value(
self,
p_value: float,
p_value: Union[None, float],
X: Union[Set, List, str],
Y: Union[Set, List, str],
Z: Optional[Union[Set, List, str]] = None,
@ -118,7 +127,12 @@ def validate_lmc(
if p_values_memory is None:
p_values_memory = _PValuesMemory()
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: 0, FalsifyConst.P_VALUES: dict()}
validation_summary = {
FalsifyConst.METHOD: FalsifyConst.VALIDATE_LMC,
FalsifyConst.N_VIOLATIONS: 0,
FalsifyConst.N_TESTS: 0,
FalsifyConst.P_VALUES: dict(),
}
# Find out which tests to do
triples = _get_parental_triples(causal_graph, include_unconditional)
@ -126,7 +140,7 @@ def validate_lmc(
for node, non_desc, parents in triples:
if not (node, non_desc, parents) in p_values_memory:
to_test.append((node, non_desc, parents))
p_values_memory.add_p_value(-1, node, non_desc, parents) # Placeholder
p_values_memory.add_p_value(None, node, non_desc, parents) # Placeholder
# Parallelize over tests
random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(to_test))
@ -150,7 +164,7 @@ def validate_lmc(
# Summarize
for node, non_desc, parents in triples:
lmc_p_value = p_values_memory.get_p_value(node, non_desc, parents)
if lmc_p_value != -1:
if lmc_p_value is not None:
validation_summary[FalsifyConst.N_TESTS] += 1
validation_summary[FalsifyConst.P_VALUES][(node, non_desc)] = (
lmc_p_value,
@ -165,21 +179,27 @@ def validate_tpa(
causal_graph: DirectedGraph,
causal_graph_reference: DirectedGraph,
include_unconditional: bool = True,
data: Optional[pd.DataFrame] = None,
) -> Dict[str, int]:
"""
Graphical criterion to evaluate which pairwise parental d-separations in `causal_graph` are violated, assuming
`causal_graph_reference` is the ground truth graph. If none are violated, then both graphs lie in the same Markov
equivalence class.
Graphical criterion to evaluate which pairwise parental d-separations (parental triples) in `causal_graph` are
violated, assuming `causal_graph_reference` is the ground truth graph. If none are violated, then both graphs lie
in the same Markov equivalence class.
Specifically we test:
X _|_G' Y | Z and X _/|_G Y | Z for Y \in ND{X}^G', Z = PA{X}^G
:param causal_graph: Causal graph for which to evaluate parental d-separations (G')
:param causal_graph_reference: Causal graph where we test if d-separation holds (G)
:param include_unconditional: Test also unconditional independencies of root nodes.
:param data: IGNORED! No data is needed for this validation method and thus the data argument is ignored!
:return: Validation summary with number of d-separations implied by `causal_graph` and number of times these are
violated in the graph `causal_graph_reference`.
"""
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: 0}
validation_summary = {
FalsifyConst.METHOD: FalsifyConst.VALIDATE_TPA,
FalsifyConst.N_VIOLATIONS: 0,
FalsifyConst.N_TESTS: 0,
}
triples = _get_parental_triples(causal_graph, include_unconditional)
for node, non_desc, parents in triples:
@ -231,7 +251,12 @@ def validate_pd(
if n_pairs > len(pairs):
raise ValueError(f"n_pairs ({n_pairs}) > number of pairs in the DAG ({len(pairs)})")
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: n_pairs, FalsifyConst.P_VALUES: dict()}
validation_summary = {
FalsifyConst.METHOD: FalsifyConst.VALIDATE_PD,
FalsifyConst.N_VIOLATIONS: 0,
FalsifyConst.N_TESTS: n_pairs,
FalsifyConst.P_VALUES: dict(),
}
pair_idxs = np.random.choice(len(pairs), size=n_pairs, replace=False)
# Find out which tests to do
@ -240,7 +265,7 @@ def validate_pd(
ancestor, node = pairs[pair_idx]
if not (ancestor, node) in p_values_memory:
to_test.append((ancestor, node))
p_values_memory.add_p_value(-1, ancestor, node) # Placeholder
p_values_memory.add_p_value(None, ancestor, node) # Placeholder
# Parallelize over tests
random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(to_test))
@ -265,7 +290,7 @@ def validate_pd(
for pair_idx in pair_idxs:
ancestor, node = pairs[pair_idx]
p_value = p_values_memory.get_p_value(ancestor, node)
if p_value != -1:
if p_value is not None:
validation_summary[FalsifyConst.P_VALUES][(node, ancestor)] = (p_value, p_value > significance_level)
if p_value > significance_level:
validation_summary[FalsifyConst.N_VIOLATIONS] += 1
@ -301,7 +326,12 @@ def validate_cm(
if p_values_memory is None:
p_values_memory = _PValuesMemory()
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: 0, FalsifyConst.P_VALUES: dict()}
validation_summary = {
FalsifyConst.METHOD: FalsifyConst.VALIDATE_CM,
FalsifyConst.N_VIOLATIONS: 0,
FalsifyConst.N_TESTS: 0,
FalsifyConst.P_VALUES: dict(),
}
# Find out which tests to do
triples = []
@ -314,7 +344,7 @@ def validate_cm(
triples.append((node, p, other_parents))
if not (node, p, other_parents) in p_values_memory:
to_test.append((node, p, other_parents))
p_values_memory.add_p_value(-1, node, p, other_parents) # Placeholder
p_values_memory.add_p_value(None, node, p, other_parents) # Placeholder
# Parallelize over tests
random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(to_test))
@ -338,7 +368,7 @@ def validate_cm(
# Summarize
for node, p, other_parents in triples:
p_value = p_values_memory.get_p_value(node, p, other_parents)
if p_value != -1:
if p_value is not None:
validation_summary[FalsifyConst.N_TESTS] += 1
validation_summary[FalsifyConst.P_VALUES][(node, p, tuple(other_parents))] = (
p_value,
@ -352,27 +382,20 @@ def validate_cm(
def validate_graph(
causal_graph: DirectedGraph,
data: Optional[pd.DataFrame] = None,
methods: Optional[Union[Callable, Tuple[Callable, ...], List[Callable]]] = None,
independence_test: Optional[Callable[[np.ndarray, np.ndarray], float]] = kernel_based,
conditional_independence_test: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray], float]] = kernel_based,
significance_level: float = 0.05,
n_jobs: Optional[int] = None,
data: pd.DataFrame,
methods: Optional[Union[Callable, Tuple[Callable, ...], List[Callable]]] = partial(
validate_lmc, independence_test=kernel_based, conditional_independence_test=kernel_based
),
) -> Dict[str, Dict]:
"""
Validate a given causal graph using observational data and some given methods. If methods are provided, they must
be wrapped in a wrapped_partial object, with their respective parameters. E.g., if one wants to test the local
be wrapped in a partial object, with their respective parameters. E.g., if one wants to test the local
Markov conditions and the pairwise dependencies (unconditional faithfulness), then call
validate_graph(G, methods=(
wrap_partial(validate_lmc, data=data, independence_test=..., conditional_independence_test=...),
wrap_partial(validate_pd, data=data, independence_test=...),
validate_graph(G, data, methods=(
partial(validate_lmc, independence_test=..., conditional_independence_test=...),
partial(validate_pd, independence_test=...),
)
)
If called with methods=None, then we only test lmc and expect that the data is provided via data=data. Optionally,
one can provide additional arguments to this test via the respective keyword arguments.
WARNING: If methods are provided, the optional keywords arguments are ignored and overwritten by the ones set in
wrap_partial(validate_lmc, ...), or the default arguments by the respective method if not provided in
wrap_partial. E.g. `independence_test` should be provided in wrap_partial(method, independence_test=...).
:param causal_graph: A directed acyclic graph (DAG).
:param data: Observations of variables in the DAG.
@ -380,33 +403,18 @@ def validate_graph(
wrap_partial(validate_lmc, data=data, independence_test=..., conditional_independence_test=...).
If no methods are provided we run validate_lmc with optional keyword arguments provided to
validate_graph.
:param independence_test: Test to use for unconditional independencies (only used if include_unconditional=True)
:param conditional_independence_test: Conditional independence test to use for checking local Markov condition.
:param significance_level: Significance level for (conditional) independence tests.
:param n_jobs: Number of jobs to use for parallel execution of (conditional) independence tests.
:return: Validation summary as dict.
"""
if not methods:
assert data is not None, "If (partial) methods are not given, then must provide data!"
methods = (
wrap_partial(
validate_lmc,
data=data,
independence_test=independence_test,
conditional_independence_test=conditional_independence_test,
significance_level=significance_level,
n_jobs=n_jobs,
),
)
elif not isinstance(methods, (tuple, list)):
if not isinstance(methods, (tuple, list)):
methods = (methods,)
validation_summary = dict()
for m in methods:
m_summary = m(causal_graph=causal_graph)
validation_summary[m.__name__] = m_summary
m_summary = m(causal_graph=causal_graph, data=data)
m_name = m_summary.pop(FalsifyConst.METHOD)
validation_summary[m_name] = m_summary
return validation_summary
@ -420,8 +428,6 @@ class EvaluationResult:
Attributes
----------
methods : tuple
Tuple containing the methods used for the node permutation test
summary : dict
Dictionary containing the summary of the evaluation.
significance_level : float
@ -433,7 +439,6 @@ class EvaluationResult:
"""
methods: tuple
summary: dict
significance_level: float
suggestions: Optional[dict] = None
@ -451,13 +456,13 @@ class EvaluationResult:
self.falsified = None
self.falsifiable = None
elif (
self.summary["validate_lmc"][FalsifyConst.P_VALUE]
self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.P_VALUE]
> self.significance_level
> self.summary["validate_tpa"][FalsifyConst.P_VALUE]
> self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.P_VALUE]
):
self.falsified = True
self.falsifiable = True
elif self.significance_level < self.summary["validate_tpa"][FalsifyConst.P_VALUE]:
elif self.significance_level < self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.P_VALUE]:
self.falsified = False
self.falsifiable = False
else:
@ -469,10 +474,10 @@ class EvaluationResult:
if self.can_evaluate:
decision = " " if self.falsified else " do not "
informative = " " if self.falsifiable else " not "
frac_MEC = f"{len(self.summary['MEC'])} / {len(self.summary['validate_tpa'][FalsifyConst.PERM_VIOLATIONS])}"
frac_VLMC = f"{self.summary['validate_lmc'][FalsifyConst.GIVEN_VIOLATIONS]}/{self.summary['validate_lmc'][FalsifyConst.N_TESTS]}"
p_LMC = self.summary["validate_lmc"][FalsifyConst.P_VALUE]
p_dSep = self.summary["validate_tpa"][FalsifyConst.P_VALUE]
frac_MEC = f"{len(self.summary[FalsifyConst.MEC])} / {len(self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.PERM_VIOLATIONS])}"
frac_VLMC = f"{self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.GIVEN_VIOLATIONS]}/{self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS]}"
p_LMC = self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.P_VALUE]
p_dSep = self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.P_VALUE]
validation_repr = [
f"The given DAG is{informative}informative because {frac_MEC} of the permutations lie in the Markov",
f"equivalence class of the given DAG (p-value: {p_dSep:.2f}).",
@ -495,8 +500,8 @@ class EvaluationResult:
def _can_evaluate(self):
can_evaluate = True
for m in (validate_lmc, validate_tpa):
if m.__name__ not in [n.__name__ for n in self.methods]:
for m in (FalsifyConst.VALIDATE_LMC, FalsifyConst.VALIDATE_TPA):
if m not in self.summary:
can_evaluate = False
return can_evaluate
@ -504,7 +509,7 @@ class EvaluationResult:
def falsify_graph(
causal_graph: DirectedGraph,
data: Optional[pd.DataFrame] = None,
data: pd.DataFrame,
methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
suggestion_methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
suggestions: bool = False,
@ -542,7 +547,7 @@ def falsify_graph(
you should set `n_permutations` to some larger value (e.g. 100 or 1000). If `n_permutations=-1` we test on all
n_nodes! permutations.
`methods` and `suggestion_methods` must be wrapped in wrap_partial(method, **kwargs) (c.f. `validate_graph`).
`methods` and `suggestion_methods` must be wrapped in partial(method, **kwargs) (c.f. `validate_graph`).
Related paper:
Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023).
@ -577,18 +582,16 @@ def falsify_graph(
# If no methods are provided, use default ones: validate_lmc, validate_tpa
if not methods:
assert data is not None, "If methods=None, must provide data instead!"
methods = (
wrap_partial(
partial(
validate_lmc,
data=data,
independence_test=independence_test,
conditional_independence_test=conditional_independence_test,
significance_level=significance_ci,
p_values_memory=p_values_memory,
n_jobs=n_jobs,
),
wrap_partial(validate_tpa, causal_graph_reference=causal_graph),
partial(validate_tpa, causal_graph_reference=causal_graph),
)
elif isinstance(methods, Callable):
methods = (methods,)
@ -599,11 +602,9 @@ def falsify_graph(
elif suggestions and isinstance(suggestion_methods, Callable):
suggestion_methods = (suggestion_methods,)
elif suggestions:
assert data is not None, "If suggestions=True and suggestion_methods=None, must provide data instead!"
suggestion_methods = (
wrap_partial(
partial(
validate_cm,
data=data,
independence_test=independence_test,
conditional_independence_test=conditional_independence_test,
significance_level=significance_ci,
@ -614,52 +615,55 @@ def falsify_graph(
summary_given = validate_graph(
causal_graph,
data,
methods=methods + suggestion_methods,
)
summary_perm = _permutation_based(
causal_graph,
data,
methods=methods,
exclude_original_order=False,
n_permutations=n_permutations,
show_progress_bar=show_progress_bar,
)
summary = {m.__name__: dict() for m in methods}
summary = dict()
validation_methods = set(summary_perm.keys()) - {FalsifyConst.PERM_GRAPHS}
for m in validation_methods:
summary[m] = dict()
for m, m_summary in summary.items():
m_summary[FalsifyConst.PERM_VIOLATIONS] = [perm[FalsifyConst.N_VIOLATIONS] for perm in summary_perm[m]]
m_summary[FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS]
m_summary[FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS]
m_summary[FalsifyConst.F_PERM_VIOLATIONS] = [
summary[m][FalsifyConst.PERM_VIOLATIONS] = [perm[FalsifyConst.N_VIOLATIONS] for perm in summary_perm[m]]
summary[m][FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS]
summary[m][FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS]
summary[m][FalsifyConst.F_PERM_VIOLATIONS] = [
perm[FalsifyConst.N_VIOLATIONS] / perm[FalsifyConst.N_TESTS] for perm in summary_perm[m]
]
m_summary[FalsifyConst.F_GIVEN_VIOLATIONS] = (
m_summary[FalsifyConst.GIVEN_VIOLATIONS] / m_summary[FalsifyConst.N_TESTS]
summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = (
summary[m][FalsifyConst.GIVEN_VIOLATIONS] / summary[m][FalsifyConst.N_TESTS]
)
m_summary[FalsifyConst.P_VALUE] = sum(
summary[m][FalsifyConst.P_VALUE] = sum(
[
1
for perm in m_summary[FalsifyConst.F_PERM_VIOLATIONS]
if perm <= m_summary[FalsifyConst.F_GIVEN_VIOLATIONS]
for perm in summary[m][FalsifyConst.F_PERM_VIOLATIONS]
if perm <= summary[m][FalsifyConst.F_GIVEN_VIOLATIONS]
]
) / len(m_summary[FalsifyConst.PERM_VIOLATIONS])
) / len(summary[m][FalsifyConst.PERM_VIOLATIONS])
if m != "validate_tpa":
if m != FalsifyConst.VALIDATE_TPA:
# Append list of violations (node, non_desc) to get local information
m_summary[FalsifyConst.LOCAL_VIOLATION_INSIGHT] = summary_given[m][FalsifyConst.P_VALUES]
summary[m][FalsifyConst.LOCAL_VIOLATION_INSIGHT] = summary_given[m][FalsifyConst.P_VALUES]
if "validate_tpa" in summary:
summary["MEC"] = [
summary_perm["permuted_graphs"][i]
for i, v in enumerate(summary["validate_tpa"][FalsifyConst.PERM_VIOLATIONS])
if FalsifyConst.VALIDATE_TPA in summary:
summary[FalsifyConst.MEC] = [
summary_perm[FalsifyConst.PERM_GRAPHS][i]
for i, v in enumerate(summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.PERM_VIOLATIONS])
if v == 0
]
result = EvaluationResult(
methods=methods,
summary=summary,
significance_level=significance_level,
suggestions={m.__name__: summary_given[m.__name__] for m in suggestion_methods},
suggestions={m: summary_given[m] for m in summary_given if m not in validation_methods},
)
if plot_histogram:
plot_evaluation_results(result, **plot_kwargs)
@ -693,7 +697,7 @@ def plot_evaluation_results(evaluation_result, figsize=(8, 3), bins=None, title=
data = []
labels = []
evaluation_summary = {k: v for k, v in evaluation_result.summary.items() if k != "MEC"}
evaluation_summary = {k: v for k, v in evaluation_result.summary.items() if k != FalsifyConst.MEC}
for i, (m, m_summary) in enumerate(evaluation_summary.items()):
data.append(m_summary[FalsifyConst.F_PERM_VIOLATIONS])
labels.append(f"Violations of {FALSIFY_METHODS[m]} of permuted DAGs")
@ -731,7 +735,7 @@ def plot_evaluation_results(evaluation_result, figsize=(8, 3), bins=None, title=
def plot_local_insights(
causal_graph: DirectedGraph,
evaluation_result: Union[EvaluationResult, Dict],
method: Optional[str] = "validate_lmc",
method: Optional[str] = FalsifyConst.VALIDATE_LMC,
):
"""
For some given graph and evaluation result plot local violations.
@ -762,13 +766,13 @@ def plot_local_insights(
for nodes, result in local_insight_dict.items():
if result[1]:
if method == "validate_lmc":
if method == FalsifyConst.VALIDATE_LMC:
# For LMC we highlight X for which X _|/|_ Y \in ND_X | Pa_X
colors[nodes[0]] = VIOLATION_COLOR
elif method == "validate_pd":
elif method == FalsifyConst.VALIDATE_PD:
# For PD we highlight the edge (if Y\in Anc_X -> X are adjacent)
colors[(nodes[1], nodes[0])] = VIOLATION_COLOR
elif method == "validate_cm":
elif method == FalsifyConst.VALIDATE_CM:
# For causal minimality we highlight the edge Y \in Pa_X -> X
colors[(nodes[1], nodes[0])] = VIOLATION_COLOR
@ -843,14 +847,14 @@ def _compute_p_value(
for node in [X, Y]:
if not node in data.columns:
warnings.warn(f"WARN: Couldn't find data for node {node}. Skip this test.")
return -1
return
if Z:
# Test if we have data for Z
for node in Z:
if not node in data.columns:
warnings.warn(f"WARN: Couldn't find data for node {node}. Skip this test.")
return -1
return
p_value = conditional_independence_test(data[X].values, data[Y].values, data[Z].values)
else:
p_value = independence_test(data[X].values, data[Y].values)
@ -875,6 +879,7 @@ def _get_parental_triples(causal_graph: DirectedGraph, include_unconditional: bo
def _permutation_based(
causal_graph: DirectedGraph,
data: pd.DataFrame,
methods: Union[Callable, Tuple[Callable, ...], List[Callable]],
exclude_original_order: bool,
n_permutations: int,
@ -895,15 +900,18 @@ def _permutation_based(
methods = (methods,)
perm_gen = _PermuteNodes(causal_graph, n_permutations=n_permutations, exclude_original_order=exclude_original_order)
validation_summary = {"permuted_graphs": [], **{m.__name__: [] for m in methods}}
validation_summary = {FalsifyConst.PERM_GRAPHS: []}
for permuted_graph in tqdm(perm_gen, desc="Test permutations of given graph", disable=not show_progress_bar):
res = validate_graph(
causal_graph=permuted_graph,
data=data,
methods=methods,
)
validation_summary["permuted_graphs"].append(permuted_graph)
for m in methods:
validation_summary[m.__name__].append(res[m.__name__])
validation_summary[FalsifyConst.PERM_GRAPHS].append(permuted_graph)
for m_name, summary in res.items():
if m_name not in validation_summary:
validation_summary[m_name] = []
validation_summary[m_name].append(summary)
return validation_summary
@ -973,9 +981,3 @@ def _to_frozenset(x: Union[Set, List, str]):
if isinstance(x, str):
return frozenset({x})
return frozenset(x)
def wrap_partial(f, *args, **kwargs):
partial_f = partial(f, *args, **kwargs)
update_wrapper(partial_f, f)
return partial_f

Просмотреть файл

@ -1,5 +1,7 @@
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
from functools import partial
import networkx as nx
import numpy as np
import pandas as pd
@ -14,7 +16,6 @@ from dowhy.gcm.falsify import (
validate_lmc,
validate_pd,
validate_tpa,
wrap_partial,
)
from dowhy.gcm.independence_test.generalised_cov_measure import generalised_cov_based
from dowhy.gcm.independence_test.kernel import kernel_based
@ -75,21 +76,20 @@ def test_given_correct_collider_when_validating_graph_then_report_no_violations(
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
true_dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=true_dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 2
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 2
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 2
@flaky(max_runs=5)
@ -103,21 +103,20 @@ def test_given_wrong_collider_when_validating_graph_then_report_violations():
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=true_dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 2
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 2
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 1
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 2
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 2
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 2
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 1
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 2
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 2
@flaky(max_runs=5)
@ -130,21 +129,20 @@ def test_given_correct_chain_when_validating_graph_then_report_no_violations():
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
true_dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=true_dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 1
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 3
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 3
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 1
@flaky(max_runs=5)
@ -158,21 +156,20 @@ def test_given_wrong_chain_when_validating_graph_then_report_violations():
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=true_dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 1
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 1
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 3
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 1
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 1
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 3
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 1
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 1
@flaky(max_runs=5)
@ -182,20 +179,19 @@ def test_given_empty_DAG_and_data_when_validating_graph_then_report_no_violation
data = pd.DataFrame()
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 0
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
@flaky(max_runs=5)
@ -209,21 +205,20 @@ def test_given_correct_full_DAG_when_validating_graph_then_report_no_violations(
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, X3=X3))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 6
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 6
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
@flaky(max_runs=5)
@ -234,21 +229,20 @@ def test_given_correct_single_node_when_validating_graph_then_report_no_violatio
data = pd.DataFrame(data=dict(X=np.random.normal(size=500)))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 0
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
@flaky(max_runs=5)
@ -260,21 +254,20 @@ def test_given_correct_single_edge_when_validating_graph_then_report_no_violatio
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 1
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
@flaky(max_runs=5)
@ -289,21 +282,20 @@ def test_given_wrong_single_edge_when_validating_graph_then_report_violations():
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
partial(validate_pd, independence_test=_gcm_linear),
partial(validate_tpa, causal_graph_reference=true_dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 2
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 2
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 0
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 2
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 2
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 2
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 2
@flaky(max_runs=5)
@ -314,21 +306,20 @@ def test_given_correct_categorical_when_validating_graph_then_report_no_violatio
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
summary = validate_graph(
dag,
data,
methods=(
wrap_partial(
validate_lmc, data=data, independence_test=kernel_based, conditional_independence_test=kernel_based
),
wrap_partial(validate_pd, data=data, independence_test=kernel_based),
wrap_partial(validate_tpa, causal_graph_reference=dag),
partial(validate_lmc, independence_test=kernel_based, conditional_independence_test=kernel_based),
partial(validate_pd, independence_test=kernel_based),
partial(validate_tpa, causal_graph_reference=dag),
),
)
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 1
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 3
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 1
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 3
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 1
@flaky(max_runs=5)
@ -345,13 +336,12 @@ def test_given_non_minimal_DAG_when_validating_causal_minimality_then_report_vio
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
summary = validate_graph(
given_dag,
methods=wrap_partial(
validate_cm, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
data,
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
)
assert summary["validate_cm"][FalsifyConst.N_VIOLATIONS] == 1
assert summary["validate_cm"][FalsifyConst.N_TESTS] == 3
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 1
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 3
@flaky(max_runs=5)
@ -364,10 +354,9 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
summary = validate_graph(
dag,
methods=wrap_partial(
validate_cm, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
),
data,
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
)
assert summary["validate_cm"][FalsifyConst.N_VIOLATIONS] == 0
assert summary["validate_cm"][FalsifyConst.N_TESTS] == 2
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 0
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 2