Updates to causal graph to include mediators and a few aesthetic changes (#209)
* better looking graphs * updated the getting started notebook
This commit is contained in:
Родитель
7c89a8ecac
Коммит
aeb1981a6c
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -20,18 +20,23 @@ class CausalGraph:
|
|||
common_cause_names=None,
|
||||
instrument_names=None,
|
||||
effect_modifier_names=None,
|
||||
mediator_names=None,
|
||||
observed_node_names=None,
|
||||
missing_nodes_as_confounders=False):
|
||||
self.treatment_name = parse_state(treatment_name)
|
||||
self.outcome_name = parse_state(outcome_name)
|
||||
instrument_names = parse_state(instrument_names)
|
||||
common_cause_names = parse_state(common_cause_names)
|
||||
effect_modifier_names = parse_state(effect_modifier_names)
|
||||
mediator_names = parse_state(mediator_names)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
if graph is None:
|
||||
self._graph = nx.DiGraph()
|
||||
self._graph = self.build_graph(common_cause_names,
|
||||
instrument_names, effect_modifier_names)
|
||||
instrument_names,
|
||||
effect_modifier_names,
|
||||
mediator_names)
|
||||
elif re.match(r".*\.dot", graph):
|
||||
# load dot file
|
||||
try:
|
||||
|
@ -104,25 +109,27 @@ class CausalGraph:
|
|||
arrowstyle="-|>",
|
||||
style="dashed",
|
||||
arrowsize=12)
|
||||
|
||||
|
||||
labels = nx.draw_networkx_labels(self._graph, pos)
|
||||
|
||||
plt.axis('off')
|
||||
plt.savefig(out_filename)
|
||||
plt.draw()
|
||||
|
||||
def build_graph(self, common_cause_names, instrument_names, effect_modifier_names):
|
||||
def build_graph(self, common_cause_names, instrument_names,
|
||||
effect_modifier_names, mediator_names):
|
||||
""" Creates nodes and edges based on variable names and their semantics.
|
||||
|
||||
|
||||
Currently only considers the graphical representation of "direct" effect modifiers. Thus, all effect modifiers are assumed to be "direct" unless otherwise expressed using a graph. Based on the taxonomy of effect modifiers by VanderWheele and Robins: "Four types of effect modification: A classification based on directed acyclic graphs. Epidemiology. 2007."
|
||||
"""
|
||||
|
||||
for treatment in self.treatment_name:
|
||||
self._graph.add_node(treatment, observed="yes")
|
||||
self._graph.add_node(treatment, observed="yes", penwidth=2)
|
||||
for outcome in self.outcome_name:
|
||||
self._graph.add_node(outcome, observed="yes")
|
||||
self._graph.add_node(outcome, observed="yes", penwidth=2)
|
||||
for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name):
|
||||
self._graph.add_edge(treatment, outcome)
|
||||
# adding penwidth to make the edge bold
|
||||
self._graph.add_edge(treatment, outcome, penwidth=2)
|
||||
|
||||
# Adding common causes
|
||||
if common_cause_names is not None:
|
||||
|
@ -153,6 +160,12 @@ class CausalGraph:
|
|||
self._graph.add_node(node_name, observed="yes")
|
||||
self._graph.add_edge(node_name, outcome, style = "dotted", headport="s", tailport="n")
|
||||
self._graph.add_edge(outcome, node_name, style = "dotted", headport="n", tailport="s") # TODO make the ports more general so that they apply not just to top-bottom node configurations
|
||||
if mediator_names is not None:
|
||||
for node_name in mediator_names:
|
||||
for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name):
|
||||
self._graph.add_node(node_name, observed="yes")
|
||||
self._graph.add_edge(treatment, node_name)
|
||||
self._graph.add_edge(node_name, outcome)
|
||||
return self._graph
|
||||
|
||||
def add_node_attributes(self, observed_node_names):
|
||||
|
@ -172,7 +185,7 @@ class CausalGraph:
|
|||
self._graph.add_edge(node_name, treatment_outcome_node)
|
||||
return self._graph
|
||||
|
||||
def add_unobserved_common_cause(self, observed_node_names):
|
||||
def add_unobserved_common_cause(self, observed_node_names, color="gray"):
|
||||
# Adding unobserved confounders
|
||||
current_common_causes = self.get_common_causes(self.treatment_name,
|
||||
self.outcome_name)
|
||||
|
@ -182,7 +195,8 @@ class CausalGraph:
|
|||
create_new_common_cause = False
|
||||
if create_new_common_cause:
|
||||
uc_label = "Unobserved Confounders"
|
||||
self._graph.add_node('U', label=uc_label, observed="no")
|
||||
self._graph.add_node('U', label=uc_label, observed="no",
|
||||
color=color, style="filled", fillcolor=color)
|
||||
for node in self.treatment_name + self.outcome_name:
|
||||
self._graph.add_edge('U', node)
|
||||
self.logger.info('If this is observed data (not from a randomized experiment), there might always be missing confounders. Adding a node named "Unobserved Confounders" to reflect this.')
|
||||
|
|
|
@ -63,22 +63,22 @@ class AddUnobservedCommonCause(CausalRefuter):
|
|||
new_effect = new_estimator.estimate_effect()
|
||||
refute = CausalRefutation(self._estimate.value, new_effect.value,
|
||||
refutation_type="Refute: Add an Unobserved Common Cause")
|
||||
|
||||
|
||||
refute.new_effect = np.array(new_effect.value)
|
||||
refute.add_refuter(self)
|
||||
return refute
|
||||
|
||||
else: # Deal with multiple value inputs
|
||||
|
||||
|
||||
if isinstance(self.kappa_t, np.ndarray) and isinstance(self.kappa_y, np.ndarray): # Deal with range inputs
|
||||
|
||||
|
||||
# Get a 2D matrix of values
|
||||
x,y = np.meshgrid(self.kappa_t, self.kappa_y) # x,y are both MxN
|
||||
|
||||
|
||||
results_matrix = np.random.rand(len(x),len(y)) # Matrix to hold all the results of NxM
|
||||
print(results_matrix.shape)
|
||||
orig_data = copy.deepcopy(self._data)
|
||||
|
||||
|
||||
for i in range(0,len(x[0])):
|
||||
for j in range(0,len(y)):
|
||||
new_data = self.include_confounders_effect(orig_data, x[0][i], y[j][0])
|
||||
|
@ -88,10 +88,10 @@ class AddUnobservedCommonCause(CausalRefuter):
|
|||
refutation_type="Refute: Add an Unobserved Common Cause")
|
||||
self.logger.debug(refute)
|
||||
results_matrix[i][j] = refute.estimated_effect # Populate the results
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(6,5))
|
||||
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
|
||||
ax = fig.add_axes([left, bottom, width, height])
|
||||
ax = fig.add_axes([left, bottom, width, height])
|
||||
|
||||
cp = plt.contourf(x, y, results_matrix)
|
||||
plt.colorbar(cp)
|
||||
|
@ -120,13 +120,13 @@ class AddUnobservedCommonCause(CausalRefuter):
|
|||
|
||||
fig = plt.figure(figsize=(6,5))
|
||||
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
|
||||
ax = fig.add_axes([left, bottom, width, height])
|
||||
ax = fig.add_axes([left, bottom, width, height])
|
||||
|
||||
plt.plot(self.kappa_t, outcomes)
|
||||
ax.set_title('Effect of Unobserved Common Cause')
|
||||
ax.set_xlabel('Value of Linear Constant on Treatment')
|
||||
ax.set_ylabel('New Effect')
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
refute.new_effect = outcomes
|
||||
refute.add_refuter(self)
|
||||
|
@ -144,7 +144,7 @@ class AddUnobservedCommonCause(CausalRefuter):
|
|||
refutation_type="Refute: Add an Unobserved Common Cause")
|
||||
self.logger.debug(refute)
|
||||
outcomes[i] = refute.estimated_effect # Populate the results
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(6,5))
|
||||
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
|
||||
ax = fig.add_axes([left, bottom, width, height])
|
||||
|
@ -153,7 +153,7 @@ class AddUnobservedCommonCause(CausalRefuter):
|
|||
ax.set_title('Effect of Unobserved Common Cause')
|
||||
ax.set_xlabel('Value of Linear Constant on Outcome')
|
||||
ax.set_ylabel('New Effect')
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
||||
refute.new_effect = outcomes
|
||||
refute.add_refuter(self)
|
||||
|
@ -161,7 +161,7 @@ class AddUnobservedCommonCause(CausalRefuter):
|
|||
|
||||
def include_confounders_effect(self, new_data, kappa_t, kappa_y):
|
||||
"""
|
||||
This function deals with the change in the value of the data due to the effect of the unobserved confounder.
|
||||
This function deals with the change in the value of the data due to the effect of the unobserved confounder.
|
||||
In the case of a binary flip, we flip only if the random number is greater than the threshold set.
|
||||
In the case of a linear effect, we use the variable as the linear regression constant.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче