Causal model + identification speedup (#288)

* Added identify_vars and different functions to calculate common causes, instruments and effect modifiers
* Added class for Backdoor variable search

* Check for blocked paths during DFS
* Improved memoization

* Added Hitting Set Algorithm to generate backdoor variables per node pair

* Added test file and other changes

* Added 2 more test cases in pytest file

* Added notebook for optimized backdoor identification

* Added docstring

* Add condition to prevent adding colliders to backdoor paths
This commit is contained in:
Siddhant Haldar 2021-07-30 11:28:17 +05:30 коммит произвёл GitHub
Родитель 46614846b2
Коммит 854d6e2822
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 708 добавлений и 22 удалений

Просмотреть файл

@ -0,0 +1,150 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Example to demonstrate optimized backdoor variable search for Causal Identification\n",
"\n",
"This notebook compares the performance between causal identification using vanilla backdoor search and the optimized backdoor search and demonstrates the performance gains obtained by using the latter."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import time\n",
"import random\n",
"from networkx.linalg.graphmatrix import adjacency_matrix\n",
"import numpy as np\n",
"import pandas as pd\n",
"import networkx as nx\n",
"\n",
"import dowhy\n",
"from dowhy import CausalModel\n",
"from dowhy.utils import graph_operations\n",
"import dowhy.datasets\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Create Random Graph \n",
"In this section, we create a random graph with the designated number of nodes (10 in this case)."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"n = 10\n",
"p = 0.5\n",
"\n",
"graph = nx.generators.random_graphs.fast_gnp_random_graph(n, p, directed=True)\n",
"nodes = []\n",
"for i in graph.nodes:\n",
"\tnodes.append(str(i))\n",
"adjacency_matrix = np.asarray(nx.to_numpy_matrix(graph))\n",
"graph_dot = graph_operations.adjacency_matrix_to_graph(adjacency_matrix, nodes)\n",
"graph_dot = graph_operations.str_to_dot(graph_dot.source)\n",
"print(\"Graph Generated.\")\n",
"\n",
"df = pd.DataFrame(columns=nodes)\n",
"print(\"Dataframe Generated.\")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Graph Generated.\n",
"Dataframe Generated.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Testing optimized backdoor search\n",
"\n",
"In this section, we compare the runtimes for causal identification using vanilla backdoor search and the optimized backdoor search."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"start = time.time()\n",
"\n",
"# I. Create a causal model from the data and given graph.\n",
"model = CausalModel(data=df,treatment=str(random.randint(0,n-1)),outcome=str(random.randint(0,n-1)),graph=graph_dot)\n",
"time1 = time.time()\n",
"print(\"Time taken for initializing model =\", time1-start)\n",
"\n",
"# II. Identify causal effect and return target estimands\n",
"identified_estimand = model.identify_effect()\n",
"time2 = time.time()\n",
"print(\"Time taken for vanilla identification =\", time2-time1)\n",
"\n",
"# III. Identify causal effect using the optimized backdoor implementation\n",
"identified_estimand = model.identify_effect(optimize_backdoor=True)\n",
"end = time.time()\n",
"print(\"Time taken for optimized backdoor identification =\", end-time2)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Time taken for initializing model = 0.07566142082214355\n",
"Time taken for vanilla identification = 6.404623508453369\n",
"Time taken for optimized backdoor identification = 1.3513822555541992\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"It can be observed that the optimized backdoor search makes causal identification significantly faster as compared to the vanilla implementation."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.6.9",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.9 64-bit"
},
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -43,7 +43,7 @@ class CausalIdentifier:
self._proceed_when_unidentifiable = proceed_when_unidentifiable
self.logger = logging.getLogger(__name__)
def identify_effect(self):
def identify_effect(self, optimize_backdoor=False):
"""Main method that returns an identified estimand (if one exists).
If estimand_type is non-parametric ATE, then uses backdoor, instrumental variable and frontdoor identification methods, to check if an identified estimand exists, based on the causal graph.
@ -52,7 +52,7 @@ class CausalIdentifier:
:returns: target estimand, an instance of the IdentifiedEstimand class
"""
if self.estimand_type == CausalIdentifier.NONPARAMETRIC_ATE:
return self.identify_ate_effect()
return self.identify_ate_effect(optimize_backdoor=optimize_backdoor)
elif self.estimand_type == CausalIdentifier.NONPARAMETRIC_NDE:
return self.identify_nde_effect()
elif self.estimand_type == CausalIdentifier.NONPARAMETRIC_NIE:
@ -63,13 +63,18 @@ class CausalIdentifier:
CausalIdentifier.NONPARAMETRIC_NDE,
CausalIdentifier.NONPARAMETRIC_NIE))
def identify_ate_effect(self):
def identify_ate_effect(self, optimize_backdoor):
estimands_dict = {}
mediation_first_stage_confounders = None
mediation_second_stage_confounders = None
### 1. BACKDOOR IDENTIFICATION
# First, checking if there are any valid backdoor adjustment sets
if optimize_backdoor == False:
backdoor_sets = self.identify_backdoor(self.treatment_name, self.outcome_name)
else:
from dowhy.causal_identifiers.backdoor import Backdoor
path = Backdoor(self._graph._graph, self.treatment_name, self.outcome_name)
backdoor_sets = path.get_backdoor_vars()
estimands_dict, backdoor_variables_dict = self.build_backdoor_estimands_dict(
self.treatment_name,
self.outcome_name,

Просмотреть файл

@ -0,0 +1,298 @@
import networkx as nx
from dowhy.utils.graph_operations import adjacency_matrix_to_adjacency_list
class NodePair:
'''
Data structure to store backdoor variables between 2 nodes.
'''
def __init__(self, node1, node2):
self._node1 = node1
self._node2 = node2
self._is_blocked = None # To store if all paths between node1 and node2 are blocked
self._condition_vars = [] # To store variable to be conditioned on to block all paths between node1 and node2
self._complete = False # To store to paths between node pair have been completely explored.
def update(self, path, condition_vars=None):
if condition_vars is None:
'''path is a Path variable'''
if self._is_blocked is None:
self._is_blocked = path.is_blocked()
else:
self._is_blocked = self._is_blocked and path.is_blocked()
if not path.is_blocked():
self._condition_vars.append(path.get_condition_vars())
else:
'''path is a list'''
condition_vars = list(condition_vars)
self._condition_vars.append(set([*path[1:], *condition_vars]))
def get_condition_vars(self):
return self._condition_vars
def set_complete(self):
self._complete = True
def is_complete(self):
return self._complete
def __str__(self):
string = ""
string += "Blocked: " + str(self._is_blocked) + "\n"
if not self._is_blocked:
condition_vars = [str(s) for s in self._condition_vars]
string += "To block path, condition on: " + ",".join(condition_vars) + "\n"
return string
class Path:
'''
Data structure to store a particular path between 2 nodes.
'''
def __init__(self):
self._is_blocked = None # To store if path is blocked
self._condition_vars = set() # To store variables needed to block the path
def update(self, path, is_blocked):
'''
path is a list
'''
self._is_blocked = is_blocked
if not is_blocked:
self._condition_vars = self._condition_vars.union(set(path[1:-1]))
def is_blocked(self):
return self._is_blocked
def get_condition_vars(self):
return self._condition_vars
def __str__(self):
string = ""
string += "Blocked: " + str(self._is_blocked) + "\n"
if not self._is_blocked:
string += "To block path, condition on: " + ",".join(self._condition_vars) + "\n"
return string
class Backdoor:
'''
Class for optimized implementation of Backdoor variable search between the source nodes and the target nodes.
'''
def __init__(self, graph, nodes1, nodes2):
self._graph = graph
self._nodes1 = nodes1
self._nodes2 = nodes2
self._nodes12 = set(self._nodes1).union(self._nodes2) # Total set of nodes
self._colliders = set()
def get_backdoor_vars(self):
'''
Obtains sets of backdoor variable to condition on for each node pair.
:returns: List of sets with each set containing backdoor variable corresponding to a given node pair.
'''
undirected_graph = self._graph.to_undirected()
# Get adjacency list
adjlist = adjacency_matrix_to_adjacency_list(nx.to_numpy_matrix(undirected_graph), labels=list(undirected_graph.nodes))
path_dict = {}
backdoor_sets = [] # Put in backdoor sets format
for node1 in self._nodes1:
for node2 in self._nodes2:
if (node1, node2) in path_dict:
continue
self._path_search(adjlist, node1, node2, path_dict)
if len(path_dict) != 0:
obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders)
backdoor_set = {}
backdoor_set['backdoor_set'] = tuple(obj.find_set())
backdoor_set['num_paths_blocked_by_observed_nodes'] = obj.num_sets()
backdoor_sets.append(backdoor_set)
return backdoor_sets
def is_backdoor(self, path):
'''
Check if path is a backdoor path.
:param path: List of nodes comprising the path.
'''
if len(path)<2:
return False
return True if self._graph.has_edge(path[1], path[0]) else False
def _path_search_util(self, graph, node1, node2, vis, path, path_dict, is_blocked=False, prev_arrow=None):
'''
:param graph: Adjacency list of the graph under consideration.
:param node1: Current node being considered.
:param node2: Target node.
:param vis: Set of already visited nodes.
:param path: List of nodes comprising the path upto node1.
:path path_dict: Dictionary of node pairs.
:param is_blocked: True is path is blocked by a collider, else False.
:param prev_arrow: Described state of previous arrow. True if arrow incoming, False if arrow outgoing.
'''
if is_blocked:
return
# If node pair has been fully explored
if ((node1, node2) in path_dict) and (path_dict[(node1, node2)].is_complete()):
for i in range(len(path)):
if (path[i], node2) not in path_dict:
path_dict[(path[i], node2)] = NodePair(path[i], node2)
obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders)
# Add node1 to backdoor set of node_pair
s = set([node1])
s = s.union(obj.find_set())
path_dict[(path[i], node2)].update(path[i:], s)
else:
path.append(node1)
vis.add(node1)
if node1 == node2:
# Check if path is backdoor and does not have nodes1\node1 or nodes2\node2 as intermediate nodes
if self.is_backdoor(path) and len(self._nodes12.intersection(set(path[1:-1])))==0:
for i in range(len(path)-1):
if (path[i], node2) not in path_dict:
path_dict[(path[i], node2)] = NodePair(path[i], node2)
path_var = Path()
path_var.update(path[i:].copy(), is_blocked)
path_dict[(path[i], node2)].update(path_var)
else:
for neighbour in graph[node1]:
if neighbour not in vis:
# True if arrow incoming, False if arrow outgoing
next_arrow = False if self._graph.has_edge(node1, neighbour) else True
if next_arrow == True and prev_arrow == True:
is_blocked = True
self._colliders.add(node1)
self._path_search_util(graph, neighbour, node2, vis, path, path_dict, is_blocked, not next_arrow) # Since incoming for current node is outgoing for the next
path.pop()
vis.remove(node1)
# Mark pair (node1, node2) complete
if (node1, node2) in path_dict:
path_dict[(node1, node2)].set_complete()
def _path_search(self, graph, node1, node2, path_dict):
'''
Path search using DFS.
:param graph: Adjacency list of the graph under consideration.
:param node1: Current node being considered.
:param node2: Target node.
:path path_dict: Dictionary of node pairs.
'''
vis = set()
self._path_search_util(graph, node1, node2, vis, [], path_dict, is_blocked=False)
class HittingSetAlgorithm:
'''
Class for the Hitting Set Algorithm to obtain a approximate minimal set of backdoor variables
to condition on for each node pair.
'''
def __init__(self, list_of_sets, colliders=set()):
'''
:param list_of_sets: List of sets such that each set comprises nodes representing a single backdoor path between a source node and a target node.
'''
self._list_of_sets = list_of_sets
self._colliders = colliders
self._var_count = self._count_vars()
def num_sets(self):
'''
Obtain number of backdoor paths between a node pair.
'''
return len(self._list_of_sets)
def find_set(self):
'''
Find approximate minimal set of nodes such that there is atleast one node from each set in list_of_sets.
:returns: Approximate minimal set of nodes.
'''
var_set = set()
num_indices = len(self._list_of_sets)
indices_covered = set()
all_set_indices = set([i for i in range(num_indices)])
while not self._is_covered(indices_covered, num_indices):
set_index = all_set_indices - indices_covered
max_el = self._max_occurence_var(var_dict=self._var_count)
if max_el is None:
break
var_set.add(max_el)
# Modify variable count and indices covered
covered_present = self._indices_covered(el=max_el, set_index=set_index)
self._modify_count(covered_present)
indices_covered = indices_covered.union(covered_present)
return var_set
def _count_vars(self, set_index = None):
'''
Obtain count of number of sets each particular node belongs to.
:param set_index: Set of indices to consider for calculating the number of sets "hit" by a variable..
'''
var_dict = {}
if set_index == None:
set_index = set([i for i in range(len(self._list_of_sets))])
for idx in set_index:
s = self._list_of_sets[idx]
for el in s:
if el not in self._colliders:
if el not in var_dict:
var_dict[el] = 0
var_dict[el] += 1
return var_dict
def _modify_count(self, indices_covered):
'''
Modify count of number of sets each particular node belongs to based on nodes already covered in the previous iteration of the algorithm.
'''
for idx in indices_covered:
for el in self._list_of_sets[idx]:
if el not in self._colliders:
self._var_count[el] -= 1
def _max_occurence_var(self, var_dict):
'''
Find the node contained in most number of sets.
'''
max_el = None
max_count = 0
for key, val in var_dict.items():
if val>max_count:
max_count = val
max_el = key
return max_el
def _indices_covered(self, el, set_index=None):
'''
Obtain indices covered in a particular iteration of the algorithm.
'''
covered = set()
if set_index == None:
set_index = set([i for i in range(len(self._list_of_sets))])
for idx in set_index:
if el in self._list_of_sets[idx]:
covered.add(idx)
return covered
def _is_covered(self, indices_covered, num_indices):
'''
List of sets is covered by the variable set.
'''
covered = [False for i in range(num_indices)]
for idx in indices_covered:
covered[idx] = True
return all(covered)

Просмотреть файл

@ -30,6 +30,7 @@ class CausalModel:
estimand_type="nonparametric-ate",
proceed_when_unidentifiable=False,
missing_nodes_as_confounders=False,
identify_vars=False,
**kwargs):
"""Initialize data and create a causal graph instance.
@ -53,6 +54,7 @@ class CausalModel:
:param estimand_type: the type of estimand requested (currently only "nonparametric-ate" is supported). In the future, may support other specific parametric forms of identification.
:param proceed_when_unidentifiable: does the identification proceed by ignoring potential unobserved confounders. Binary flag.
:param missing_nodes_as_confounders: Binary flag indicating whether variables in the dataframe that are not included in the causal graph, should be automatically included as confounder nodes.
:param identify_vars: Variable deciding whether to compute common causes, instruments and effect modifiers while initializing the class. identify_vars should be set to False when user is providing common_causes, instruments or effect modifiers on their own(otherwise the identify_vars code can override the user provided values). Also it does not make sense if no graph is given.
:returns: an instance of CausalModel class
"""
@ -99,12 +101,12 @@ class CausalModel:
self._graph = None
else:
self.init_graph(graph=graph)
self.init_graph(graph=graph, identify_vars=identify_vars)
self._other_variables = kwargs
self.summary()
def init_graph(self, graph):
def init_graph(self, graph, identify_vars):
'''
Initialize self._graph using graph provided by the user.
@ -118,6 +120,8 @@ class CausalModel:
observed_node_names=self._data.columns.tolist(),
missing_nodes_as_confounders = self._missing_nodes_as_confounders
)
if identify_vars:
self._common_causes = self._graph.get_common_causes(self._treatment, self._outcome)
self._instruments = self._graph.get_instruments(self._treatment,
self._outcome)
@ -128,6 +132,18 @@ class CausalModel:
if self._effect_modifiers is None or not self._effect_modifiers:
self._effect_modifiers = self._graph.get_effect_modifiers(self._treatment, self._outcome)
def get_common_causes(self):
self._common_causes = self._graph.get_common_causes(self._treatment, self._outcome)
return self._common_causes
def get_instruments(self):
self._instruments = self._graph.get_instruments(self._treatment, self._outcome)
return self._instruments
def get_effect_modifiers(self):
self._effect_modifiers = self._graph.get_effect_modifiers(self._treatment, self._outcome)
return self._effect_modifiers
def learn_graph(self, method_name="cdt.causality.graph.LiNGAM", *args, **kwargs):
'''
Learn causal graph from the data. This function takes the method name as input and initializes the
@ -151,7 +167,7 @@ class CausalModel:
return self._graph
def identify_effect(self, estimand_type=None,
method_name="default", proceed_when_unidentifiable=None):
method_name="default", proceed_when_unidentifiable=None, optimize_backdoor=False):
"""Identify the causal effect to be estimated, using properties of the causal graph.
:param method_name: Method name for identification algorithm. ("id-algorithm" or "default")
@ -169,12 +185,13 @@ class CausalModel:
estimand_type,
method_name,
proceed_when_unidentifiable=proceed_when_unidentifiable)
identified_estimand = self.identifier.identify_effect()
else:
self.identifier = CausalIdentifier(self._graph,
estimand_type,
method_name,
proceed_when_unidentifiable=proceed_when_unidentifiable)
identified_estimand = self.identifier.identify_effect()
identified_estimand = self.identifier.identify_effect(optimize_backdoor=optimize_backdoor)
return identified_estimand
@ -217,6 +234,9 @@ class CausalModel:
"""
if effect_modifiers is None:
if self._effect_modifiers is None:
effect_modifiers = self.get_effect_modifiers()
else:
effect_modifiers = self._effect_modifiers
if method_name is None:
@ -300,7 +320,6 @@ class CausalModel:
pass
else:
str_arr = method_name.split(".", maxsplit=1)
print(str_arr)
identifier_name = str_arr[0]
estimator_name = str_arr[1]
identified_estimand.set_identifier_method(identifier_name)

Просмотреть файл

@ -3,6 +3,25 @@ import graphviz
from queue import LifoQueue
from dowhy.utils.ordered_set import OrderedSet
def adjacency_matrix_to_adjacency_list(adjacency_matrix, labels=None):
'''
Convert the adjacency matrix of a graph to an adjacency list.
:param adjacency_matrix: A numpy array representing the graph adjacency matrix.
:param labels: List of labels.
:returns: Adjacency list as a dictionary.
'''
adjlist = dict()
if labels is None:
labels = [str(i+1) for i in range(adjacency_matrix.shape[0])]
for i in range(adjacency_matrix.shape[0]):
adjlist[labels[i]] = list()
for j in range(adjacency_matrix.shape[1]):
if adjacency_matrix[i, j] != 0:
adjlist[labels[i]].append(labels[j])
return adjlist
def adjacency_matrix_to_graph(adjacency_matrix, labels=None):
'''
Convert a given graph adjacency matrix to DOT format.

Просмотреть файл

@ -97,7 +97,8 @@ class TestEconMLEstimator:
treatment=data["treatment_name"],
outcome=data["outcome_name"],
effect_modifiers=data["effect_modifier_names"],
graph=data["gml_graph"]
graph=data["gml_graph"],
identify_vars=True
)
identified_estimand = model.identify_effect(
proceed_when_unidentifiable=True)

Просмотреть файл

@ -0,0 +1,194 @@
import pandas as pd
from dowhy import CausalModel
from dowhy.causal_identifier import CausalIdentifier
from dowhy.utils.api import parse_state
from dowhy.causal_identifiers.backdoor import Backdoor
class TestOptimizeBackdoorIdentifier(object):
def test_1(self):
treatment = "T"
outcome = "Y"
variables = ["X1", "X2"]
causal_graph = "digraph{X1->T;X2->T;X1->X2;X2->Y;T->Y}"
vars = list(treatment) + list(outcome) + list(variables)
df = pd.DataFrame(columns=vars)
treatment_name = parse_state(treatment)
outcome_name = parse_state(outcome)
# Causal model initialization
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
# Causal identifier identification
identifier = CausalIdentifier(causal_model._graph,
estimand_type=None,
method_name="default",
proceed_when_unidentifiable=None)
# Obtain backdoor sets
path = Backdoor(identifier._graph._graph, treatment_name, outcome_name)
backdoor_sets = path.get_backdoor_vars()
# Check if backdoor sets are valid i.e. if they block all paths between the treatment and the outcome
backdoor_paths = identifier._graph.get_backdoor_paths(treatment_name, outcome_name)
check_set = set(backdoor_sets[0]['backdoor_set'])
check = identifier._graph.check_valid_backdoor_set(treatment_name, outcome_name, check_set, backdoor_paths=backdoor_paths)
assert check["is_dseparated"]
def test_2(self):
treatment = "T"
outcome = "Y"
variables = ["X1", "X2"]
causal_graph = "digraph{T->X1;T->X2;X1->X2;X2->Y;T->Y}"
vars = list(treatment) + list(outcome) + list(variables)
df = pd.DataFrame(columns=vars)
treatment_name = parse_state(treatment)
outcome_name = parse_state(outcome)
# Causal model initialization
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
# Causal identifier identification
identifier = CausalIdentifier(causal_model._graph,
estimand_type=None,
method_name="default",
proceed_when_unidentifiable=None)
# Obtain backdoor sets
path = Backdoor(identifier._graph._graph, treatment_name, outcome_name)
backdoor_sets = path.get_backdoor_vars()
assert len(backdoor_sets) == 0
def test_3(self):
treatment = "T"
outcome = "Y"
variables = ["X1", "X2", "X3"]
causal_graph = "digraph{X1->T;X1->X2;Y->X2;X3->T;X3->Y;T->Y}"
vars = list(treatment) + list(outcome) + list(variables)
df = pd.DataFrame(columns=vars)
treatment_name = parse_state(treatment)
outcome_name = parse_state(outcome)
# Causal model initialization
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
# Causal identifier identification
identifier = CausalIdentifier(causal_model._graph,
estimand_type=None,
method_name="default",
proceed_when_unidentifiable=None)
# Obtain backdoor sets
path = Backdoor(identifier._graph._graph, treatment_name, outcome_name)
backdoor_sets = path.get_backdoor_vars()
# Check if backdoor sets are valid i.e. if they block all paths between the treatment and the outcome
backdoor_paths = identifier._graph.get_backdoor_paths(treatment_name, outcome_name)
check_set = set(backdoor_sets[0]['backdoor_set'])
check = identifier._graph.check_valid_backdoor_set(treatment_name, outcome_name, check_set, backdoor_paths=backdoor_paths)
assert check["is_dseparated"]
def test_4(self):
treatment = "T"
outcome = "Y"
variables = ["X1", "X2"]
causal_graph = "digraph{T->Y;X1->T;X1->Y;X2->T;}"
vars = list(treatment) + list(outcome) + list(variables)
df = pd.DataFrame(columns=vars)
treatment_name = parse_state(treatment)
outcome_name = parse_state(outcome)
# Causal model initialization
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
# Causal identifier identification
identifier = CausalIdentifier(causal_model._graph,
estimand_type=None,
method_name="default",
proceed_when_unidentifiable=None)
# Obtain backdoor sets
path = Backdoor(identifier._graph._graph, treatment_name, outcome_name)
backdoor_sets = path.get_backdoor_vars()
# Check if backdoor sets are valid i.e. if they block all paths between the treatment and the outcome
backdoor_paths = identifier._graph.get_backdoor_paths(treatment_name, outcome_name)
check_set = set(backdoor_sets[0]['backdoor_set'])
check = identifier._graph.check_valid_backdoor_set(treatment_name, outcome_name, check_set, backdoor_paths=backdoor_paths)
assert check["is_dseparated"]
def test_5(self):
treatment = "T"
outcome = "Y"
variables = ["X1", "X2", "X3", "X4"]
causal_graph = "digraph{X1->T;X1->X2;X2->Y;X3->T;X3->X4;X4->Y;T->Y}"
vars = list(treatment) + list(outcome) + list(variables)
df = pd.DataFrame(columns=vars)
treatment_name = parse_state(treatment)
outcome_name = parse_state(outcome)
# Causal model initialization
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
# Causal identifier identification
identifier = CausalIdentifier(causal_model._graph,
estimand_type=None,
method_name="default",
proceed_when_unidentifiable=None)
# Obtain backdoor sets
path = Backdoor(identifier._graph._graph, treatment_name, outcome_name)
backdoor_sets = path.get_backdoor_vars()
# Check if backdoor sets are valid i.e. if they block all paths between the treatment and the outcome
backdoor_paths = identifier._graph.get_backdoor_paths(treatment_name, outcome_name)
check_set = set(backdoor_sets[0]['backdoor_set'])
check = identifier._graph.check_valid_backdoor_set(treatment_name, outcome_name, check_set, backdoor_paths=backdoor_paths)
assert check["is_dseparated"]
def test_6(self):
treatment = "T"
outcome = "Y"
variables = ["X1", "X2", "X3", "X4"]
causal_graph = "digraph{X1->T;X1->X2;Y->X2;X3->T;X3->X4;X4->Y;T->Y}"
vars = list(treatment) + list(outcome) + list(variables)
df = pd.DataFrame(columns=vars)
treatment_name = parse_state(treatment)
outcome_name = parse_state(outcome)
# Causal model initialization
causal_model = CausalModel(df, treatment, outcome, graph=causal_graph)
# Causal identifier identification
identifier = CausalIdentifier(causal_model._graph,
estimand_type=None,
method_name="default",
proceed_when_unidentifiable=None)
# Obtain backdoor sets
path = Backdoor(identifier._graph._graph, treatment_name, outcome_name)
backdoor_sets = path.get_backdoor_vars()
# Check if backdoor sets are valid i.e. if they block all paths between the treatment and the outcome
backdoor_paths = identifier._graph.get_backdoor_paths(treatment_name, outcome_name)
check_set = set(backdoor_sets[0]['backdoor_set'])
check = identifier._graph.check_valid_backdoor_set(treatment_name, outcome_name, check_set, backdoor_paths=backdoor_paths)
assert check["is_dseparated"]

Просмотреть файл

@ -36,5 +36,5 @@ class TestCausalModel(object):
test_significance=None,
missing_nodes_as_confounders=True
)
assert all(node_name in model._common_causes for node_name in ["X1", "X2"])
common_causes = model.get_common_causes()
assert all(node_name in common_causes for node_name in ["X1", "X2"])