ID Algorithm for Causal Identification (#280)
* Added code for ID Algorithm * Replaced set and list by OrderedSet * Need to prepare recursive datastructure for results * Added function to utils/graph_operations.py * Added print in a readable fashion * Updated ID notebook * IDIdentifier class inherited from CausalIdentifier
This commit is contained in:
Родитель
c2c3da5fef
Коммит
56c0b34729
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -48,7 +48,7 @@ class CausalIdentifier:
|
|||
|
||||
If estimand_type is non-parametric ATE, then uses backdoor, instrumental variable and frontdoor identification methods, to check if an identified estimand exists, based on the causal graph.
|
||||
|
||||
:param self: instance of the CausalEstimator class (or its subclass)
|
||||
:param self: instance of the CausalIdentifier class (or its subclass)
|
||||
:returns: target estimand, an instance of the IdentifiedEstimand class
|
||||
"""
|
||||
if self.estimand_type == CausalIdentifier.NONPARAMETRIC_ATE:
|
||||
|
@ -79,7 +79,7 @@ class CausalIdentifier:
|
|||
default_backdoor_id = self.get_default_backdoor_set_id(backdoor_variables_dict)
|
||||
estimands_dict["backdoor"] = estimands_dict.get(str(default_backdoor_id), None)
|
||||
backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None)
|
||||
|
||||
|
||||
### 2. INSTRUMENTAL VARIABLE IDENTIFICATION
|
||||
# Now checking if there is also a valid iv estimand
|
||||
instrument_names = self._graph.get_instruments(self.treatment_name,
|
||||
|
@ -97,7 +97,7 @@ class CausalIdentifier:
|
|||
estimands_dict["iv"] = iv_estimand_expr
|
||||
else:
|
||||
estimands_dict["iv"] = None
|
||||
|
||||
|
||||
### 3. FRONTDOOR IDENTIFICATION
|
||||
# Now checking if there is a valid frontdoor variable
|
||||
frontdoor_variables_names = self.identify_frontdoor()
|
||||
|
@ -116,7 +116,7 @@ class CausalIdentifier:
|
|||
mediation_second_stage_confounders = self.identify_mediation_second_stage_confounders(frontdoor_variables_names, self.outcome_name)
|
||||
else:
|
||||
estimands_dict["frontdoor"] = None
|
||||
|
||||
|
||||
# Finally returning the estimand object
|
||||
estimand = IdentifiedEstimand(
|
||||
self,
|
||||
|
@ -237,13 +237,11 @@ class CausalIdentifier:
|
|||
)
|
||||
return estimand
|
||||
|
||||
|
||||
|
||||
|
||||
def identify_backdoor(self, treatment_name, outcome_name, include_unobserved=True):
|
||||
backdoor_sets = []
|
||||
backdoor_paths = self._graph.get_backdoor_paths(treatment_name, outcome_name)
|
||||
method_name = self.method_name if self.method_name != CausalIdentifier.BACKDOOR_DEFAULT else CausalIdentifier.DEFAULT_BACKDOOR_METHOD
|
||||
|
||||
# First, checking if empty set is a valid backdoor set
|
||||
empty_set = set()
|
||||
check = self._graph.check_valid_backdoor_set(treatment_name, outcome_name, empty_set,
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import networkx as nx
|
||||
from dowhy.utils.ordered_set import OrderedSet
|
||||
from dowhy.utils.graph_operations import find_c_components, induced_graph, find_ancestor
|
||||
from dowhy.causal_identifier import CausalIdentifier
|
||||
from dowhy.utils.api import parse_state
|
||||
|
||||
class IDExpression:
|
||||
"""
|
||||
Class for storing a causal estimand, as a result of the identification step using the ID algorithm.
|
||||
The object stores a list of estimators(self._product) whose porduct must be obtained and a list of variables (self._sum) over which the product must be marginalized.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._product = []
|
||||
self._sum = []
|
||||
|
||||
def add_product(self, element):
|
||||
'''
|
||||
Add an estimator to the list of product.
|
||||
|
||||
:param element: Estimator to append to the product list.
|
||||
'''
|
||||
self._product.append(element)
|
||||
|
||||
def add_sum(self, element):
|
||||
'''
|
||||
Add variables to the list.
|
||||
|
||||
:param element: Set of variables to append to the list self._sum.
|
||||
'''
|
||||
for el in element:
|
||||
self._sum.append(el)
|
||||
|
||||
def get_val(self, return_type):
|
||||
"""
|
||||
Get either the list of estimators (for product) or list of variables (for the marginalization).
|
||||
|
||||
:param return_type: "prod" to return the list of estimators or "sum" to return the list of variables.
|
||||
"""
|
||||
if return_type=="prod":
|
||||
return self._product
|
||||
elif return_type=="sum":
|
||||
return self._sum
|
||||
else:
|
||||
raise Exception("Provide correct return type.")
|
||||
|
||||
def _print_estimator(self, prefix, estimator=None, start=False):
|
||||
'''
|
||||
Print the IDExpression object.
|
||||
'''
|
||||
if estimator is None:
|
||||
return None
|
||||
|
||||
string = ""
|
||||
if isinstance(estimator, IDExpression):
|
||||
s = True if len(estimator.get_val(return_type="sum"))>0 else False
|
||||
if s:
|
||||
sum_vars = "{" + ",".join(estimator.get_val(return_type="sum")) + "}"
|
||||
string += prefix + "Sum over " + sum_vars + ":\n"
|
||||
prefix += "\t"
|
||||
for expression in estimator.get_val(return_type='prod'):
|
||||
add_string = self._print_estimator(prefix, expression)
|
||||
if add_string is None:
|
||||
return None
|
||||
else:
|
||||
string += add_string
|
||||
else:
|
||||
outcome_vars = list(estimator['outcome_vars'])
|
||||
condition_vars = list(estimator['condition_vars'])
|
||||
string += prefix + "Predictor: P(" + ",".join(outcome_vars)
|
||||
if len(condition_vars)>0:
|
||||
string += "|" + ",".join(condition_vars)
|
||||
string += ")\n"
|
||||
if start:
|
||||
string = string[:-1]
|
||||
return string
|
||||
|
||||
def __str__(self):
|
||||
string = self._print_estimator(prefix="", estimator=self, start=True)
|
||||
if string is None:
|
||||
return "The graph is not identifiable."
|
||||
else:
|
||||
return string
|
||||
|
||||
class IDIdentifier(CausalIdentifier):
|
||||
|
||||
def __init__(self, graph, estimand_type,
|
||||
method_name = "default",
|
||||
proceed_when_unidentifiable=None):
|
||||
'''
|
||||
Class to perform identification using the ID algorithm.
|
||||
|
||||
:param self: instance of the IDIdentifier class.
|
||||
:param estimand_type: Type of estimand ("nonparametric-ate", "nonparametric-nde" or "nonparametric-nie").
|
||||
:param method_name: Identification method ("id-algorithm" in this case).
|
||||
:param proceed_when_unidentifiable: If True, proceed with identification even in the presence of unobserved/missing variables.
|
||||
'''
|
||||
|
||||
super().__init__(graph, estimand_type, method_name, proceed_when_unidentifiable)
|
||||
|
||||
if self.estimand_type != CausalIdentifier.NONPARAMETRIC_ATE:
|
||||
raise Exception("The estimand type should be 'non-parametric ate' for the ID method type.")
|
||||
|
||||
self._treatment_names = OrderedSet(parse_state(graph.treatment_name))
|
||||
self._outcome_names = OrderedSet(parse_state(graph.outcome_name))
|
||||
self._adjacency_matrix = graph.get_adjacency_matrix()
|
||||
|
||||
try:
|
||||
self._tsort_node_names = OrderedSet(list(nx.topological_sort(graph._graph))) # topological sorting of graph nodes
|
||||
except:
|
||||
raise Exception("The graph must be a directed acyclic graph (DAG).")
|
||||
self._node_names = OrderedSet(graph._graph.nodes)
|
||||
|
||||
def identify_effect(self, treatment_names=None, outcome_names=None, adjacency_matrix=None, node_names=None):
|
||||
'''
|
||||
Implementation of the ID algorithm.
|
||||
Link - https://ftp.cs.ucla.edu/pub/stat_ser/shpitser-thesis.pdf
|
||||
The pseudo code has been provided on Pg 40.
|
||||
|
||||
:param self: instance of the IDIdentifier class.
|
||||
:param treatment_names: OrderedSet comprising names of treatment variables.
|
||||
:param outcome_names:OrderedSet comprising names of outcome variables.
|
||||
:param adjacency_matrix: Graph adjacency matrix.
|
||||
:param node_names: OrderedSet comprising names of all nodes in the graph.
|
||||
:returns: target estimand, an instance of the IDExpression class.
|
||||
'''
|
||||
if adjacency_matrix is None:
|
||||
adjacency_matrix = self._adjacency_matrix
|
||||
if treatment_names is None:
|
||||
treatment_names = self._treatment_names
|
||||
if outcome_names is None:
|
||||
outcome_names = self._outcome_names
|
||||
if node_names is None:
|
||||
node_names = self._node_names
|
||||
node2idx, idx2node = self._idx_node_mapping(node_names)
|
||||
|
||||
# Estimators list for returning after identification
|
||||
estimators = IDExpression()
|
||||
|
||||
# Line 1
|
||||
# If no action has been taken, the effect on Y is just the marginal of the observational distribution P(v) on Y.
|
||||
if len(treatment_names) == 0:
|
||||
identifier = IDExpression()
|
||||
estimator = {}
|
||||
estimator['outcome_vars'] = node_names
|
||||
estimator['condition_vars'] = OrderedSet()
|
||||
identifier.add_product(estimator)
|
||||
identifier.add_sum(node_names.difference(outcome_names))
|
||||
estimators.add_product(identifier)
|
||||
return estimators
|
||||
|
||||
# Line 2
|
||||
# If we are interested in the effect on Y, it is sufficient to restrict our attention on the parts of the model ancestral to Y.
|
||||
ancestors = find_ancestor(outcome_names, node_names, adjacency_matrix, node2idx, idx2node)
|
||||
if len(node_names.difference(ancestors)) != 0: # If there are elements which are not the ancestor of the outcome variables
|
||||
# Modify list of valid nodes
|
||||
treatment_names = treatment_names.intersection(ancestors)
|
||||
node_names = node_names.intersection(ancestors)
|
||||
adjacency_matrix = induced_graph(node_set=node_names, adjacency_matrix=adjacency_matrix, node2idx=node2idx)
|
||||
return self.identify_effect(treatment_names=treatment_names, outcome_names=outcome_names, adjacency_matrix=adjacency_matrix, node_names=node_names)
|
||||
|
||||
# Line 3 - forces an action on any node where such an action would have no effect on Y – assuming we already acted on X.
|
||||
# Modify adjacency matrix to obtain that corresponding to do(X)
|
||||
adjacency_matrix_do_x = adjacency_matrix.copy()
|
||||
for x in treatment_names:
|
||||
x_idx = node2idx[x]
|
||||
for i in range(len(node_names)):
|
||||
adjacency_matrix_do_x[i, x_idx] = 0
|
||||
ancestors = find_ancestor(outcome_names, node_names, adjacency_matrix_do_x, node2idx, idx2node)
|
||||
W = node_names.difference(treatment_names).difference(ancestors)
|
||||
if len(W) != 0:
|
||||
return self.identify_effect(treatment_names = treatment_names.union(W), outcome_names=outcome_names, adjacency_matrix=adjacency_matrix, node_names=node_names)
|
||||
|
||||
# Line 4 - Decomposes the problem into a set of smaller problems using the key property of C-component factorization of causal models.
|
||||
# If the entire graph is a single C-component already, further problem decomposition is impossible, and we must provide base cases.
|
||||
# Modify adjacency matrix to remove treatment variables
|
||||
node_names_minus_x = node_names.difference(treatment_names)
|
||||
node2idx_minus_x, idx2node_minus_x = self._idx_node_mapping(node_names_minus_x)
|
||||
adjacency_matrix_minus_x = induced_graph(node_set=node_names_minus_x, adjacency_matrix=adjacency_matrix, node2idx=node2idx)
|
||||
c_components = find_c_components(adjacency_matrix=adjacency_matrix_minus_x, node_set=node_names_minus_x, idx2node=idx2node_minus_x)
|
||||
if len(c_components)>1:
|
||||
identifier = IDExpression()
|
||||
sum_over_set = node_names.difference(outcome_names.union(treatment_names))
|
||||
for component in c_components:
|
||||
expressions = self.identify_effect(treatment_names=node_names.difference(component), outcome_names=OrderedSet(list(component)), adjacency_matrix=adjacency_matrix, node_names=node_names)
|
||||
for expression in expressions.get_val(return_type="prod"):
|
||||
identifier.add_product(expression)
|
||||
identifier.add_sum(sum_over_set)
|
||||
estimators.add_product(identifier)
|
||||
return estimators
|
||||
|
||||
|
||||
# Line 5 - The algorithms fails due to the presence of a hedge - the graph G, and a subgraph S that does not contain any X nodes.
|
||||
S = c_components[0]
|
||||
c_components_G = find_c_components(adjacency_matrix=adjacency_matrix, node_set=node_names, idx2node=idx2node)
|
||||
if len(c_components_G)==1 and c_components_G[0] == node_names:
|
||||
return None
|
||||
|
||||
# Line 6 - If there are no bidirected arcs from X to the other nodes in the current subproblem under consideration, then we can replace acting on X by conditioning, and thus solve the subproblem.
|
||||
if S in c_components_G:
|
||||
sum_over_set = S.difference(outcome_names)
|
||||
prev_nodes = []
|
||||
for node in self._tsort_node_names:
|
||||
if node in S:
|
||||
identifier = IDExpression()
|
||||
estimator = {}
|
||||
estimator['outcome_vars'] = OrderedSet([node])
|
||||
estimator['condition_vars'] = OrderedSet(prev_nodes)
|
||||
identifier.add_product(estimator)
|
||||
identifier.add_sum(sum_over_set)
|
||||
estimators.add_product(identifier)
|
||||
prev_nodes.append(node)
|
||||
return estimators
|
||||
|
||||
|
||||
# Line 7 - This is the most complicated case in the algorithm. Explain in the second last paragraph on Pg 41 of the link provided in the docstring above.
|
||||
for component in c_components_G:
|
||||
C = S.difference(component)
|
||||
if C.is_empty() is None:
|
||||
return self.identify_effect(treatment_names=treatment_names.intersection(component), outcome_names=outcome_names, adjacency_matrix=induced_graph(node_set=component, adjacency_matrix=adjacency_matrix,node2idx=node2idx), node_names=node_names)
|
||||
|
||||
def _idx_node_mapping(self, node_names):
|
||||
'''
|
||||
Obtain the node name to index and index to node name mappings.
|
||||
|
||||
:param node_names: Name of all nodes in the graph.
|
||||
:return: node to index and index to node mappings.
|
||||
'''
|
||||
node2idx = {}
|
||||
idx2node = {}
|
||||
for i, node in enumerate(node_names.get_all()):
|
||||
node2idx[node] = i
|
||||
idx2node[i] = node
|
||||
return node2idx, idx2node
|
|
@ -12,6 +12,7 @@ import dowhy.utils.cli_helpers as cli
|
|||
from dowhy.causal_estimator import CausalEstimate
|
||||
from dowhy.causal_graph import CausalGraph
|
||||
from dowhy.causal_identifier import CausalIdentifier
|
||||
from dowhy.causal_identifiers.id_identifier import IDIdentifier
|
||||
from dowhy.utils.api import parse_state
|
||||
|
||||
init_printing() # To display symbolic math symbols
|
||||
|
@ -153,6 +154,7 @@ class CausalModel:
|
|||
method_name="default", proceed_when_unidentifiable=None):
|
||||
"""Identify the causal effect to be estimated, using properties of the causal graph.
|
||||
|
||||
:param method_name: Method name for identification algorithm. ("id-algorithm" or "default")
|
||||
:param proceed_when_unidentifiable: Binary flag indicating whether identification should proceed in the presence of (potential) unobserved confounders.
|
||||
:returns: a probability expression (estimand) for the causal effect if identified, else NULL
|
||||
|
||||
|
@ -161,13 +163,19 @@ class CausalModel:
|
|||
proceed_when_unidentifiable = self._proceed_when_unidentifiable
|
||||
if estimand_type is None:
|
||||
estimand_type = self._estimand_type
|
||||
|
||||
self.identifier = CausalIdentifier(self._graph,
|
||||
|
||||
if method_name == "id-algorithm":
|
||||
self.identifier = IDIdentifier(self._graph,
|
||||
estimand_type,
|
||||
method_name,
|
||||
proceed_when_unidentifiable=proceed_when_unidentifiable)
|
||||
else:
|
||||
self.identifier = CausalIdentifier(self._graph,
|
||||
estimand_type,
|
||||
method_name,
|
||||
proceed_when_unidentifiable=proceed_when_unidentifiable)
|
||||
identified_estimand = self.identifier.identify_effect()
|
||||
|
||||
|
||||
return identified_estimand
|
||||
|
||||
def estimate_effect(self, identified_estimand, method_name=None,
|
||||
|
|
|
@ -1,7 +1,16 @@
|
|||
import numpy as np
|
||||
import graphviz
|
||||
from queue import LifoQueue
|
||||
from dowhy.utils.ordered_set import OrderedSet
|
||||
|
||||
def adjacency_matrix_to_graph(adjacency_matrix, labels=None):
|
||||
'''
|
||||
Convert a given graph adjacency matrix to DOT format.
|
||||
|
||||
:param adjacency_matrix: A numpy array representing the graph adjacency matrix.
|
||||
:param labels: List of labels.
|
||||
:returns: Graph in DOT format.
|
||||
'''
|
||||
# Only consider edges have absolute edge weight > 0.01
|
||||
idx = np.abs(adjacency_matrix) > 0.01
|
||||
dirs = np.where(idx)
|
||||
|
@ -16,7 +25,97 @@ def adjacency_matrix_to_graph(adjacency_matrix, labels=None):
|
|||
def str_to_dot(string):
|
||||
'''
|
||||
Converts input string from graphviz library to valid DOT graph format.
|
||||
|
||||
:param string: Graph in DOT format.
|
||||
:returns: DOT string converted to a suitable format for the DoWhy library.
|
||||
'''
|
||||
graph = string.replace('\n', ';').replace('\t','')
|
||||
graph = graph[:9] + graph[10:-2] + graph[-1] # Removing unnecessary characters from string
|
||||
return graph
|
||||
return graph
|
||||
|
||||
def find_ancestor(node_set, node_names, adjacency_matrix, node2idx, idx2node):
|
||||
'''
|
||||
Finds ancestors of a given set of nodes in a given graph.
|
||||
|
||||
:param node_set: Set of nodes whos ancestors must be obtained.
|
||||
:param node_names: Name of all nodes in the graph.
|
||||
:param adjacency_matrix: Graph adjacency matrix.
|
||||
:param node2idx: A dictionary mapping node names to their row or column index in the adjacency matrix.
|
||||
:param idx2node: A dictionary mapping the row or column indices in the adjacency matrix to the corresponding node names.
|
||||
:returns: OrderedSet containing ancestors of all nodes in the node_set.
|
||||
'''
|
||||
|
||||
def find_ancestor_help(node_name, node_names, adjacency_matrix, node2idx, idx2node):
|
||||
ancestors = OrderedSet()
|
||||
nodes_to_visit = LifoQueue(maxsize = len(node_names))
|
||||
nodes_to_visit.put(node2idx[node_name])
|
||||
while not nodes_to_visit.empty():
|
||||
child = nodes_to_visit.get()
|
||||
ancestors.add(idx2node[child])
|
||||
for i in range(len(node_names)):
|
||||
if idx2node[i] not in ancestors and adjacency_matrix[i, child] == 1: # For edge a->b, a is along height and b is along width of adjacency matrix
|
||||
nodes_to_visit.put(i)
|
||||
return ancestors
|
||||
|
||||
ancestors = OrderedSet()
|
||||
for node_name in node_set.get_all():
|
||||
ancestors = ancestors.union(find_ancestor_help(node_name, node_names, adjacency_matrix, node2idx, idx2node))
|
||||
return ancestors
|
||||
|
||||
def induced_graph(node_set, adjacency_matrix, node2idx):
|
||||
'''
|
||||
To obtain the induced graph corresponding to a subset of nodes.
|
||||
|
||||
:param node_set: Set of nodes whos ancestors must be obtained.
|
||||
:param adjacency_matrix: Graph adjacency matrix.
|
||||
:param node2idx: A dictionary mapping node names to their row or column index in the adjacency matrix.
|
||||
:returns: Numpy array representing the adjacency matrix of the induced graph.
|
||||
'''
|
||||
node_idx_list = [node2idx[node] for node in node_set]
|
||||
node_idx_list.sort()
|
||||
adjacency_matrix_induced = adjacency_matrix.copy()
|
||||
adjacency_matrix_induced = adjacency_matrix_induced[node_idx_list]
|
||||
adjacency_matrix_induced = adjacency_matrix_induced[:, node_idx_list]
|
||||
return adjacency_matrix_induced
|
||||
|
||||
def find_c_components(adjacency_matrix, node_set, idx2node):
|
||||
'''
|
||||
Obtain C-components in a graph.
|
||||
|
||||
:param adjacency_matrix: Graph adjacency matrix.
|
||||
:param node_set: Set of nodes whos ancestors must be obtained.
|
||||
:param idx2node: A dictionary mapping the row or column indices in the adjacency matrix to the corresponding node names.
|
||||
:returns: List of C-components in the graph.
|
||||
'''
|
||||
num_nodes = len(node_set)
|
||||
adj_matrix = adjacency_matrix.copy()
|
||||
adjacency_list = [[] for _ in range(num_nodes)]
|
||||
|
||||
# Modify graph such that it only contains bidirected edges
|
||||
for h in range(0, num_nodes-1):
|
||||
for w in range(h+1, num_nodes):
|
||||
if adjacency_matrix[h, w]==1 and adjacency_matrix[w, h]==1:
|
||||
adjacency_list[h].append(w)
|
||||
adjacency_list[w].append(h)
|
||||
else:
|
||||
adj_matrix[h, w] = 0
|
||||
adj_matrix[w, h] = 0
|
||||
|
||||
# Find c components by finding connected components on the undirected graph
|
||||
visited = [False for _ in range(num_nodes)]
|
||||
|
||||
def dfs(node_idx, component):
|
||||
visited[node_idx] = True
|
||||
component.add(idx2node[node_idx])
|
||||
for neighbour in adjacency_list[node_idx]:
|
||||
if visited[neighbour] == False:
|
||||
dfs(neighbour, component)
|
||||
|
||||
c_components = []
|
||||
for i in range(num_nodes):
|
||||
if visited[i] == False:
|
||||
component = OrderedSet()
|
||||
dfs(i, component)
|
||||
c_components.append(component)
|
||||
|
||||
return c_components
|
|
@ -0,0 +1,113 @@
|
|||
class OrderedSet:
|
||||
'''
|
||||
Python class for ordered set.
|
||||
Code taken from https://github.com/buyalsky/ordered-hash-set/tree/5198b23e01faeac3f5398ab2c08cb013d14b3702.
|
||||
'''
|
||||
def __init__(self, elements=None):
|
||||
self._set = {}
|
||||
self._start = None
|
||||
self._end = None
|
||||
if elements is not None:
|
||||
for element in elements:
|
||||
self.add(element)
|
||||
|
||||
def add(self, element):
|
||||
"""
|
||||
Function to add an element to do set if it does not exit.
|
||||
|
||||
:param element: element to be added.
|
||||
"""
|
||||
if self._start is None:
|
||||
self._start = element
|
||||
|
||||
if element not in self._set.keys():
|
||||
self._set[element] = None
|
||||
if len(self._set) > 1:
|
||||
self._set[self._end] = element
|
||||
self._end = element
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
Function to return list of all elements in the set.
|
||||
|
||||
:returns: List of all items in the set.
|
||||
"""
|
||||
return list(self)
|
||||
|
||||
def is_empty(self):
|
||||
"""
|
||||
Function to determine if the set is empty or not.
|
||||
|
||||
:returns: ``True`` if the set is empty, ``False`` otherwise.
|
||||
"""
|
||||
return self.__len__() == 0
|
||||
|
||||
def union(self, other_set):
|
||||
"""
|
||||
Function to compute the union of self._set and other_set.
|
||||
|
||||
:param other_set: The set to obtain union with. Can be a list, set or OrderedSet.
|
||||
:returns: New OrderedSet representing the set with elements from the OrderedSet object and other_set.
|
||||
"""
|
||||
new_set = OrderedSet()
|
||||
for element in self._set:
|
||||
new_set.add(element)
|
||||
for element in other_set:
|
||||
new_set.add(element)
|
||||
return new_set
|
||||
|
||||
def intersection(self, other_set):
|
||||
"""
|
||||
Function to compute the intersection of self._set and other_set.
|
||||
|
||||
:param other_set: The set to obtain intersection with. Can be a list, set or OrderedSet.
|
||||
:returns: New OrderedSet representing the set with elements common to the OrderedSet object and other_set.
|
||||
"""
|
||||
new_set = OrderedSet()
|
||||
for element in self._set:
|
||||
if element in other_set:
|
||||
new_set.add(element)
|
||||
return new_set
|
||||
|
||||
def difference(self, other_set):
|
||||
"""
|
||||
Function to remove elements in self._set which are also present in other_set.
|
||||
|
||||
:param other_set: The set to obtain difference with. Can be a list, set or OrderedSet.
|
||||
:returns: New OrderedSet representing the difference of elements in the self._set and other_set.
|
||||
"""
|
||||
new_set = OrderedSet()
|
||||
for element in self._set:
|
||||
if element not in other_set:
|
||||
new_set.add(element)
|
||||
return new_set
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index >= self.__len__():
|
||||
raise IndexError("Index is out of range")
|
||||
return list(self)[index]
|
||||
|
||||
def __iter__(self):
|
||||
self._iter = self._start
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
element = self._iter
|
||||
if not element:
|
||||
raise StopIteration
|
||||
self._iter = self._set[element]
|
||||
return element
|
||||
|
||||
def __len__(self):
|
||||
return len(self._set)
|
||||
|
||||
def __str__(self):
|
||||
elements = [str(i) for i in self]
|
||||
string = "OrderedSet(" + ",".join(elements) + ")"
|
||||
return string
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
return self._set == other._set
|
|
@ -0,0 +1,109 @@
|
|||
from numpy.core.fromnumeric import var
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dowhy import CausalModel
|
||||
|
||||
class TestIDIdentifier(object):
|
||||
|
||||
def test_1(self):
|
||||
treatment = "T"
|
||||
outcome = "Y"
|
||||
causal_graph = "digraph{T->Y;}"
|
||||
columns = list(treatment) + list(outcome)
|
||||
df = pd.DataFrame(columns=columns)
|
||||
|
||||
# Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50)
|
||||
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
|
||||
identified_estimand = causal_model.identify_effect(method_name="id-algorithm")
|
||||
|
||||
# Only P(Y|T) should be present for test to succeed.
|
||||
identified_str = identified_estimand.__str__()
|
||||
gt_str = "Predictor: P(Y|T)"
|
||||
assert identified_str == gt_str
|
||||
|
||||
def test_2(self):
|
||||
'''
|
||||
Test undirected edge between treatment and outcome.
|
||||
'''
|
||||
treatment = "T"
|
||||
outcome = "Y"
|
||||
causal_graph = "digraph{T->Y; Y->T;}"
|
||||
columns = list(treatment) + list(outcome)
|
||||
df = pd.DataFrame(columns=columns)
|
||||
|
||||
# Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50)
|
||||
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
|
||||
|
||||
# Since undirected graph, identify effect must throw an error.
|
||||
with pytest.raises(Exception):
|
||||
identified_estimand = causal_model.identify_effect(method_name="id-algorithm")
|
||||
|
||||
def test_3(self):
|
||||
treatment = "T"
|
||||
outcome = "Y"
|
||||
variables = ["X1"]
|
||||
causal_graph = "digraph{T->X1;X1->Y;}"
|
||||
columns = list(treatment) + list(outcome) + list(variables)
|
||||
df = pd.DataFrame(columns=columns)
|
||||
|
||||
# Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50)
|
||||
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
|
||||
identified_estimand = causal_model.identify_effect(method_name="id-algorithm")
|
||||
|
||||
# Compare with ground truth
|
||||
identified_str = identified_estimand.__str__()
|
||||
gt_str = "Sum over {X1}:\n\tPredictor: P(X1|T)\n\tPredictor: P(Y|T,X1)"
|
||||
assert identified_str == gt_str
|
||||
|
||||
def test_4(self):
|
||||
treatment = "T"
|
||||
outcome = "Y"
|
||||
variables = ["X1"]
|
||||
causal_graph = "digraph{T->Y;T->X1;X1->Y;}"
|
||||
columns = list(treatment) + list(outcome) + list(variables)
|
||||
df = pd.DataFrame(columns=columns)
|
||||
|
||||
# Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50)
|
||||
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
|
||||
identified_estimand = causal_model.identify_effect(method_name="id-algorithm")
|
||||
|
||||
# Compare with ground truth
|
||||
identified_str = identified_estimand.__str__()
|
||||
gt_str = "Sum over {X1}:\n\tPredictor: P(Y|T,X1)\n\tPredictor: P(X1|T)"
|
||||
assert identified_str == gt_str
|
||||
|
||||
|
||||
def test_5(self):
|
||||
treatment = "T"
|
||||
outcome = "Y"
|
||||
variables = ["X1", "X2"]
|
||||
causal_graph = "digraph{T->Y;X1->T;X1->Y;X2->T;}"
|
||||
columns = list(treatment) + list(outcome) + list(variables)
|
||||
df = pd.DataFrame(columns=columns)
|
||||
|
||||
# Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50)
|
||||
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
|
||||
identified_estimand = causal_model.identify_effect(method_name="id-algorithm")
|
||||
|
||||
# Compare with ground truth
|
||||
identified_str = identified_estimand.__str__()
|
||||
gt_str = "Sum over {X1}:\n\tPredictor: P(Y|X2,X1,T)\n\tPredictor: P(X1)"
|
||||
assert identified_str == gt_str
|
||||
|
||||
def test_6(self):
|
||||
treatment = "T"
|
||||
outcome = "Y"
|
||||
variables = ["X1"]
|
||||
causal_graph = "digraph{T;X1->Y;}"
|
||||
columns = list(treatment) + list(outcome) + list(variables)
|
||||
df = pd.DataFrame(columns=columns)
|
||||
|
||||
# Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50)
|
||||
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
|
||||
identified_estimand = causal_model.identify_effect(method_name="id-algorithm")
|
||||
|
||||
# Compare with ground truth
|
||||
identified_str = identified_estimand.__str__()
|
||||
gt_str = "Sum over {X1}:\n\tPredictor: P(X1,Y)"
|
||||
assert identified_str == gt_str
|
Загрузка…
Ссылка в новой задаче