diff --git a/dowhy/gcm/causal_models.py b/dowhy/gcm/causal_models.py index bcfa2f5bc..e185126d2 100644 --- a/dowhy/gcm/causal_models.py +++ b/dowhy/gcm/causal_models.py @@ -32,12 +32,17 @@ class ProbabilisticCausalModel: causal mechanisms can be any general stochastic models.""" def __init__( - self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph + self, + graph: Optional[DirectedGraph] = None, + graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph, + remove_existing_mechanisms: bool = False, ): """ :param graph: Optional graph object to be used as causal graph. :param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph constructor. + :param remove_existing_mechanisms: If True, removes existing causal mechanisms assigned to nodes if they exist. + Otherwise, does not modify graph. """ # Todo: Remove after https://github.com/py-why/dowhy/pull/943. from dowhy.causal_graph import CausalGraph @@ -50,6 +55,11 @@ class ProbabilisticCausalModel: elif isinstance(graph, CausalGraph): graph = graph_copier(graph._graph) + if remove_existing_mechanisms: + for node in graph.nodes: + if CAUSAL_MECHANISM in graph.nodes[node]: + del graph.nodes[node][CAUSAL_MECHANISM] + self.graph = graph self.graph_copier = graph_copier