Backdoor identification tests (#247)
* add first tests for backdoor identification * added new backdoor tests and refactored test code * fix: remove exclusion of instrumental variables from backdoor set that creates identification errors in cases like graphs with m-bias * fix: added `include_unobserved` parameter to `identify_backdoor` to be explicit about unobserved variables. * feat: added minimum-sufficient and maximum-possible methods to identify-backdoor fix: default backdoor set now contains minimum possible number of instrumental variables; mod: temporarily made include_unobserved to be True by the default, as most tests require it. * mod: changed backdoor method names: `minimum-sufficient` to `minimal-adjustment` and `maxumum-possible` to `maximal-adjustment` * mod: adding tests for minimal and maximal adjustment, whether adjustment is necessary and small refactoring * fix: fix a bug with some maximal adjustment not existing because it contained unobserved variables. This also brings some performance improvements when there is a lot of unobserved variables.
This commit is contained in:
Родитель
5ff3ff23be
Коммит
a8ba8932e5
|
@ -1,8 +1,10 @@
|
|||
import itertools
|
||||
import logging
|
||||
import re
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from dowhy.utils.api import parse_state
|
||||
import itertools
|
||||
|
||||
|
||||
class CausalGraph:
|
||||
|
@ -344,7 +346,11 @@ class CausalGraph:
|
|||
return True
|
||||
|
||||
def get_all_nodes(self, include_unobserved=True):
|
||||
return self._graph.nodes
|
||||
nodes = self._graph.nodes
|
||||
if not include_unobserved:
|
||||
nodes = set(self.filter_unobserved_variables(nodes))
|
||||
|
||||
return nodes
|
||||
|
||||
def filter_unobserved_variables(self, node_names):
|
||||
observed_node_names = list()
|
||||
|
|
|
@ -23,7 +23,14 @@ class CausalIdentifier:
|
|||
NONPARAMETRIC_NDE="nonparametric-nde"
|
||||
NONPARAMETRIC_NIE="nonparametric-nie"
|
||||
MAX_BACKDOOR_ITERATIONS = 100000
|
||||
VALID_METHOD_NAMES = {"default", "exhaustive-search"}
|
||||
|
||||
# Backdoor method names
|
||||
BACKDOOR_DEFAULT="default"
|
||||
BACKDOOR_EXHAUSTIVE="exhaustive-search"
|
||||
BACKDOOR_MIN="minimal-adjustment"
|
||||
BACKDOOR_MAX="maximal-adjustment"
|
||||
METHOD_NAMES = {BACKDOOR_DEFAULT, BACKDOOR_EXHAUSTIVE, BACKDOOR_MIN, BACKDOOR_MAX}
|
||||
DEFAULT_BACKDOOR_METHOD = BACKDOOR_MAX
|
||||
|
||||
def __init__(self, graph, estimand_type,
|
||||
method_name = "default",
|
||||
|
@ -233,9 +240,10 @@ class CausalIdentifier:
|
|||
|
||||
|
||||
|
||||
def identify_backdoor(self, treatment_name, outcome_name):
|
||||
def identify_backdoor(self, treatment_name, outcome_name, include_unobserved=True):
|
||||
backdoor_sets = []
|
||||
backdoor_paths = self._graph.get_backdoor_paths(treatment_name, outcome_name)
|
||||
method_name = self.method_name if self.method_name != CausalIdentifier.BACKDOOR_DEFAULT else CausalIdentifier.DEFAULT_BACKDOOR_METHOD
|
||||
# First, checking if empty set is a valid backdoor set
|
||||
empty_set = set()
|
||||
check = self._graph.check_valid_backdoor_set(treatment_name, outcome_name, empty_set,
|
||||
|
@ -244,17 +252,22 @@ class CausalIdentifier:
|
|||
backdoor_sets.append({
|
||||
'backdoor_set':empty_set,
|
||||
'num_paths_blocked_by_observed_nodes': check["num_paths_blocked_by_observed_nodes"]})
|
||||
# Second, checking for all other sets of variables
|
||||
eligible_variables = self._graph.get_all_nodes() \
|
||||
- set(treatment_name) \
|
||||
- set(outcome_name) \
|
||||
- set(self._graph.get_instruments(treatment_name, outcome_name))
|
||||
eligible_variables -= self._graph.get_descendants(treatment_name)
|
||||
# If the method is `minimal-adjustment`, return the empty set right away.
|
||||
if method_name == CausalIdentifier.BACKDOOR_MIN:
|
||||
return backdoor_sets
|
||||
|
||||
# Second, checking for all other sets of variables. If include_unobserved is false, then only observed variables are eligible.
|
||||
eligible_variables = self._graph.get_all_nodes(include_unobserved=include_unobserved) \
|
||||
- set(treatment_name) \
|
||||
- set(outcome_name)
|
||||
eligible_variables -= self._graph.get_descendants(treatment_name)
|
||||
|
||||
num_iterations = 0
|
||||
found_valid_adjustment_set = False
|
||||
if self.method_name in CausalIdentifier.VALID_METHOD_NAMES:
|
||||
for size_candidate_set in range(len(eligible_variables), 0, -1):
|
||||
if method_name in CausalIdentifier.METHOD_NAMES:
|
||||
# If `minimal-adjustment` method is specified, start the search from the set with minimum size. Otherwise, start from the largest.
|
||||
set_sizes = range(1, len(eligible_variables) + 1, 1) if method_name == CausalIdentifier.BACKDOOR_MIN else range(len(eligible_variables), 0, -1)
|
||||
for size_candidate_set in set_sizes:
|
||||
for candidate_set in itertools.combinations(eligible_variables, size_candidate_set):
|
||||
check = self._graph.check_valid_backdoor_set(treatment_name,
|
||||
outcome_name, candidate_set, backdoor_paths=backdoor_paths)
|
||||
|
@ -263,33 +276,34 @@ class CausalIdentifier:
|
|||
backdoor_sets.append({
|
||||
'backdoor_set': candidate_set,
|
||||
'num_paths_blocked_by_observed_nodes': check["num_paths_blocked_by_observed_nodes"]})
|
||||
if self._graph.all_observed(candidate_set):
|
||||
found_valid_adjustment_set = True
|
||||
found_valid_adjustment_set = True
|
||||
num_iterations += 1
|
||||
if self.method_name == "default" and num_iterations > CausalIdentifier.MAX_BACKDOOR_ITERATIONS:
|
||||
if method_name == CausalIdentifier.BACKDOOR_EXHAUSTIVE and num_iterations > CausalIdentifier.MAX_BACKDOOR_ITERATIONS:
|
||||
break
|
||||
if self.method_name == "default" and found_valid_adjustment_set:
|
||||
# If the backdoor method is `maximal-adjustment` or `minimal-adjustment`, return the first found adjustment set.
|
||||
if method_name in {CausalIdentifier.BACKDOOR_MAX, CausalIdentifier.BACKDOOR_MIN} and found_valid_adjustment_set:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Identifier method {self.method_name} not supported. Try one of the following: {CausalIdentifier.VALID_METHOD_NAMES}")
|
||||
#causes_t = self._graph.get_causes(self.treatment_name)
|
||||
#causes_y = self._graph.get_causes(self.outcome_name, remove_edges={'sources':self.treatment_name, 'targets':self.outcome_name})
|
||||
#common_causes = list(causes_t.intersection(causes_y))
|
||||
#self.logger.info("Common causes of treatment and outcome:" + str(common_causes))
|
||||
observed_backdoor_sets = [ bset for bset in backdoor_sets if self._graph.all_observed(bset["backdoor_set"])]
|
||||
if len(observed_backdoor_sets)==0:
|
||||
return backdoor_sets
|
||||
else:
|
||||
return observed_backdoor_sets
|
||||
raise ValueError(f"Identifier method {method_name} not supported. Try one of the following: {CausalIdentifier.METHOD_NAMES}")
|
||||
|
||||
return backdoor_sets
|
||||
|
||||
def get_default_backdoor_set_id(self, backdoor_sets_dict):
|
||||
# Adding a None estimand if no backdoor set found
|
||||
if len(backdoor_sets_dict) == 0:
|
||||
return None
|
||||
|
||||
# Default set contains minimum possible number of instrumental variables, to prevent lowering variance in the treatment variable.
|
||||
instrument_names = set(self._graph.get_instruments(self.treatment_name, self.outcome_name))
|
||||
iv_count_dict = {key: len(set(bdoor_set).intersection(instrument_names)) for key, bdoor_set in backdoor_sets_dict.items()}
|
||||
min_iv_count = min(iv_count_dict.values())
|
||||
min_iv_keys = {key for key, iv_count in iv_count_dict.items() if iv_count == min_iv_count}
|
||||
min_iv_backdoor_sets_dict = {key: backdoor_sets_dict[key] for key in min_iv_keys}
|
||||
|
||||
# Default set is the one with the most number of adjustment variables (optimizing for minimum (unknown) bias not for efficiency)
|
||||
max_set_length = -1
|
||||
default_key = None
|
||||
# Default set is the one with the most number of adjustment variables (optimizing for minimum (unknown) bias not for efficiency)
|
||||
for key, bdoor_set in backdoor_sets_dict.items():
|
||||
for key, bdoor_set in min_iv_backdoor_sets_dict.items():
|
||||
if len(bdoor_set) > max_set_length:
|
||||
max_set_length = len(bdoor_set)
|
||||
default_key = key
|
||||
|
@ -413,19 +427,6 @@ class CausalIdentifier:
|
|||
backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None)
|
||||
return backdoor_variables_dict
|
||||
|
||||
def get_default_backdoor_set_id(self, backdoor_sets_dict):
|
||||
# Adding a None estimand if no backdoor set found
|
||||
if len(backdoor_sets_dict) == 0:
|
||||
return None
|
||||
max_set_length = -1
|
||||
default_key = None
|
||||
# Default set is the one with the most number of adjustment variables (optimizing for minimum (unknown) bias not for efficiency)
|
||||
for key, bdoor_set in backdoor_sets_dict.items():
|
||||
if len(bdoor_set) > max_set_length:
|
||||
max_set_length = len(bdoor_set)
|
||||
default_key = key
|
||||
return default_key
|
||||
|
||||
def construct_backdoor_estimand(self, estimand_type, treatment_name,
|
||||
outcome_name, common_causes):
|
||||
# TODO: outputs string for now, but ideally should do symbolic
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
from dowhy.causal_graph import CausalGraph
|
||||
|
||||
from .example_graphs import TEST_GRAPH_SOLUTIONS
|
||||
|
||||
|
||||
class IdentificationTestGraphSolution(object):
|
||||
|
||||
def __init__(self, graph_str, observed_variables, biased_sets, minimal_adjustment_sets, maximal_adjustment_sets):
|
||||
self.graph = CausalGraph("X", "Y", graph_str, observed_node_names=observed_variables)
|
||||
self.observed_variables = observed_variables
|
||||
self.biased_sets = biased_sets
|
||||
self.minimal_adjustment_sets = minimal_adjustment_sets
|
||||
self.maximal_adjustment_sets = maximal_adjustment_sets
|
||||
|
||||
|
||||
@pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys())
|
||||
def example_graph_solution(request):
|
||||
return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param])
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
The example below illustrate some of the common examples of causal graph in books.
|
||||
This file is meant to group all of the graph definitions as well as the expected results of identification algorithms in one place.
|
||||
Each example graph is contained of the following values:
|
||||
|
||||
* graph_str - The graph string in GML format.
|
||||
* observed_variables - A list of observed variables in the graph. This will be used to test no unobserved variables are offered in the solution.
|
||||
* biased_sets - The sets that we shouldn't get in the output as they incur biased estimates of the causal effect.
|
||||
* minimal_adjustment_sets - Sets of observed variables that should be returned when 'minimal-adjustment' is specified as the backdoor method.
|
||||
If no adjustment is necessary given the graph, minimal adjustment set should be the empty set.
|
||||
* maximal_adjustment_sets - Sets of observed variables that should be returned when 'maximal-adjustment' is specified as the backdoor method.
|
||||
"""
|
||||
|
||||
TEST_GRAPH_SOLUTIONS = {
|
||||
# Example is selected from Pearl J. "Causality" 2nd Edition, from chapter 3.3.1 on backoor criterion.
|
||||
"pearl_backdoor_example_graph": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z1" label "Z1"]
|
||||
node[id "Z2" label "Z2"]
|
||||
node[id "Z3" label "Z3"]
|
||||
node[id "Z4" label "Z4"]
|
||||
node[id "Z5" label "Z5"]
|
||||
node[id "Z6" label "Z6"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "Z1" target "Z3"]
|
||||
edge[source "Z1" target "Z4"]
|
||||
edge[source "Z2" target "Z4"]
|
||||
edge[source "Z2" target "Z5"]
|
||||
edge[source "Z3" target "X"]
|
||||
edge[source "Z4" target "X"]
|
||||
edge[source "Z4" target "Y"]
|
||||
edge[source "Z5" target "Y"]
|
||||
edge[source "Z6" target "Y"]
|
||||
edge[source "X" target "Z6"]]
|
||||
""",
|
||||
observed_variables = ["Z1", "Z2", "Z3", "Z4", "Z5", "Z6", "X", "Y"],
|
||||
biased_sets = [{"Z4"}, {"Z6"}, {"Z5"}, {"Z2"}, {"Z1"}, {"Z3"}, {"Z1", "Z3"}, {"Z2", "Z5"}, {"Z1", "Z2"}],
|
||||
minimal_adjustment_sets = [{"Z1", "Z4"}, {"Z2", "Z4"}, {"Z3", "Z4"}, {"Z5", "Z4"}],
|
||||
maximal_adjustment_sets = [{"Z1", "Z2", "Z3", "Z4", "Z5"}]
|
||||
),
|
||||
"simple_selection_bias_graph": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z1" label "Z1"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "X" target "Z1"]
|
||||
edge[source "Y" target "Z1"]]
|
||||
""",
|
||||
observed_variables = ["Z1", "X", "Y"],
|
||||
biased_sets = [{"Z1",}],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{}]
|
||||
),
|
||||
"simple_no_confounder_graph": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z1" label "Z1"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "Z1" target "X"]]
|
||||
""",
|
||||
observed_variables=["Z1", "X", "Y"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{"Z1",}]
|
||||
),
|
||||
# The following simpsons paradox examples are taken from Pearl, J {2013}. "Understanding Simpson’s Paradox" - http://ftp.cs.ucla.edu/pub/stat_ser/r414.pdf
|
||||
"pearl_simpsons_paradox_1c": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z" label "Z"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "L1" label "L1"]
|
||||
node[id "L2" label "L2"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "L1" target "X"]
|
||||
edge[source "L1" target "Z"]
|
||||
edge[source "L2" target "Z"]
|
||||
edge[source "L2" target "Y"]]
|
||||
""",
|
||||
observed_variables=["Z", "X", "Y"],
|
||||
biased_sets = [{"Z",}],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{}]
|
||||
),
|
||||
"pearl_simpsons_paradox_1d": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z" label "Z"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "L1" label "L1"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "L1" target "X"]
|
||||
edge[source "L1" target "Z"]
|
||||
edge[source "Z" target "Y"]]
|
||||
""",
|
||||
observed_variables = ["Z", "X", "Y"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [{"Z",}],
|
||||
maximal_adjustment_sets = [{"Z",}]
|
||||
),
|
||||
"pearl_simpsons_paradox_2a": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z" label "Z"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "L" label "L"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "X" target "Z"]
|
||||
edge[source "L" target "Z"]
|
||||
edge[source "L" target "Y"]]
|
||||
""",
|
||||
observed_variables = ["Z", "X", "Y"],
|
||||
biased_sets = [{"Z", }],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{}]
|
||||
),
|
||||
"pearl_simpsons_paradox_2b": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z" label "Z"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "L" label "L"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "Z" target "X"]
|
||||
edge[source "L" target "X"]
|
||||
edge[source "L" target "Y"]]""",
|
||||
observed_variables = ["Z", "X", "Y"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [],
|
||||
maximal_adjustment_sets = [] # Should this be {"Z"}?
|
||||
),
|
||||
"pearl_simpsons_paradox_2b_L_observed": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z" label "Z"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "L" label "L"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "Z" target "X"]
|
||||
edge[source "L" target "X"]
|
||||
edge[source "L" target "Y"]]""",
|
||||
observed_variables = ["Z", "X", "Y", "L"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [{"L"}],
|
||||
maximal_adjustment_sets = [{"L", "Z"}]
|
||||
),
|
||||
"pearl_simpsons_machine_lvl1": dict(
|
||||
graph_str = """graph[directed 1 node[id "Z1" label "Z1"]
|
||||
node[id "Z2" label "Z2"]
|
||||
node[id "Z3" label "Z3"]
|
||||
node[id "L" label "L"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "Z1" target "L"]
|
||||
edge[source "L" target "Z2"]
|
||||
edge[source "Z3" target "Z2"]
|
||||
edge[source "L" target "X"]
|
||||
edge[source "Z3" target "Y"]]
|
||||
""",
|
||||
observed_variables=["Z1", "Z2", "Z3", "X", "Y"],
|
||||
biased_sets = [{"Z2",}, {"Z1", "Z2"}],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{"Z1", "Z2", "Z3"}]
|
||||
),
|
||||
# The following are examples given in the "Book of Why" by Judea Pearl, chapter "The Do-operator and the Back-Door Criterion"
|
||||
"book_of_why_game2": dict(
|
||||
graph_str = """graph[directed 1 node[id "A" label "A"]
|
||||
node[id "B" label "B"]
|
||||
node[id "C" label "C"]
|
||||
node[id "D" label "D"]
|
||||
node[id "E" label "E"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "A" target "X"]
|
||||
edge[source "A" target "B"]
|
||||
edge[source "B" target "C"]
|
||||
edge[source "D" target "B"]
|
||||
edge[source "D" target "E"]
|
||||
edge[source "X" target "E"]
|
||||
edge[source "E" target "Y"]]
|
||||
""",
|
||||
observed_variables = ["A", "B", "C", "D", "E", "X", "Y"],
|
||||
biased_sets = [{"B",}, {"C",}, {"B", "C"}],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{"A", "B", "C", "D"}]
|
||||
),
|
||||
"book_of_why_game5": dict(
|
||||
graph_str = """graph[directed 1 node[id "A" label "A"]
|
||||
node[id "B" label "B"]
|
||||
node[id "C" label "C"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "A" target "X"]
|
||||
edge[source "A" target "B"]
|
||||
edge[source "B" target "X"]
|
||||
edge[source "C" target "B"]
|
||||
edge[source "C" target "Y"]
|
||||
edge[source "X" target "Y"]]
|
||||
""",
|
||||
observed_variables = ["A", "B", "C", "X", "Y"],
|
||||
biased_sets = [{"B",}],
|
||||
minimal_adjustment_sets = [{"C"}],
|
||||
maximal_adjustment_sets = [{"A", "B", "C"}]
|
||||
),
|
||||
"book_of_why_game5_C_is_unobserved": dict(
|
||||
graph_str = """graph[directed 1 node[id "A" label "A"]
|
||||
node[id "B" label "B"]
|
||||
node[id "C" label "C"]
|
||||
node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
edge[source "A" target "X"]
|
||||
edge[source "A" target "B"]
|
||||
edge[source "B" target "X"]
|
||||
edge[source "C" target "B"]
|
||||
edge[source "C" target "Y"]
|
||||
edge[source "X" target "Y"]]
|
||||
""",
|
||||
observed_variables = ["A", "B", "X", "Y"],
|
||||
biased_sets = [{"B",}],
|
||||
minimal_adjustment_sets = [{"A", "B"}],
|
||||
maximal_adjustment_sets = [{"A", "B"}]
|
||||
),
|
||||
"no_treatment_but_valid_maximal_set": dict(
|
||||
graph_str = """graph[directed 1 node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "Z1" label "Z1"]
|
||||
node[id "Z2" label "Z2"]
|
||||
edge[source "X" target "Y"]
|
||||
edge[source "X" target "Z1"]
|
||||
edge[source "Z1" target "Y"]
|
||||
edge[source "Z2" target "Z1"]]
|
||||
""",
|
||||
observed_variables = ["Z1", "Z2", "X", "Y"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [{}],
|
||||
maximal_adjustment_sets = [{"Z2"}]
|
||||
),
|
||||
"common_cause_of_mediator1": dict(
|
||||
graph_str = """graph[directed 1 node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "U" label "U"]
|
||||
node[id "Z" label "Z"]
|
||||
node[id "M" label "M"]
|
||||
edge[source "X" target "M"]
|
||||
edge[source "M" target "Y"]
|
||||
edge[source "U" target "X"]
|
||||
edge[source "U" target "Z"]
|
||||
edge[source "Z" target "M"]]
|
||||
""",
|
||||
observed_variables = ["X", "Y", "Z", "M"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [{"Z"}],
|
||||
maximal_adjustment_sets = [{"Z"}]
|
||||
),
|
||||
"common_cause_of_mediator2": dict(
|
||||
graph_str = """graph[directed 1 node[id "X" label "X"]
|
||||
node[id "Y" label "Y"]
|
||||
node[id "U" label "U"]
|
||||
node[id "Z" label "Z"]
|
||||
node[id "M" label "M"]
|
||||
edge[source "X" target "M"]
|
||||
edge[source "M" target "Y"]
|
||||
edge[source "U" target "Z"]
|
||||
edge[source "U" target "M"]
|
||||
edge[source "Z" target "X"]]
|
||||
""",
|
||||
observed_variables = ["X", "Y", "Z", "M"],
|
||||
biased_sets = [],
|
||||
minimal_adjustment_sets = [{"Z"}],
|
||||
maximal_adjustment_sets = [{"Z"}]
|
||||
)
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
import pytest
|
||||
from dowhy.causal_graph import CausalGraph
|
||||
from dowhy.causal_identifier import CausalIdentifier
|
||||
|
||||
from .base import IdentificationTestGraphSolution, example_graph_solution
|
||||
|
||||
|
||||
class TestBackdoorIdentification(object):
|
||||
|
||||
def test_identify_backdoor_no_biased_sets(self, example_graph_solution: IdentificationTestGraphSolution):
|
||||
graph = example_graph_solution.graph
|
||||
biased_sets = example_graph_solution.biased_sets
|
||||
identifier = CausalIdentifier(graph, "nonparametric-ate", method_name="exhaustive-search")
|
||||
|
||||
backdoor_results = identifier.identify_backdoor("X", "Y", include_unobserved=False)
|
||||
backdoor_sets = [
|
||||
set(backdoor_result_dict["backdoor_set"])
|
||||
for backdoor_result_dict in backdoor_results
|
||||
if len(backdoor_result_dict["backdoor_set"]) > 0
|
||||
]
|
||||
|
||||
assert (
|
||||
(len(backdoor_sets) == 0 and len(biased_sets) == 0) # No biased sets exist and that's expected.
|
||||
or
|
||||
all([
|
||||
set(biased_backdoor_set) not in backdoor_sets
|
||||
for biased_backdoor_set in biased_sets
|
||||
]) # No sets that would induce biased results are present in the solution.
|
||||
)
|
||||
|
||||
def test_identify_backdoor_unobserved_not_in_backdoor_set(self, example_graph_solution: IdentificationTestGraphSolution):
|
||||
graph = example_graph_solution.graph
|
||||
observed_variables = example_graph_solution.observed_variables
|
||||
identifier = CausalIdentifier(graph, "nonparametric-ate", method_name="exhaustive-search")
|
||||
|
||||
backdoor_results = identifier.identify_backdoor("X", "Y", include_unobserved=False)
|
||||
backdoor_sets = [
|
||||
set(backdoor_result_dict["backdoor_set"])
|
||||
for backdoor_result_dict in backdoor_results
|
||||
if len(backdoor_result_dict["backdoor_set"]) > 0
|
||||
]
|
||||
|
||||
assert all([variable in observed_variables for backdoor_set in backdoor_sets for variable in backdoor_set]) # All variables used in the backdoor sets must be observed.
|
||||
|
||||
def test_identify_backdoor_minimal_adjustment(self, example_graph_solution: IdentificationTestGraphSolution):
|
||||
graph = example_graph_solution.graph
|
||||
expected_sets = example_graph_solution.minimal_adjustment_sets
|
||||
identifier = CausalIdentifier(graph, "nonparametric-ate", method_name="minimal-adjustment", proceed_when_unidentifiable=False)
|
||||
|
||||
backdoor_results = identifier.identify_backdoor("X", "Y", include_unobserved=False)
|
||||
backdoor_sets = [
|
||||
set(backdoor_result_dict["backdoor_set"])
|
||||
for backdoor_result_dict in backdoor_results
|
||||
]
|
||||
|
||||
assert (
|
||||
((len(backdoor_sets) == 0) and (len(expected_sets) == 0)) # No adjustments exist and that's expected.
|
||||
or
|
||||
all([
|
||||
set(expected_set) in backdoor_sets
|
||||
for expected_set in expected_sets
|
||||
])
|
||||
)
|
||||
|
||||
def test_identify_backdoor_maximal_adjustment(self, example_graph_solution: IdentificationTestGraphSolution):
|
||||
graph = example_graph_solution.graph
|
||||
expected_sets = example_graph_solution.maximal_adjustment_sets
|
||||
identifier = CausalIdentifier(graph, "nonparametric-ate", method_name="maximal-adjustment", proceed_when_unidentifiable=False)
|
||||
|
||||
backdoor_results = identifier.identify_backdoor("X", "Y", include_unobserved=False)
|
||||
|
||||
backdoor_sets = [
|
||||
set(backdoor_result_dict["backdoor_set"])
|
||||
for backdoor_result_dict in backdoor_results
|
||||
]
|
||||
|
||||
assert (
|
||||
((len(backdoor_sets) == 0) and (len(expected_sets) == 0)) # No adjustments exist and that's expected.
|
||||
or
|
||||
all([
|
||||
set(expected_set) in backdoor_sets
|
||||
for expected_set in expected_sets
|
||||
])
|
||||
)
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
|
||||
import dowhy.datasets
|
||||
from dowhy import CausalModel
|
||||
import logging
|
||||
|
||||
|
||||
class TestRefuter(object):
|
||||
def __init__(self, error_tolerance, estimator_method, refuter_method,
|
||||
|
@ -42,7 +44,7 @@ class TestRefuter(object):
|
|||
proceed_when_unidentifiable=True,
|
||||
test_significance=None
|
||||
)
|
||||
target_estimand = model.identify_effect()
|
||||
target_estimand = model.identify_effect(method_name="exhaustive-search")
|
||||
target_estimand.set_identifier_method(self.identifier_method)
|
||||
ate_estimate = model.estimate_effect(
|
||||
identified_estimand=target_estimand,
|
||||
|
|
Загрузка…
Ссылка в новой задаче