Remove existing causal mechanisms when creating GCM

Before, when a causal graph had causal mechanisms assigned, they were also used when creating a new GCM object based on it. Now, they are removed (from a copied version of the graph).

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Patrick Bloebaum 2024-06-17 10:27:05 -07:00 коммит произвёл Patrick Blöbaum
Родитель 57fd684450
Коммит 474c5844ec
1 изменённых файлов: 11 добавлений и 1 удалений

Просмотреть файл

@ -32,12 +32,17 @@ class ProbabilisticCausalModel:
causal mechanisms can be any general stochastic models."""
def __init__(
self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph
self,
graph: Optional[DirectedGraph] = None,
graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph,
remove_existing_mechanisms: bool = False,
):
"""
:param graph: Optional graph object to be used as causal graph.
:param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph
constructor.
:param remove_existing_mechanisms: If True, removes existing causal mechanisms assigned to nodes if they exist.
Otherwise, does not modify graph.
"""
# Todo: Remove after https://github.com/py-why/dowhy/pull/943.
from dowhy.causal_graph import CausalGraph
@ -50,6 +55,11 @@ class ProbabilisticCausalModel:
elif isinstance(graph, CausalGraph):
graph = graph_copier(graph._graph)
if remove_existing_mechanisms:
for node in graph.nodes:
if CAUSAL_MECHANISM in graph.nodes[node]:
del graph.nodes[node][CAUSAL_MECHANISM]
self.graph = graph
self.graph_copier = graph_copier