Updates to causal graph to include mediators and a few aesthetic changes (#209)

* better looking graphs

* updated the getting started notebook
This commit is contained in:
Amit Sharma 2020-12-12 14:47:53 +05:30 коммит произвёл GitHub
Родитель 7c89a8ecac
Коммит aeb1981a6c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 701 добавлений и 1401 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -20,18 +20,23 @@ class CausalGraph:
common_cause_names=None,
instrument_names=None,
effect_modifier_names=None,
mediator_names=None,
observed_node_names=None,
missing_nodes_as_confounders=False):
self.treatment_name = parse_state(treatment_name)
self.outcome_name = parse_state(outcome_name)
instrument_names = parse_state(instrument_names)
common_cause_names = parse_state(common_cause_names)
effect_modifier_names = parse_state(effect_modifier_names)
mediator_names = parse_state(mediator_names)
self.logger = logging.getLogger(__name__)
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names, effect_modifier_names)
instrument_names,
effect_modifier_names,
mediator_names)
elif re.match(r".*\.dot", graph):
# load dot file
try:
@ -104,25 +109,27 @@ class CausalGraph:
arrowstyle="-|>",
style="dashed",
arrowsize=12)
labels = nx.draw_networkx_labels(self._graph, pos)
plt.axis('off')
plt.savefig(out_filename)
plt.draw()
def build_graph(self, common_cause_names, instrument_names, effect_modifier_names):
def build_graph(self, common_cause_names, instrument_names,
effect_modifier_names, mediator_names):
""" Creates nodes and edges based on variable names and their semantics.
Currently only considers the graphical representation of "direct" effect modifiers. Thus, all effect modifiers are assumed to be "direct" unless otherwise expressed using a graph. Based on the taxonomy of effect modifiers by VanderWheele and Robins: "Four types of effect modification: A classification based on directed acyclic graphs. Epidemiology. 2007."
"""
for treatment in self.treatment_name:
self._graph.add_node(treatment, observed="yes")
self._graph.add_node(treatment, observed="yes", penwidth=2)
for outcome in self.outcome_name:
self._graph.add_node(outcome, observed="yes")
self._graph.add_node(outcome, observed="yes", penwidth=2)
for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name):
self._graph.add_edge(treatment, outcome)
# adding penwidth to make the edge bold
self._graph.add_edge(treatment, outcome, penwidth=2)
# Adding common causes
if common_cause_names is not None:
@ -153,6 +160,12 @@ class CausalGraph:
self._graph.add_node(node_name, observed="yes")
self._graph.add_edge(node_name, outcome, style = "dotted", headport="s", tailport="n")
self._graph.add_edge(outcome, node_name, style = "dotted", headport="n", tailport="s") # TODO make the ports more general so that they apply not just to top-bottom node configurations
if mediator_names is not None:
for node_name in mediator_names:
for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name):
self._graph.add_node(node_name, observed="yes")
self._graph.add_edge(treatment, node_name)
self._graph.add_edge(node_name, outcome)
return self._graph
def add_node_attributes(self, observed_node_names):
@ -172,7 +185,7 @@ class CausalGraph:
self._graph.add_edge(node_name, treatment_outcome_node)
return self._graph
def add_unobserved_common_cause(self, observed_node_names):
def add_unobserved_common_cause(self, observed_node_names, color="gray"):
# Adding unobserved confounders
current_common_causes = self.get_common_causes(self.treatment_name,
self.outcome_name)
@ -182,7 +195,8 @@ class CausalGraph:
create_new_common_cause = False
if create_new_common_cause:
uc_label = "Unobserved Confounders"
self._graph.add_node('U', label=uc_label, observed="no")
self._graph.add_node('U', label=uc_label, observed="no",
color=color, style="filled", fillcolor=color)
for node in self.treatment_name + self.outcome_name:
self._graph.add_edge('U', node)
self.logger.info('If this is observed data (not from a randomized experiment), there might always be missing confounders. Adding a node named "Unobserved Confounders" to reflect this.')

Просмотреть файл

@ -63,22 +63,22 @@ class AddUnobservedCommonCause(CausalRefuter):
new_effect = new_estimator.estimate_effect()
refute = CausalRefutation(self._estimate.value, new_effect.value,
refutation_type="Refute: Add an Unobserved Common Cause")
refute.new_effect = np.array(new_effect.value)
refute.add_refuter(self)
return refute
else: # Deal with multiple value inputs
if isinstance(self.kappa_t, np.ndarray) and isinstance(self.kappa_y, np.ndarray): # Deal with range inputs
# Get a 2D matrix of values
x,y = np.meshgrid(self.kappa_t, self.kappa_y) # x,y are both MxN
results_matrix = np.random.rand(len(x),len(y)) # Matrix to hold all the results of NxM
print(results_matrix.shape)
orig_data = copy.deepcopy(self._data)
for i in range(0,len(x[0])):
for j in range(0,len(y)):
new_data = self.include_confounders_effect(orig_data, x[0][i], y[j][0])
@ -88,10 +88,10 @@ class AddUnobservedCommonCause(CausalRefuter):
refutation_type="Refute: Add an Unobserved Common Cause")
self.logger.debug(refute)
results_matrix[i][j] = refute.estimated_effect # Populate the results
fig = plt.figure(figsize=(6,5))
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
ax = fig.add_axes([left, bottom, width, height])
ax = fig.add_axes([left, bottom, width, height])
cp = plt.contourf(x, y, results_matrix)
plt.colorbar(cp)
@ -120,13 +120,13 @@ class AddUnobservedCommonCause(CausalRefuter):
fig = plt.figure(figsize=(6,5))
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
ax = fig.add_axes([left, bottom, width, height])
ax = fig.add_axes([left, bottom, width, height])
plt.plot(self.kappa_t, outcomes)
ax.set_title('Effect of Unobserved Common Cause')
ax.set_xlabel('Value of Linear Constant on Treatment')
ax.set_ylabel('New Effect')
plt.show()
plt.show()
refute.new_effect = outcomes
refute.add_refuter(self)
@ -144,7 +144,7 @@ class AddUnobservedCommonCause(CausalRefuter):
refutation_type="Refute: Add an Unobserved Common Cause")
self.logger.debug(refute)
outcomes[i] = refute.estimated_effect # Populate the results
fig = plt.figure(figsize=(6,5))
left, bottom, width, height = 0.1, 0.1, 0.8, 0.8
ax = fig.add_axes([left, bottom, width, height])
@ -153,7 +153,7 @@ class AddUnobservedCommonCause(CausalRefuter):
ax.set_title('Effect of Unobserved Common Cause')
ax.set_xlabel('Value of Linear Constant on Outcome')
ax.set_ylabel('New Effect')
plt.show()
plt.show()
refute.new_effect = outcomes
refute.add_refuter(self)
@ -161,7 +161,7 @@ class AddUnobservedCommonCause(CausalRefuter):
def include_confounders_effect(self, new_data, kappa_t, kappa_y):
"""
This function deals with the change in the value of the data due to the effect of the unobserved confounder.
This function deals with the change in the value of the data due to the effect of the unobserved confounder.
In the case of a binary flip, we flip only if the random number is greater than the threshold set.
In the case of a linear effect, we use the variable as the linear regression constant.