added support for effect modifiers in the non-graph causal model interface and improved the graph drawing functions to show effect modifiers
This commit is contained in:
Родитель
081f2c3563
Коммит
98b6ed2a82
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -7,23 +7,32 @@ import itertools
|
|||
|
||||
class CausalGraph:
|
||||
|
||||
"""Class for creating and modifying the causal graph.
|
||||
|
||||
Accepts a graph string (or a text file) in gml format (preferred) and dot format. Graphviz-like attributes can be set for edges and nodes. E.g. style="dashed" as an edge attribute ensures that the edge is drawn with a dashed line.
|
||||
|
||||
If a graph string is not given, names of treatment, outcome, and confounders, instruments and effect modifiers (if any) can be provided to create the graph.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
treatment_name, outcome_name,
|
||||
graph=None,
|
||||
common_cause_names=None,
|
||||
instrument_names=None,
|
||||
effect_modifier_names=None,
|
||||
observed_node_names=None,
|
||||
missing_nodes_as_confounders=False):
|
||||
self.treatment_name = parse_state(treatment_name)
|
||||
self.outcome_name = parse_state(outcome_name)
|
||||
instrument_names = parse_state(instrument_names)
|
||||
common_cause_names = parse_state(common_cause_names)
|
||||
effect_modifier_names = parse_state(effect_modifier_names)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
if graph is None:
|
||||
self._graph = nx.DiGraph()
|
||||
self._graph = self.build_graph(common_cause_names,
|
||||
instrument_names)
|
||||
instrument_names, effect_modifier_names)
|
||||
elif re.match(r".*\.dot", graph):
|
||||
# load dot file
|
||||
try:
|
||||
|
@ -74,13 +83,35 @@ class CausalGraph:
|
|||
self.logger.warning("Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed.")
|
||||
self.logger.info("Using Matplotlib for plotting")
|
||||
import matplotlib.pyplot as plt
|
||||
solid_edges = [(n1,n2) for n1,n2, e in self._graph.edges(data=True) if 'style' not in e ]
|
||||
dashed_edges =[(n1,n2) for n1,n2, e in self._graph.edges(data=True) if ('style' in e and e['style']=="dashed") ]
|
||||
plt.clf()
|
||||
nx.draw_networkx(self._graph, pos=nx.shell_layout(self._graph))
|
||||
pos = nx.layout.shell_layout(self._graph)
|
||||
nodes = nx.draw_networkx_nodes(self._graph, pos)
|
||||
edges = nx.draw_networkx_edges(
|
||||
self._graph,
|
||||
pos,
|
||||
edgelist=solid_edges,
|
||||
arrowstyle="fancy",
|
||||
arrowsize=1)
|
||||
edges = nx.draw_networkx_edges(
|
||||
self._graph,
|
||||
pos,
|
||||
edgelist=dashed_edges,
|
||||
arrowstyle="->",
|
||||
style="dashed",
|
||||
arrowsize=2)
|
||||
labels = nx.draw_networkx_labels(self._graph, pos)
|
||||
#nx.draw_networkx(self._graph, pos=nx.shell_layout(self._graph))
|
||||
plt.axis('off')
|
||||
plt.savefig(out_filename)
|
||||
plt.draw()
|
||||
|
||||
def build_graph(self, common_cause_names, instrument_names):
|
||||
def build_graph(self, common_cause_names, instrument_names, effect_modifier_names):
|
||||
""" Creates nodes and edges based on variable names and their semantics.
|
||||
Currently only considers the graphical representation of "direct" effect modifiers. Thus, all effect modifiers are assumed to be "direct" unless otherwise expressed using a graph. Based on the taxonomy of effect modifiers by VanderWheele and Robins: "Four types of effect modification: A classification based on directed acyclic graphs. Epidemiology. 2007."
|
||||
"""
|
||||
|
||||
for treatment in self.treatment_name:
|
||||
self._graph.add_node(treatment, observed="yes")
|
||||
for outcome in self.outcome_name:
|
||||
|
@ -95,6 +126,7 @@ class CausalGraph:
|
|||
self._graph.add_node(node_name, observed="yes")
|
||||
self._graph.add_edge(node_name, treatment)
|
||||
self._graph.add_edge(node_name, outcome)
|
||||
|
||||
# Adding instruments
|
||||
if instrument_names:
|
||||
if type(instrument_names[0]) != tuple:
|
||||
|
@ -107,6 +139,14 @@ class CausalGraph:
|
|||
for instrument, treatment in itertools.product(instrument_names):
|
||||
self._graph.add_node(instrument, observed="yes")
|
||||
self._graph.add_edge(instrument, treatment)
|
||||
|
||||
# Adding effect modifiers
|
||||
if effect_modifier_names is not None:
|
||||
for node_name in effect_modifier_names:
|
||||
for outcome in self.outcome_name:
|
||||
self._graph.add_node(node_name, observed="yes")
|
||||
self._graph.add_edge(node_name, outcome, style = "dotted", headport="s", tailport="n")
|
||||
self._graph.add_edge(outcome, node_name, style = "dotted", headport="n", tailport="s") # TODO make the ports more general so that they apply not just to top-bottom node configurations
|
||||
return self._graph
|
||||
|
||||
def add_node_attributes(self, observed_node_names):
|
||||
|
|
|
@ -25,7 +25,9 @@ class CausalModel:
|
|||
"""
|
||||
|
||||
def __init__(self, data, treatment, outcome, graph=None,
|
||||
common_causes=None, instruments=None, estimand_type="ate",
|
||||
common_causes=None, instruments=None,
|
||||
effect_modifiers=None,
|
||||
estimand_type="ate",
|
||||
proceed_when_unidentifiable=False,
|
||||
missing_nodes_as_confounders=False,
|
||||
**kwargs):
|
||||
|
@ -46,6 +48,7 @@ class CausalModel:
|
|||
:param common_causes: names of common causes of treatment and _outcome
|
||||
:param instruments: names of instrumental variables for the effect of
|
||||
treatment on outcome
|
||||
:param effect_modifiers: names of variables that can modify the treatment effect (useful for heterogeneous treatment effect estimation)
|
||||
:returns: an instance of CausalModel class
|
||||
|
||||
"""
|
||||
|
@ -67,12 +70,14 @@ class CausalModel:
|
|||
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
|
||||
self._common_causes = parse_state(common_causes)
|
||||
self._instruments = parse_state(instruments)
|
||||
self._effect_modifiers = parse_state(effect_modifiers)
|
||||
if common_causes is not None and instruments is not None:
|
||||
self._graph = CausalGraph(
|
||||
self._treatment,
|
||||
self._outcome,
|
||||
common_cause_names=self._common_causes,
|
||||
instrument_names=self._instruments,
|
||||
effect_modifier_names = self._effect_modifiers,
|
||||
observed_node_names=self._data.columns.tolist()
|
||||
)
|
||||
elif common_causes is not None:
|
||||
|
@ -80,6 +85,7 @@ class CausalModel:
|
|||
self._treatment,
|
||||
self._outcome,
|
||||
common_cause_names=self._common_causes,
|
||||
effect_modifier_names = self._effect_modifiers,
|
||||
observed_node_names=self._data.columns.tolist()
|
||||
)
|
||||
elif instruments is not None:
|
||||
|
@ -87,6 +93,7 @@ class CausalModel:
|
|||
self._treatment,
|
||||
self._outcome,
|
||||
instrument_names=self._instruments,
|
||||
effect_modifier_names = self._effect_modifiers,
|
||||
observed_node_names=self._data.columns.tolist()
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -105,6 +105,7 @@ def linear_dataset(beta, num_common_causes, num_samples, num_instruments=0,
|
|||
"outcome_name": outcome,
|
||||
"common_causes_names": common_causes,
|
||||
"instrument_names": instruments,
|
||||
"effect_modifier_names": effect_modifiers,
|
||||
"dot_graph": dot_graph,
|
||||
"gml_graph": gml_graph,
|
||||
"ate": ate
|
||||
|
|
Загрузка…
Ссылка в новой задаче