Remove existing causal mechanisms when creating GCM
Before, when a causal graph had causal mechanisms assigned, they were also used when creating a new GCM object based on it. Now, they are removed (from a copied version of the graph). Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
This commit is contained in:
Родитель
57fd684450
Коммит
474c5844ec
|
@ -32,12 +32,17 @@ class ProbabilisticCausalModel:
|
|||
causal mechanisms can be any general stochastic models."""
|
||||
|
||||
def __init__(
|
||||
self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph
|
||||
self,
|
||||
graph: Optional[DirectedGraph] = None,
|
||||
graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph,
|
||||
remove_existing_mechanisms: bool = False,
|
||||
):
|
||||
"""
|
||||
:param graph: Optional graph object to be used as causal graph.
|
||||
:param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph
|
||||
constructor.
|
||||
:param remove_existing_mechanisms: If True, removes existing causal mechanisms assigned to nodes if they exist.
|
||||
Otherwise, does not modify graph.
|
||||
"""
|
||||
# Todo: Remove after https://github.com/py-why/dowhy/pull/943.
|
||||
from dowhy.causal_graph import CausalGraph
|
||||
|
@ -50,6 +55,11 @@ class ProbabilisticCausalModel:
|
|||
elif isinstance(graph, CausalGraph):
|
||||
graph = graph_copier(graph._graph)
|
||||
|
||||
if remove_existing_mechanisms:
|
||||
for node in graph.nodes:
|
||||
if CAUSAL_MECHANISM in graph.nodes[node]:
|
||||
del graph.nodes[node][CAUSAL_MECHANISM]
|
||||
|
||||
self.graph = graph
|
||||
self.graph_copier = graph_copier
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче