remove wrap_partial, simplify validate_graph, None for missing p-values
Signed-off-by: eeulig <contact@eeulig.com>
This commit is contained in:
Родитель
70b6b983b0
Коммит
379f188d67
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -4,8 +4,8 @@ Functions in this module should be considered experimental, meaning there might
|
|||
"""
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import partial, update_wrapper
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from itertools import permutations
|
||||
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Union
|
||||
|
||||
|
@ -24,26 +24,35 @@ from dowhy.gcm.util.general import set_random_seed
|
|||
from dowhy.gcm.validation import _get_non_descendants
|
||||
from dowhy.graph import DirectedGraph, get_ordered_predecessors
|
||||
|
||||
FALSIFY_METHODS = {
|
||||
"validate_lmc": "LMC",
|
||||
"validate_pd": "Faithfulness",
|
||||
"validate_tpa": "tPa",
|
||||
"validate_cm": "Causal Minimality",
|
||||
}
|
||||
VIOLATION_COLOR = "red"
|
||||
COLORS = list(mcolors.TABLEAU_COLORS.values())
|
||||
|
||||
|
||||
class FalsifyConst(Enum):
|
||||
N_VIOLATIONS = 0
|
||||
N_TESTS = 1
|
||||
P_VALUE = 2
|
||||
P_VALUES = 3
|
||||
GIVEN_VIOLATIONS = 4
|
||||
PERM_VIOLATIONS = 5
|
||||
F_GIVEN_VIOLATIONS = 6
|
||||
F_PERM_VIOLATIONS = 7
|
||||
LOCAL_VIOLATION_INSIGHT = 8
|
||||
N_VIOLATIONS = auto()
|
||||
N_TESTS = auto()
|
||||
P_VALUE = auto()
|
||||
P_VALUES = auto()
|
||||
GIVEN_VIOLATIONS = auto()
|
||||
PERM_VIOLATIONS = auto()
|
||||
F_GIVEN_VIOLATIONS = auto()
|
||||
F_PERM_VIOLATIONS = auto()
|
||||
LOCAL_VIOLATION_INSIGHT = auto()
|
||||
METHOD = auto()
|
||||
VALIDATE_LMC = auto()
|
||||
VALIDATE_TPA = auto()
|
||||
VALIDATE_PD = auto()
|
||||
VALIDATE_CM = auto()
|
||||
PERM_GRAPHS = auto()
|
||||
MEC = auto()
|
||||
|
||||
|
||||
FALSIFY_METHODS = {
|
||||
FalsifyConst.VALIDATE_LMC: "LMC",
|
||||
FalsifyConst.VALIDATE_PD: "Faithfulness",
|
||||
FalsifyConst.VALIDATE_TPA: "tPa",
|
||||
FalsifyConst.VALIDATE_CM: "Causal Minimality",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -58,7 +67,7 @@ class _PValuesMemory:
|
|||
|
||||
def add_p_value(
|
||||
self,
|
||||
p_value: float,
|
||||
p_value: Union[None, float],
|
||||
X: Union[Set, List, str],
|
||||
Y: Union[Set, List, str],
|
||||
Z: Optional[Union[Set, List, str]] = None,
|
||||
|
@ -118,7 +127,12 @@ def validate_lmc(
|
|||
if p_values_memory is None:
|
||||
p_values_memory = _PValuesMemory()
|
||||
|
||||
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: 0, FalsifyConst.P_VALUES: dict()}
|
||||
validation_summary = {
|
||||
FalsifyConst.METHOD: FalsifyConst.VALIDATE_LMC,
|
||||
FalsifyConst.N_VIOLATIONS: 0,
|
||||
FalsifyConst.N_TESTS: 0,
|
||||
FalsifyConst.P_VALUES: dict(),
|
||||
}
|
||||
|
||||
# Find out which tests to do
|
||||
triples = _get_parental_triples(causal_graph, include_unconditional)
|
||||
|
@ -126,7 +140,7 @@ def validate_lmc(
|
|||
for node, non_desc, parents in triples:
|
||||
if not (node, non_desc, parents) in p_values_memory:
|
||||
to_test.append((node, non_desc, parents))
|
||||
p_values_memory.add_p_value(-1, node, non_desc, parents) # Placeholder
|
||||
p_values_memory.add_p_value(None, node, non_desc, parents) # Placeholder
|
||||
|
||||
# Parallelize over tests
|
||||
random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(to_test))
|
||||
|
@ -150,7 +164,7 @@ def validate_lmc(
|
|||
# Summarize
|
||||
for node, non_desc, parents in triples:
|
||||
lmc_p_value = p_values_memory.get_p_value(node, non_desc, parents)
|
||||
if lmc_p_value != -1:
|
||||
if lmc_p_value is not None:
|
||||
validation_summary[FalsifyConst.N_TESTS] += 1
|
||||
validation_summary[FalsifyConst.P_VALUES][(node, non_desc)] = (
|
||||
lmc_p_value,
|
||||
|
@ -165,21 +179,27 @@ def validate_tpa(
|
|||
causal_graph: DirectedGraph,
|
||||
causal_graph_reference: DirectedGraph,
|
||||
include_unconditional: bool = True,
|
||||
data: Optional[pd.DataFrame] = None,
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Graphical criterion to evaluate which pairwise parental d-separations in `causal_graph` are violated, assuming
|
||||
`causal_graph_reference` is the ground truth graph. If none are violated, then both graphs lie in the same Markov
|
||||
equivalence class.
|
||||
Graphical criterion to evaluate which pairwise parental d-separations (parental triples) in `causal_graph` are
|
||||
violated, assuming `causal_graph_reference` is the ground truth graph. If none are violated, then both graphs lie
|
||||
in the same Markov equivalence class.
|
||||
Specifically we test:
|
||||
X _|_G' Y | Z and X _/|_G Y | Z for Y \in ND{X}^G', Z = PA{X}^G
|
||||
|
||||
:param causal_graph: Causal graph for which to evaluate parental d-separations (G')
|
||||
:param causal_graph_reference: Causal graph where we test if d-separation holds (G)
|
||||
:param include_unconditional: Test also unconditional independencies of root nodes.
|
||||
:param data: IGNORED! No data is needed for this validation method and thus the data argument is ignored!
|
||||
:return: Validation summary with number of d-separations implied by `causal_graph` and number of times these are
|
||||
violated in the graph `causal_graph_reference`.
|
||||
"""
|
||||
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: 0}
|
||||
validation_summary = {
|
||||
FalsifyConst.METHOD: FalsifyConst.VALIDATE_TPA,
|
||||
FalsifyConst.N_VIOLATIONS: 0,
|
||||
FalsifyConst.N_TESTS: 0,
|
||||
}
|
||||
|
||||
triples = _get_parental_triples(causal_graph, include_unconditional)
|
||||
for node, non_desc, parents in triples:
|
||||
|
@ -231,7 +251,12 @@ def validate_pd(
|
|||
if n_pairs > len(pairs):
|
||||
raise ValueError(f"n_pairs ({n_pairs}) > number of pairs in the DAG ({len(pairs)})")
|
||||
|
||||
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: n_pairs, FalsifyConst.P_VALUES: dict()}
|
||||
validation_summary = {
|
||||
FalsifyConst.METHOD: FalsifyConst.VALIDATE_PD,
|
||||
FalsifyConst.N_VIOLATIONS: 0,
|
||||
FalsifyConst.N_TESTS: n_pairs,
|
||||
FalsifyConst.P_VALUES: dict(),
|
||||
}
|
||||
pair_idxs = np.random.choice(len(pairs), size=n_pairs, replace=False)
|
||||
|
||||
# Find out which tests to do
|
||||
|
@ -240,7 +265,7 @@ def validate_pd(
|
|||
ancestor, node = pairs[pair_idx]
|
||||
if not (ancestor, node) in p_values_memory:
|
||||
to_test.append((ancestor, node))
|
||||
p_values_memory.add_p_value(-1, ancestor, node) # Placeholder
|
||||
p_values_memory.add_p_value(None, ancestor, node) # Placeholder
|
||||
|
||||
# Parallelize over tests
|
||||
random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(to_test))
|
||||
|
@ -265,7 +290,7 @@ def validate_pd(
|
|||
for pair_idx in pair_idxs:
|
||||
ancestor, node = pairs[pair_idx]
|
||||
p_value = p_values_memory.get_p_value(ancestor, node)
|
||||
if p_value != -1:
|
||||
if p_value is not None:
|
||||
validation_summary[FalsifyConst.P_VALUES][(node, ancestor)] = (p_value, p_value > significance_level)
|
||||
if p_value > significance_level:
|
||||
validation_summary[FalsifyConst.N_VIOLATIONS] += 1
|
||||
|
@ -301,7 +326,12 @@ def validate_cm(
|
|||
if p_values_memory is None:
|
||||
p_values_memory = _PValuesMemory()
|
||||
|
||||
validation_summary = {FalsifyConst.N_VIOLATIONS: 0, FalsifyConst.N_TESTS: 0, FalsifyConst.P_VALUES: dict()}
|
||||
validation_summary = {
|
||||
FalsifyConst.METHOD: FalsifyConst.VALIDATE_CM,
|
||||
FalsifyConst.N_VIOLATIONS: 0,
|
||||
FalsifyConst.N_TESTS: 0,
|
||||
FalsifyConst.P_VALUES: dict(),
|
||||
}
|
||||
|
||||
# Find out which tests to do
|
||||
triples = []
|
||||
|
@ -314,7 +344,7 @@ def validate_cm(
|
|||
triples.append((node, p, other_parents))
|
||||
if not (node, p, other_parents) in p_values_memory:
|
||||
to_test.append((node, p, other_parents))
|
||||
p_values_memory.add_p_value(-1, node, p, other_parents) # Placeholder
|
||||
p_values_memory.add_p_value(None, node, p, other_parents) # Placeholder
|
||||
|
||||
# Parallelize over tests
|
||||
random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(to_test))
|
||||
|
@ -338,7 +368,7 @@ def validate_cm(
|
|||
# Summarize
|
||||
for node, p, other_parents in triples:
|
||||
p_value = p_values_memory.get_p_value(node, p, other_parents)
|
||||
if p_value != -1:
|
||||
if p_value is not None:
|
||||
validation_summary[FalsifyConst.N_TESTS] += 1
|
||||
validation_summary[FalsifyConst.P_VALUES][(node, p, tuple(other_parents))] = (
|
||||
p_value,
|
||||
|
@ -352,27 +382,20 @@ def validate_cm(
|
|||
|
||||
def validate_graph(
|
||||
causal_graph: DirectedGraph,
|
||||
data: Optional[pd.DataFrame] = None,
|
||||
methods: Optional[Union[Callable, Tuple[Callable, ...], List[Callable]]] = None,
|
||||
independence_test: Optional[Callable[[np.ndarray, np.ndarray], float]] = kernel_based,
|
||||
conditional_independence_test: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray], float]] = kernel_based,
|
||||
significance_level: float = 0.05,
|
||||
n_jobs: Optional[int] = None,
|
||||
data: pd.DataFrame,
|
||||
methods: Optional[Union[Callable, Tuple[Callable, ...], List[Callable]]] = partial(
|
||||
validate_lmc, independence_test=kernel_based, conditional_independence_test=kernel_based
|
||||
),
|
||||
) -> Dict[str, Dict]:
|
||||
"""
|
||||
Validate a given causal graph using observational data and some given methods. If methods are provided, they must
|
||||
be wrapped in a wrapped_partial object, with their respective parameters. E.g., if one wants to test the local
|
||||
be wrapped in a partial object, with their respective parameters. E.g., if one wants to test the local
|
||||
Markov conditions and the pairwise dependencies (unconditional faithfulness), then call
|
||||
validate_graph(G, methods=(
|
||||
wrap_partial(validate_lmc, data=data, independence_test=..., conditional_independence_test=...),
|
||||
wrap_partial(validate_pd, data=data, independence_test=...),
|
||||
validate_graph(G, data, methods=(
|
||||
partial(validate_lmc, independence_test=..., conditional_independence_test=...),
|
||||
partial(validate_pd, independence_test=...),
|
||||
)
|
||||
)
|
||||
If called with methods=None, then we only test lmc and expect that the data is provided via data=data. Optionally,
|
||||
one can provide additional arguments to this test via the respective keyword arguments.
|
||||
WARNING: If methods are provided, the optional keywords arguments are ignored and overwritten by the ones set in
|
||||
wrap_partial(validate_lmc, ...), or the default arguments by the respective method if not provided in
|
||||
wrap_partial. E.g. `independence_test` should be provided in wrap_partial(method, independence_test=...).
|
||||
|
||||
:param causal_graph: A directed acyclic graph (DAG).
|
||||
:param data: Observations of variables in the DAG.
|
||||
|
@ -380,33 +403,18 @@ def validate_graph(
|
|||
wrap_partial(validate_lmc, data=data, independence_test=..., conditional_independence_test=...).
|
||||
If no methods are provided we run validate_lmc with optional keyword arguments provided to
|
||||
validate_graph.
|
||||
:param independence_test: Test to use for unconditional independencies (only used if include_unconditional=True)
|
||||
:param conditional_independence_test: Conditional independence test to use for checking local Markov condition.
|
||||
:param significance_level: Significance level for (conditional) independence tests.
|
||||
:param n_jobs: Number of jobs to use for parallel execution of (conditional) independence tests.
|
||||
:return: Validation summary as dict.
|
||||
"""
|
||||
if not methods:
|
||||
assert data is not None, "If (partial) methods are not given, then must provide data!"
|
||||
|
||||
methods = (
|
||||
wrap_partial(
|
||||
validate_lmc,
|
||||
data=data,
|
||||
independence_test=independence_test,
|
||||
conditional_independence_test=conditional_independence_test,
|
||||
significance_level=significance_level,
|
||||
n_jobs=n_jobs,
|
||||
),
|
||||
)
|
||||
elif not isinstance(methods, (tuple, list)):
|
||||
if not isinstance(methods, (tuple, list)):
|
||||
methods = (methods,)
|
||||
|
||||
validation_summary = dict()
|
||||
|
||||
for m in methods:
|
||||
m_summary = m(causal_graph=causal_graph)
|
||||
validation_summary[m.__name__] = m_summary
|
||||
m_summary = m(causal_graph=causal_graph, data=data)
|
||||
m_name = m_summary.pop(FalsifyConst.METHOD)
|
||||
validation_summary[m_name] = m_summary
|
||||
|
||||
return validation_summary
|
||||
|
||||
|
@ -420,8 +428,6 @@ class EvaluationResult:
|
|||
|
||||
Attributes
|
||||
----------
|
||||
methods : tuple
|
||||
Tuple containing the methods used for the node permutation test
|
||||
summary : dict
|
||||
Dictionary containing the summary of the evaluation.
|
||||
significance_level : float
|
||||
|
@ -433,7 +439,6 @@ class EvaluationResult:
|
|||
|
||||
"""
|
||||
|
||||
methods: tuple
|
||||
summary: dict
|
||||
significance_level: float
|
||||
suggestions: Optional[dict] = None
|
||||
|
@ -451,13 +456,13 @@ class EvaluationResult:
|
|||
self.falsified = None
|
||||
self.falsifiable = None
|
||||
elif (
|
||||
self.summary["validate_lmc"][FalsifyConst.P_VALUE]
|
||||
self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.P_VALUE]
|
||||
> self.significance_level
|
||||
> self.summary["validate_tpa"][FalsifyConst.P_VALUE]
|
||||
> self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.P_VALUE]
|
||||
):
|
||||
self.falsified = True
|
||||
self.falsifiable = True
|
||||
elif self.significance_level < self.summary["validate_tpa"][FalsifyConst.P_VALUE]:
|
||||
elif self.significance_level < self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.P_VALUE]:
|
||||
self.falsified = False
|
||||
self.falsifiable = False
|
||||
else:
|
||||
|
@ -469,10 +474,10 @@ class EvaluationResult:
|
|||
if self.can_evaluate:
|
||||
decision = " " if self.falsified else " do not "
|
||||
informative = " " if self.falsifiable else " not "
|
||||
frac_MEC = f"{len(self.summary['MEC'])} / {len(self.summary['validate_tpa'][FalsifyConst.PERM_VIOLATIONS])}"
|
||||
frac_VLMC = f"{self.summary['validate_lmc'][FalsifyConst.GIVEN_VIOLATIONS]}/{self.summary['validate_lmc'][FalsifyConst.N_TESTS]}"
|
||||
p_LMC = self.summary["validate_lmc"][FalsifyConst.P_VALUE]
|
||||
p_dSep = self.summary["validate_tpa"][FalsifyConst.P_VALUE]
|
||||
frac_MEC = f"{len(self.summary[FalsifyConst.MEC])} / {len(self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.PERM_VIOLATIONS])}"
|
||||
frac_VLMC = f"{self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.GIVEN_VIOLATIONS]}/{self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS]}"
|
||||
p_LMC = self.summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.P_VALUE]
|
||||
p_dSep = self.summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.P_VALUE]
|
||||
validation_repr = [
|
||||
f"The given DAG is{informative}informative because {frac_MEC} of the permutations lie in the Markov",
|
||||
f"equivalence class of the given DAG (p-value: {p_dSep:.2f}).",
|
||||
|
@ -495,8 +500,8 @@ class EvaluationResult:
|
|||
|
||||
def _can_evaluate(self):
|
||||
can_evaluate = True
|
||||
for m in (validate_lmc, validate_tpa):
|
||||
if m.__name__ not in [n.__name__ for n in self.methods]:
|
||||
for m in (FalsifyConst.VALIDATE_LMC, FalsifyConst.VALIDATE_TPA):
|
||||
if m not in self.summary:
|
||||
can_evaluate = False
|
||||
|
||||
return can_evaluate
|
||||
|
@ -504,7 +509,7 @@ class EvaluationResult:
|
|||
|
||||
def falsify_graph(
|
||||
causal_graph: DirectedGraph,
|
||||
data: Optional[pd.DataFrame] = None,
|
||||
data: pd.DataFrame,
|
||||
methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
|
||||
suggestion_methods: Optional[Union[Callable, Tuple[Callable, ...]]] = None,
|
||||
suggestions: bool = False,
|
||||
|
@ -542,7 +547,7 @@ def falsify_graph(
|
|||
you should set `n_permutations` to some larger value (e.g. 100 or 1000). If `n_permutations=-1` we test on all
|
||||
n_nodes! permutations.
|
||||
|
||||
`methods` and `suggestion_methods` must be wrapped in wrap_partial(method, **kwargs) (c.f. `validate_graph`).
|
||||
`methods` and `suggestion_methods` must be wrapped in partial(method, **kwargs) (c.f. `validate_graph`).
|
||||
|
||||
Related paper:
|
||||
Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023).
|
||||
|
@ -577,18 +582,16 @@ def falsify_graph(
|
|||
|
||||
# If no methods are provided, use default ones: validate_lmc, validate_tpa
|
||||
if not methods:
|
||||
assert data is not None, "If methods=None, must provide data instead!"
|
||||
methods = (
|
||||
wrap_partial(
|
||||
partial(
|
||||
validate_lmc,
|
||||
data=data,
|
||||
independence_test=independence_test,
|
||||
conditional_independence_test=conditional_independence_test,
|
||||
significance_level=significance_ci,
|
||||
p_values_memory=p_values_memory,
|
||||
n_jobs=n_jobs,
|
||||
),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=causal_graph),
|
||||
partial(validate_tpa, causal_graph_reference=causal_graph),
|
||||
)
|
||||
elif isinstance(methods, Callable):
|
||||
methods = (methods,)
|
||||
|
@ -599,11 +602,9 @@ def falsify_graph(
|
|||
elif suggestions and isinstance(suggestion_methods, Callable):
|
||||
suggestion_methods = (suggestion_methods,)
|
||||
elif suggestions:
|
||||
assert data is not None, "If suggestions=True and suggestion_methods=None, must provide data instead!"
|
||||
suggestion_methods = (
|
||||
wrap_partial(
|
||||
partial(
|
||||
validate_cm,
|
||||
data=data,
|
||||
independence_test=independence_test,
|
||||
conditional_independence_test=conditional_independence_test,
|
||||
significance_level=significance_ci,
|
||||
|
@ -614,52 +615,55 @@ def falsify_graph(
|
|||
|
||||
summary_given = validate_graph(
|
||||
causal_graph,
|
||||
data,
|
||||
methods=methods + suggestion_methods,
|
||||
)
|
||||
summary_perm = _permutation_based(
|
||||
causal_graph,
|
||||
data,
|
||||
methods=methods,
|
||||
exclude_original_order=False,
|
||||
n_permutations=n_permutations,
|
||||
show_progress_bar=show_progress_bar,
|
||||
)
|
||||
|
||||
summary = {m.__name__: dict() for m in methods}
|
||||
summary = dict()
|
||||
validation_methods = set(summary_perm.keys()) - {FalsifyConst.PERM_GRAPHS}
|
||||
for m in validation_methods:
|
||||
summary[m] = dict()
|
||||
|
||||
for m, m_summary in summary.items():
|
||||
m_summary[FalsifyConst.PERM_VIOLATIONS] = [perm[FalsifyConst.N_VIOLATIONS] for perm in summary_perm[m]]
|
||||
m_summary[FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS]
|
||||
m_summary[FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS]
|
||||
m_summary[FalsifyConst.F_PERM_VIOLATIONS] = [
|
||||
summary[m][FalsifyConst.PERM_VIOLATIONS] = [perm[FalsifyConst.N_VIOLATIONS] for perm in summary_perm[m]]
|
||||
summary[m][FalsifyConst.GIVEN_VIOLATIONS] = summary_given[m][FalsifyConst.N_VIOLATIONS]
|
||||
summary[m][FalsifyConst.N_TESTS] = summary_given[m][FalsifyConst.N_TESTS]
|
||||
summary[m][FalsifyConst.F_PERM_VIOLATIONS] = [
|
||||
perm[FalsifyConst.N_VIOLATIONS] / perm[FalsifyConst.N_TESTS] for perm in summary_perm[m]
|
||||
]
|
||||
m_summary[FalsifyConst.F_GIVEN_VIOLATIONS] = (
|
||||
m_summary[FalsifyConst.GIVEN_VIOLATIONS] / m_summary[FalsifyConst.N_TESTS]
|
||||
summary[m][FalsifyConst.F_GIVEN_VIOLATIONS] = (
|
||||
summary[m][FalsifyConst.GIVEN_VIOLATIONS] / summary[m][FalsifyConst.N_TESTS]
|
||||
)
|
||||
m_summary[FalsifyConst.P_VALUE] = sum(
|
||||
summary[m][FalsifyConst.P_VALUE] = sum(
|
||||
[
|
||||
1
|
||||
for perm in m_summary[FalsifyConst.F_PERM_VIOLATIONS]
|
||||
if perm <= m_summary[FalsifyConst.F_GIVEN_VIOLATIONS]
|
||||
for perm in summary[m][FalsifyConst.F_PERM_VIOLATIONS]
|
||||
if perm <= summary[m][FalsifyConst.F_GIVEN_VIOLATIONS]
|
||||
]
|
||||
) / len(m_summary[FalsifyConst.PERM_VIOLATIONS])
|
||||
) / len(summary[m][FalsifyConst.PERM_VIOLATIONS])
|
||||
|
||||
if m != "validate_tpa":
|
||||
if m != FalsifyConst.VALIDATE_TPA:
|
||||
# Append list of violations (node, non_desc) to get local information
|
||||
m_summary[FalsifyConst.LOCAL_VIOLATION_INSIGHT] = summary_given[m][FalsifyConst.P_VALUES]
|
||||
summary[m][FalsifyConst.LOCAL_VIOLATION_INSIGHT] = summary_given[m][FalsifyConst.P_VALUES]
|
||||
|
||||
if "validate_tpa" in summary:
|
||||
summary["MEC"] = [
|
||||
summary_perm["permuted_graphs"][i]
|
||||
for i, v in enumerate(summary["validate_tpa"][FalsifyConst.PERM_VIOLATIONS])
|
||||
if FalsifyConst.VALIDATE_TPA in summary:
|
||||
summary[FalsifyConst.MEC] = [
|
||||
summary_perm[FalsifyConst.PERM_GRAPHS][i]
|
||||
for i, v in enumerate(summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.PERM_VIOLATIONS])
|
||||
if v == 0
|
||||
]
|
||||
|
||||
result = EvaluationResult(
|
||||
methods=methods,
|
||||
summary=summary,
|
||||
significance_level=significance_level,
|
||||
suggestions={m.__name__: summary_given[m.__name__] for m in suggestion_methods},
|
||||
suggestions={m: summary_given[m] for m in summary_given if m not in validation_methods},
|
||||
)
|
||||
if plot_histogram:
|
||||
plot_evaluation_results(result, **plot_kwargs)
|
||||
|
@ -693,7 +697,7 @@ def plot_evaluation_results(evaluation_result, figsize=(8, 3), bins=None, title=
|
|||
data = []
|
||||
labels = []
|
||||
|
||||
evaluation_summary = {k: v for k, v in evaluation_result.summary.items() if k != "MEC"}
|
||||
evaluation_summary = {k: v for k, v in evaluation_result.summary.items() if k != FalsifyConst.MEC}
|
||||
for i, (m, m_summary) in enumerate(evaluation_summary.items()):
|
||||
data.append(m_summary[FalsifyConst.F_PERM_VIOLATIONS])
|
||||
labels.append(f"Violations of {FALSIFY_METHODS[m]} of permuted DAGs")
|
||||
|
@ -731,7 +735,7 @@ def plot_evaluation_results(evaluation_result, figsize=(8, 3), bins=None, title=
|
|||
def plot_local_insights(
|
||||
causal_graph: DirectedGraph,
|
||||
evaluation_result: Union[EvaluationResult, Dict],
|
||||
method: Optional[str] = "validate_lmc",
|
||||
method: Optional[str] = FalsifyConst.VALIDATE_LMC,
|
||||
):
|
||||
"""
|
||||
For some given graph and evaluation result plot local violations.
|
||||
|
@ -762,13 +766,13 @@ def plot_local_insights(
|
|||
|
||||
for nodes, result in local_insight_dict.items():
|
||||
if result[1]:
|
||||
if method == "validate_lmc":
|
||||
if method == FalsifyConst.VALIDATE_LMC:
|
||||
# For LMC we highlight X for which X _|/|_ Y \in ND_X | Pa_X
|
||||
colors[nodes[0]] = VIOLATION_COLOR
|
||||
elif method == "validate_pd":
|
||||
elif method == FalsifyConst.VALIDATE_PD:
|
||||
# For PD we highlight the edge (if Y\in Anc_X -> X are adjacent)
|
||||
colors[(nodes[1], nodes[0])] = VIOLATION_COLOR
|
||||
elif method == "validate_cm":
|
||||
elif method == FalsifyConst.VALIDATE_CM:
|
||||
# For causal minimality we highlight the edge Y \in Pa_X -> X
|
||||
colors[(nodes[1], nodes[0])] = VIOLATION_COLOR
|
||||
|
||||
|
@ -843,14 +847,14 @@ def _compute_p_value(
|
|||
for node in [X, Y]:
|
||||
if not node in data.columns:
|
||||
warnings.warn(f"WARN: Couldn't find data for node {node}. Skip this test.")
|
||||
return -1
|
||||
return
|
||||
|
||||
if Z:
|
||||
# Test if we have data for Z
|
||||
for node in Z:
|
||||
if not node in data.columns:
|
||||
warnings.warn(f"WARN: Couldn't find data for node {node}. Skip this test.")
|
||||
return -1
|
||||
return
|
||||
p_value = conditional_independence_test(data[X].values, data[Y].values, data[Z].values)
|
||||
else:
|
||||
p_value = independence_test(data[X].values, data[Y].values)
|
||||
|
@ -875,6 +879,7 @@ def _get_parental_triples(causal_graph: DirectedGraph, include_unconditional: bo
|
|||
|
||||
def _permutation_based(
|
||||
causal_graph: DirectedGraph,
|
||||
data: pd.DataFrame,
|
||||
methods: Union[Callable, Tuple[Callable, ...], List[Callable]],
|
||||
exclude_original_order: bool,
|
||||
n_permutations: int,
|
||||
|
@ -895,15 +900,18 @@ def _permutation_based(
|
|||
methods = (methods,)
|
||||
|
||||
perm_gen = _PermuteNodes(causal_graph, n_permutations=n_permutations, exclude_original_order=exclude_original_order)
|
||||
validation_summary = {"permuted_graphs": [], **{m.__name__: [] for m in methods}}
|
||||
validation_summary = {FalsifyConst.PERM_GRAPHS: []}
|
||||
for permuted_graph in tqdm(perm_gen, desc="Test permutations of given graph", disable=not show_progress_bar):
|
||||
res = validate_graph(
|
||||
causal_graph=permuted_graph,
|
||||
data=data,
|
||||
methods=methods,
|
||||
)
|
||||
validation_summary["permuted_graphs"].append(permuted_graph)
|
||||
for m in methods:
|
||||
validation_summary[m.__name__].append(res[m.__name__])
|
||||
validation_summary[FalsifyConst.PERM_GRAPHS].append(permuted_graph)
|
||||
for m_name, summary in res.items():
|
||||
if m_name not in validation_summary:
|
||||
validation_summary[m_name] = []
|
||||
validation_summary[m_name].append(summary)
|
||||
|
||||
return validation_summary
|
||||
|
||||
|
@ -973,9 +981,3 @@ def _to_frozenset(x: Union[Set, List, str]):
|
|||
if isinstance(x, str):
|
||||
return frozenset({x})
|
||||
return frozenset(x)
|
||||
|
||||
|
||||
def wrap_partial(f, *args, **kwargs):
|
||||
partial_f = partial(f, *args, **kwargs)
|
||||
update_wrapper(partial_f, f)
|
||||
return partial_f
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -14,7 +16,6 @@ from dowhy.gcm.falsify import (
|
|||
validate_lmc,
|
||||
validate_pd,
|
||||
validate_tpa,
|
||||
wrap_partial,
|
||||
)
|
||||
from dowhy.gcm.independence_test.generalised_cov_measure import generalised_cov_based
|
||||
from dowhy.gcm.independence_test.kernel import kernel_based
|
||||
|
@ -75,21 +76,20 @@ def test_given_correct_collider_when_validating_graph_then_report_no_violations(
|
|||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
true_dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 2
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -103,21 +103,20 @@ def test_given_wrong_collider_when_validating_graph_then_report_violations():
|
|||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 2
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -130,21 +129,20 @@ def test_given_correct_chain_when_validating_graph_then_report_no_violations():
|
|||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
true_dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 3
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 3
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 1
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -158,21 +156,20 @@ def test_given_wrong_chain_when_validating_graph_then_report_violations():
|
|||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 3
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 3
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 1
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -182,20 +179,19 @@ def test_given_empty_DAG_and_data_when_validating_graph_then_report_no_violation
|
|||
data = pd.DataFrame()
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=dag),
|
||||
),
|
||||
)
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -209,21 +205,20 @@ def test_given_correct_full_DAG_when_validating_graph_then_report_no_violations(
|
|||
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, X3=X3))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 6
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 6
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -234,21 +229,20 @@ def test_given_correct_single_node_when_validating_graph_then_report_no_violatio
|
|||
data = pd.DataFrame(data=dict(X=np.random.normal(size=500)))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -260,21 +254,20 @@ def test_given_correct_single_edge_when_validating_graph_then_report_no_violatio
|
|||
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 0
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -289,21 +282,20 @@ def test_given_wrong_single_edge_when_validating_graph_then_report_violations():
|
|||
data = pd.DataFrame(data=dict(X0=X0, X1=X1))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=_gcm_linear),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
partial(validate_lmc, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
partial(validate_pd, independence_test=_gcm_linear),
|
||||
partial(validate_tpa, causal_graph_reference=true_dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 2
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -314,21 +306,20 @@ def test_given_correct_categorical_when_validating_graph_then_report_no_violatio
|
|||
data = pd.DataFrame(data=dict(X=X, Y=Y, Z=Z))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
data,
|
||||
methods=(
|
||||
wrap_partial(
|
||||
validate_lmc, data=data, independence_test=kernel_based, conditional_independence_test=kernel_based
|
||||
),
|
||||
wrap_partial(validate_pd, data=data, independence_test=kernel_based),
|
||||
wrap_partial(validate_tpa, causal_graph_reference=dag),
|
||||
partial(validate_lmc, independence_test=kernel_based, conditional_independence_test=kernel_based),
|
||||
partial(validate_pd, independence_test=kernel_based),
|
||||
partial(validate_tpa, causal_graph_reference=dag),
|
||||
),
|
||||
)
|
||||
|
||||
assert summary["validate_lmc"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_lmc"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary["validate_pd"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_pd"][FalsifyConst.N_TESTS] == 3
|
||||
assert summary["validate_tpa"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_tpa"][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_LMC][FalsifyConst.N_TESTS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_PD][FalsifyConst.N_TESTS] == 3
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_TPA][FalsifyConst.N_TESTS] == 1
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -345,13 +336,12 @@ def test_given_non_minimal_DAG_when_validating_causal_minimality_then_report_vio
|
|||
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
|
||||
summary = validate_graph(
|
||||
given_dag,
|
||||
methods=wrap_partial(
|
||||
validate_cm, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
data,
|
||||
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
)
|
||||
|
||||
assert summary["validate_cm"][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary["validate_cm"][FalsifyConst.N_TESTS] == 3
|
||||
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 1
|
||||
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 3
|
||||
|
||||
|
||||
@flaky(max_runs=5)
|
||||
|
@ -364,10 +354,9 @@ def test_given_minimal_DAG_when_validating_causal_minimality_then_report_no_viol
|
|||
data = pd.DataFrame(data=dict(X0=X0, X1=X1, X2=X2, Y=Y))
|
||||
summary = validate_graph(
|
||||
dag,
|
||||
methods=wrap_partial(
|
||||
validate_cm, data=data, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear
|
||||
),
|
||||
data,
|
||||
methods=partial(validate_cm, independence_test=_gcm_linear, conditional_independence_test=_gcm_linear),
|
||||
)
|
||||
|
||||
assert summary["validate_cm"][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary["validate_cm"][FalsifyConst.N_TESTS] == 2
|
||||
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_VIOLATIONS] == 0
|
||||
assert summary[FalsifyConst.VALIDATE_CM][FalsifyConst.N_TESTS] == 2
|
||||
|
|
Загрузка…
Ссылка в новой задаче