added support for effect modifiers in the non-graph causal model interface and improved the graph drawing functions to show effect modifiers

This commit is contained in:
Amit Sharma 2019-12-02 14:23:03 +05:30
Родитель 081f2c3563
Коммит 98b6ed2a82
5 изменённых файлов: 246 добавлений и 172 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -7,23 +7,32 @@ import itertools
class CausalGraph:
"""Class for creating and modifying the causal graph.
Accepts a graph string (or a text file) in gml format (preferred) and dot format. Graphviz-like attributes can be set for edges and nodes. E.g. style="dashed" as an edge attribute ensures that the edge is drawn with a dashed line.
If a graph string is not given, names of treatment, outcome, and confounders, instruments and effect modifiers (if any) can be provided to create the graph.
"""
def __init__(self,
treatment_name, outcome_name,
graph=None,
common_cause_names=None,
instrument_names=None,
effect_modifier_names=None,
observed_node_names=None,
missing_nodes_as_confounders=False):
self.treatment_name = parse_state(treatment_name)
self.outcome_name = parse_state(outcome_name)
instrument_names = parse_state(instrument_names)
common_cause_names = parse_state(common_cause_names)
effect_modifier_names = parse_state(effect_modifier_names)
self.logger = logging.getLogger(__name__)
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names)
instrument_names, effect_modifier_names)
elif re.match(r".*\.dot", graph):
# load dot file
try:
@ -74,13 +83,35 @@ class CausalGraph:
self.logger.warning("Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed.")
self.logger.info("Using Matplotlib for plotting")
import matplotlib.pyplot as plt
solid_edges = [(n1,n2) for n1,n2, e in self._graph.edges(data=True) if 'style' not in e ]
dashed_edges =[(n1,n2) for n1,n2, e in self._graph.edges(data=True) if ('style' in e and e['style']=="dashed") ]
plt.clf()
nx.draw_networkx(self._graph, pos=nx.shell_layout(self._graph))
pos = nx.layout.shell_layout(self._graph)
nodes = nx.draw_networkx_nodes(self._graph, pos)
edges = nx.draw_networkx_edges(
self._graph,
pos,
edgelist=solid_edges,
arrowstyle="fancy",
arrowsize=1)
edges = nx.draw_networkx_edges(
self._graph,
pos,
edgelist=dashed_edges,
arrowstyle="->",
style="dashed",
arrowsize=2)
labels = nx.draw_networkx_labels(self._graph, pos)
#nx.draw_networkx(self._graph, pos=nx.shell_layout(self._graph))
plt.axis('off')
plt.savefig(out_filename)
plt.draw()
def build_graph(self, common_cause_names, instrument_names):
def build_graph(self, common_cause_names, instrument_names, effect_modifier_names):
""" Creates nodes and edges based on variable names and their semantics.
Currently only considers the graphical representation of "direct" effect modifiers. Thus, all effect modifiers are assumed to be "direct" unless otherwise expressed using a graph. Based on the taxonomy of effect modifiers by VanderWheele and Robins: "Four types of effect modification: A classification based on directed acyclic graphs. Epidemiology. 2007."
"""
for treatment in self.treatment_name:
self._graph.add_node(treatment, observed="yes")
for outcome in self.outcome_name:
@ -95,6 +126,7 @@ class CausalGraph:
self._graph.add_node(node_name, observed="yes")
self._graph.add_edge(node_name, treatment)
self._graph.add_edge(node_name, outcome)
# Adding instruments
if instrument_names:
if type(instrument_names[0]) != tuple:
@ -107,6 +139,14 @@ class CausalGraph:
for instrument, treatment in itertools.product(instrument_names):
self._graph.add_node(instrument, observed="yes")
self._graph.add_edge(instrument, treatment)
# Adding effect modifiers
if effect_modifier_names is not None:
for node_name in effect_modifier_names:
for outcome in self.outcome_name:
self._graph.add_node(node_name, observed="yes")
self._graph.add_edge(node_name, outcome, style = "dotted", headport="s", tailport="n")
self._graph.add_edge(outcome, node_name, style = "dotted", headport="n", tailport="s") # TODO make the ports more general so that they apply not just to top-bottom node configurations
return self._graph
def add_node_attributes(self, observed_node_names):

Просмотреть файл

@ -25,7 +25,9 @@ class CausalModel:
"""
def __init__(self, data, treatment, outcome, graph=None,
common_causes=None, instruments=None, estimand_type="ate",
common_causes=None, instruments=None,
effect_modifiers=None,
estimand_type="ate",
proceed_when_unidentifiable=False,
missing_nodes_as_confounders=False,
**kwargs):
@ -46,6 +48,7 @@ class CausalModel:
:param common_causes: names of common causes of treatment and _outcome
:param instruments: names of instrumental variables for the effect of
treatment on outcome
:param effect_modifiers: names of variables that can modify the treatment effect (useful for heterogeneous treatment effect estimation)
:returns: an instance of CausalModel class
"""
@ -67,12 +70,14 @@ class CausalModel:
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
self._common_causes = parse_state(common_causes)
self._instruments = parse_state(instruments)
self._effect_modifiers = parse_state(effect_modifiers)
if common_causes is not None and instruments is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
instrument_names=self._instruments,
effect_modifier_names = self._effect_modifiers,
observed_node_names=self._data.columns.tolist()
)
elif common_causes is not None:
@ -80,6 +85,7 @@ class CausalModel:
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
effect_modifier_names = self._effect_modifiers,
observed_node_names=self._data.columns.tolist()
)
elif instruments is not None:
@ -87,6 +93,7 @@ class CausalModel:
self._treatment,
self._outcome,
instrument_names=self._instruments,
effect_modifier_names = self._effect_modifiers,
observed_node_names=self._data.columns.tolist()
)
else:

Просмотреть файл

@ -105,6 +105,7 @@ def linear_dataset(beta, num_common_causes, num_samples, num_instruments=0,
"outcome_name": outcome,
"common_causes_names": common_causes,
"instrument_names": instruments,
"effect_modifier_names": effect_modifiers,
"dot_graph": dot_graph,
"gml_graph": gml_graph,
"ate": ate