From 474c5844ec106b06d637b12cb7a86474d8c698bb Mon Sep 17 00:00:00 2001 From: Patrick Bloebaum Date: Mon, 17 Jun 2024 10:27:05 -0700 Subject: [PATCH] Remove existing causal mechanisms when creating GCM Before, when a causal graph had causal mechanisms assigned, they were also used when creating a new GCM object based on it. Now, they are removed (from a copied version of the graph). Signed-off-by: Patrick Bloebaum --- dowhy/gcm/causal_models.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dowhy/gcm/causal_models.py b/dowhy/gcm/causal_models.py index bcfa2f5bc..e185126d2 100644 --- a/dowhy/gcm/causal_models.py +++ b/dowhy/gcm/causal_models.py @@ -32,12 +32,17 @@ class ProbabilisticCausalModel: causal mechanisms can be any general stochastic models.""" def __init__( - self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph + self, + graph: Optional[DirectedGraph] = None, + graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph, + remove_existing_mechanisms: bool = False, ): """ :param graph: Optional graph object to be used as causal graph. :param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph constructor. + :param remove_existing_mechanisms: If True, removes existing causal mechanisms assigned to nodes if they exist. + Otherwise, does not modify graph. """ # Todo: Remove after https://github.com/py-why/dowhy/pull/943. from dowhy.causal_graph import CausalGraph @@ -50,6 +55,11 @@ class ProbabilisticCausalModel: elif isinstance(graph, CausalGraph): graph = graph_copier(graph._graph) + if remove_existing_mechanisms: + for node in graph.nodes: + if CAUSAL_MECHANISM in graph.nodes[node]: + del graph.nodes[node][CAUSAL_MECHANISM] + self.graph = graph self.graph_copier = graph_copier