Effect estimation over timeseries data (#1218)
* Library functions for temporal causal functionality * shifting plotter function * printing graph: best practices * added docstrings * moved datasets * updated tutorial notebook * sphinx documentation * updated shifting columns with 0,1,..,max_lag * support for dot format * tigramite support * updated filter to be a hidden function * black and isort utils * black and isort timeseries * updated notebook text Signed-off-by: Amit Sharma <amit_sharma@live.com> * integer range fix * correction in timestamp : notebook text * time lagged causal estimation * removed cell outputs * find ancestors * include ancestors in notebook * formatting changes * comments : notebook * multiple time lags : csv graph' * multiple time lags * unrolled graph using bfs * cleanup of functions * removed find parents and ancestors * tests for causal graph creation * tests for adding lagged edges * tests for shifting columns * tigramite dependency added --------- Signed-off-by: Amit Sharma <amit_sharma@live.com> Co-authored-by: Amit Sharma <amit_sharma@live.com>
This commit is contained in:
Родитель
e783e37db0
Коммит
becbf7f502
|
@ -0,0 +1,15 @@
|
|||
V1,V2,V3,V4,V5,V6,V7
|
||||
1,2,3,4,5,6,7
|
||||
2,3,4,5,6,7,8
|
||||
3,4,5,6,7,8,9
|
||||
4,5,6,7,8,9,10
|
||||
0,1,5,7,8,9,7
|
||||
3,5,4,1,2,6,5
|
||||
6,7,1,2,4,5,9
|
||||
12,3,5,7,3,8,9
|
||||
3,2,1,6,3,8,9
|
||||
4,6,3,5,8,9,1
|
||||
3,5,9,6,2,1,3
|
||||
5,2,6,8,11,3,4
|
||||
2,2,4,1,1,4,6
|
||||
5,6,4,3,4,6,2
|
|
|
@ -0,0 +1,8 @@
|
|||
node1,node2,time_lag
|
||||
V1,V2,3
|
||||
V2,V3,4
|
||||
V5,V6,1
|
||||
V4,V7,4
|
||||
V4,V5,2
|
||||
V7,V6,3
|
||||
V7,V6,5
|
|
|
@ -0,0 +1,8 @@
|
|||
digraph G {
|
||||
V1 -> V2 [label="(3)"];
|
||||
V2 -> V3 [label="(4)"];
|
||||
V5 -> V6 [label="(1)"];
|
||||
V4 -> V7 [label="(4)"];
|
||||
V4 -> V5 [label="(2)"];
|
||||
V7 -> V6 [label="(3, 5)"];
|
||||
}
|
|
@ -0,0 +1,306 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Effect inference with timeseries data\n",
|
||||
"\n",
|
||||
"In this notebook, we will look at an example of causal effect inference from timeseries data. We will use DoWhy's functionality to add temporal dependencies to a causal graph and estimate causal effect based on the augmented graph. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import networkx as nx\n",
|
||||
"import pandas as pd\n",
|
||||
"from dowhy.utils.timeseries import create_graph_from_csv,create_graph_from_user\n",
|
||||
"from dowhy.utils.plotting import plot, pretty_print_graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading timeseries data and causal graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_path=\"../datasets/temporal_dataset.csv\"\n",
|
||||
"\n",
|
||||
"dataframe=pd.read_csv(dataset_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In temporal causal inference, accurately estimating causal effects often requires accounting for time lags between nodes in a graph. For instance, if $node_1$ influences $node_2$ with a time lag of 5 timestamps, we represent this dependency as $node_1^{t-5}$ -> $node_2^{t}$.\n",
|
||||
"\n",
|
||||
"We can provide the causal graph as a networkx DAG or as a dot file. The edge attributes should mention the exact `time_lag` that is associated with each edge (if any)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dowhy.utils.timeseries import create_graph_from_dot_format\n",
|
||||
"\n",
|
||||
"file_path = \"../datasets/temporal_graph.dot\"\n",
|
||||
"\n",
|
||||
"graph = create_graph_from_dot_format(file_path)\n",
|
||||
"plot(graph)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also create a csv file with the edges in the temporal graph. The columns in the csv are node1, node2, time_lag which represents an directed edge node1 -> node2 with the time lag of time_lag. Let us consider the following graph as the input:\n",
|
||||
"\n",
|
||||
"| node1 | node2 | time_lag |\n",
|
||||
"|--------|--------|----------|\n",
|
||||
"| V1 | V2 | 3 |\n",
|
||||
"| V2 | V3 | 4 |\n",
|
||||
"| V5 | V6 | 1 |\n",
|
||||
"| V4 | V7 | 4 |\n",
|
||||
"| V4 | V5 | 2 |\n",
|
||||
"| V7 | V6 | 3 |\n",
|
||||
"| V7 | V6 | 5 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Input a csv file with the edges in the graph with the columns: node_1,node_2,time_lag\n",
|
||||
"file_path = \"../datasets/temporal_graph.csv\"\n",
|
||||
"\n",
|
||||
"# Create the graph from the CSV file\n",
|
||||
"graph = create_graph_from_csv(file_path)\n",
|
||||
"plot(graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset Shifting and Filtering\n",
|
||||
"\n",
|
||||
"To prepare the dataset for temporal causal inference, we need to shift the columns by the given time lag.\n",
|
||||
"\n",
|
||||
"For example, in the causal graph above, $node_1^{t-5}$ -> $node_2^{t}$ with a lag of 5. When considering $node_2$ as the target node, the data for $node_1$ should be shifted down by 5 timestamps. This adjustment ensures that the edge $node_1$ -> $node_2$ accurately represents the lagged dependency. Shifting the data in this manner creates additional columns and allows downstream estimators to acccess the correct values in the same row of a dataframe. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dowhy.timeseries.temporal_shift import shift_columns_by_lag_using_unrolled_graph, add_lagged_edges"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# the outcome node for which effect estimation has to be done, node:6\n",
|
||||
"target_node = 'V6'\n",
|
||||
"unrolled_graph = add_lagged_edges(graph, target_node)\n",
|
||||
"plot(unrolled_graph)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"time_shifted_df = shift_columns_by_lag_using_unrolled_graph(dataframe, unrolled_graph)\n",
|
||||
"time_shifted_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Causal Effect Estimation\n",
|
||||
"\n",
|
||||
"Once you have the new dataframe, causal effect estimation can be performed on the target node with respect to the action nodes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"target_node = 'V6_0'\n",
|
||||
"# include all the treatments\n",
|
||||
"treatment_columns = list(time_shifted_df.columns)\n",
|
||||
"treatment_columns.remove(target_node)\n",
|
||||
"treatment_columns"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# perform causal effect estimation on this new dataset\n",
|
||||
"import dowhy\n",
|
||||
"from dowhy import CausalModel\n",
|
||||
"\n",
|
||||
"model = CausalModel(\n",
|
||||
" data=time_shifted_df,\n",
|
||||
" treatment='V5_-1',\n",
|
||||
" outcome=target_node,\n",
|
||||
" graph = unrolled_graph\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"identified_estimand = model.identify_effect()\n",
|
||||
"\n",
|
||||
"estimate = model.estimate_effect(identified_estimand,\n",
|
||||
" method_name=\"backdoor.linear_regression\",\n",
|
||||
" test_significance=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(estimate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Importing temporal causal graph from Tigramite library\n",
|
||||
"\n",
|
||||
"Tigramite is a popular temporal causal discovery library. In this section, we highlight how the causal graph can be obtained by applying PCMCI+ algorithm from tigramite and imported into DoWhy."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install tigramite"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tigramite\n",
|
||||
"import tigramite.data_processing as pp\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"dataframe = dataframe.astype(float)\n",
|
||||
"var_names = dataframe.columns\n",
|
||||
"# convert the dataframe values to float\n",
|
||||
"dataframe = pp.DataFrame(dataframe.values, var_names=var_names)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tigramite import plotting as tp\n",
|
||||
"tp.plot_timeseries(dataframe, figsize=(15, 5)); plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tigramite.pcmci import PCMCI\n",
|
||||
"from tigramite.independence_tests.parcorr import ParCorr\n",
|
||||
"import numpy as np\n",
|
||||
"parcorr = ParCorr(significance='analytic')\n",
|
||||
"pcmci = PCMCI(\n",
|
||||
" dataframe=dataframe, \n",
|
||||
" cond_ind_test=parcorr,\n",
|
||||
" verbosity=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"correlations = pcmci.run_bivci(tau_max=3, val_only=True)['val_matrix']\n",
|
||||
"matrix_lags = np.argmax(np.abs(correlations), axis=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tau_max = 3\n",
|
||||
"pc_alpha = None\n",
|
||||
"pcmci.verbosity = 2\n",
|
||||
"\n",
|
||||
"results = pcmci.run_pcmciplus(tau_min=0, tau_max=tau_max, pc_alpha=pc_alpha)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dowhy.utils.timeseries import create_graph_from_networkx_array\n",
|
||||
"\n",
|
||||
"graph = create_graph_from_networkx_array(results['graph'], var_names)\n",
|
||||
"\n",
|
||||
"plot(graph)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
from collections import deque
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def add_lagged_edges(graph: nx.DiGraph, start_node: str) -> nx.DiGraph:
|
||||
"""
|
||||
Perform a reverse BFS starting from the node and proceed to parents level-wise,
|
||||
adding edges from the ancestor to the current node with the accumulated time lag if it does not already exist.
|
||||
Additionally, create lagged nodes for each time lag encountered.
|
||||
|
||||
:param graph: The directed graph object.
|
||||
:type graph: networkx.DiGraph
|
||||
:param start_node: The node from which to start the reverse BFS.
|
||||
:type start_node: string
|
||||
:return: A new graph with added edges based on accumulated time lags and lagged nodes.
|
||||
:rtype: networkx.DiGraph
|
||||
"""
|
||||
new_graph = nx.DiGraph()
|
||||
queue = deque([start_node])
|
||||
lagged_node_mapping = {} # Maps original nodes to their corresponding lagged nodes
|
||||
|
||||
while queue:
|
||||
current_node = queue.popleft()
|
||||
|
||||
for parent in graph.predecessors(current_node):
|
||||
edge_data = graph.get_edge_data(parent, current_node)
|
||||
if "time_lag" in edge_data:
|
||||
parent_time_lag = edge_data["time_lag"]
|
||||
|
||||
# Ensure parent_time_lag is in tuple form
|
||||
if not isinstance(parent_time_lag, tuple):
|
||||
parent_time_lag = (parent_time_lag,)
|
||||
|
||||
for lag in parent_time_lag:
|
||||
# Find or create the lagged node for the current node
|
||||
if current_node in lagged_node_mapping:
|
||||
lagged_nodes = lagged_node_mapping[current_node]
|
||||
else:
|
||||
lagged_nodes = set()
|
||||
lagged_nodes.add(f"{current_node}_0")
|
||||
new_graph.add_node(f"{current_node}_0")
|
||||
lagged_node_mapping[current_node] = lagged_nodes
|
||||
|
||||
# For each lagged node, create new time-lagged parent nodes and add edges
|
||||
new_lagged_nodes = set()
|
||||
for lagged_node in lagged_nodes:
|
||||
total_lag = -int(lagged_node.split("_")[-1]) + lag
|
||||
new_lagged_parent_node = f"{parent}_{-total_lag}"
|
||||
new_lagged_nodes.add(new_lagged_parent_node)
|
||||
|
||||
if not new_graph.has_node(new_lagged_parent_node):
|
||||
new_graph.add_node(new_lagged_parent_node)
|
||||
|
||||
new_graph.add_edge(new_lagged_parent_node, lagged_node)
|
||||
|
||||
# Add the parent to the queue for further exploration
|
||||
queue.append(parent)
|
||||
|
||||
# append the lagged nodes
|
||||
if parent in lagged_node_mapping:
|
||||
lagged_node_mapping[parent] = lagged_node_mapping[parent].union(new_lagged_nodes)
|
||||
else:
|
||||
lagged_node_mapping[parent] = new_lagged_nodes
|
||||
|
||||
for original_node, lagged_nodes in lagged_node_mapping.items():
|
||||
sorted_lagged_nodes = sorted(lagged_nodes, key=lambda x: int(x.split("_")[-1]))
|
||||
for i in range(len(sorted_lagged_nodes) - 1):
|
||||
lesser_lagged_node = sorted_lagged_nodes[i]
|
||||
more_lagged_node = sorted_lagged_nodes[i + 1]
|
||||
new_graph.add_edge(lesser_lagged_node, more_lagged_node)
|
||||
|
||||
return new_graph
|
||||
|
||||
|
||||
def shift_columns_by_lag_using_unrolled_graph(df: pd.DataFrame, unrolled_graph: nx.DiGraph) -> pd.DataFrame:
|
||||
"""
|
||||
Given a dataframe and an unrolled graph, this function shifts the columns in the dataframe by the corresponding time lags mentioned in the node names of the unrolled graph,
|
||||
creating a new unique column for each shifted version.
|
||||
|
||||
:param df: The dataframe to shift.
|
||||
:type df: pandas.DataFrame
|
||||
:param unrolled_graph: The unrolled graph with nodes containing time lags in their names.
|
||||
:type unrolled_graph: networkx.DiGraph
|
||||
:return: The dataframe with the columns shifted by the corresponding time lags.
|
||||
:rtype: pandas.DataFrame
|
||||
"""
|
||||
new_df = pd.DataFrame()
|
||||
for node in unrolled_graph.nodes:
|
||||
if "_" in node:
|
||||
base_node, lag_str = node.rsplit("_", 1)
|
||||
try:
|
||||
lag = -int(lag_str)
|
||||
if base_node in df.columns:
|
||||
new_column_name = f"{base_node}_{-lag}"
|
||||
new_df[new_column_name] = df[base_node].shift(lag, axis=0, fill_value=0)
|
||||
except ValueError:
|
||||
print(f"Warning: Cannot extract lag from node name {node}. Expected format 'baseNode_lag'")
|
||||
|
||||
return new_df
|
|
@ -199,3 +199,17 @@ def _plot_as_pyplot_figure(pygraphviz_graph: Any, figure_size: Optional[Tuple[in
|
|||
|
||||
if figure_size is not None:
|
||||
plt.rcParams["figure.figsize"] = org_fig_size
|
||||
|
||||
|
||||
def pretty_print_graph(graph: nx.DiGraph) -> None:
|
||||
"""
|
||||
Pretty print the graph edges with time lags.
|
||||
|
||||
:param graph: The networkx graph.
|
||||
:type graph: networkx.Graph
|
||||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
print("\nGraph edges with time lags:")
|
||||
for edge in graph.edges(data=True):
|
||||
print(f"{edge[0]} -> {edge[1]} with time-lagged dependency {edge[2]['time_lag']}")
|
||||
|
|
|
@ -0,0 +1,233 @@
|
|||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def create_graph_from_user() -> nx.DiGraph:
|
||||
"""
|
||||
Creates a directed graph based on user input from the console.
|
||||
|
||||
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
|
||||
Each edge can contain multiple time lags, therefore each combination of (node1,node2,time_lag) must be input individually by the user.
|
||||
|
||||
The user is prompted to enter edges one by one in the format 'node1 node2 time_lag',
|
||||
where 'node1' and 'node2' are the nodes connected by the edge, and 'time_lag' is a numerical
|
||||
value representing the weight of the edge. The user should enter 'done' to finish inputting edges.
|
||||
|
||||
:return: A directed graph created from the user's input.
|
||||
:rtype: nx.DiGraph
|
||||
|
||||
Example user input:
|
||||
Enter an edge: A B 4
|
||||
Enter an edge: B C 2
|
||||
Enter an edge: done
|
||||
"""
|
||||
graph = nx.DiGraph()
|
||||
|
||||
print("Enter the graph as a list of edges with time lags. Enter 'done' when you are finished.")
|
||||
print("Each edge should be entered in the format 'node1 node2 time_lag'. For example: 'A B 4'")
|
||||
|
||||
while True:
|
||||
edge = input("Enter an edge: ")
|
||||
if edge.lower() == "done":
|
||||
break
|
||||
edge = edge.split()
|
||||
if len(edge) != 3:
|
||||
print("Invalid edge. Please enter an edge in the format 'node1 node2 time_lag'.")
|
||||
continue
|
||||
node1, node2, time_lag = edge
|
||||
try:
|
||||
time_lag = int(time_lag)
|
||||
except ValueError:
|
||||
print("Invalid weight. Please enter a numerical value for the time_lag.")
|
||||
continue
|
||||
|
||||
# Check if the edge already exists
|
||||
if graph.has_edge(node1, node2):
|
||||
# If the edge exists, append the time_lag to the existing tuple
|
||||
current_time_lag = graph[node1][node2]["time_lag"]
|
||||
if isinstance(current_time_lag, tuple):
|
||||
graph[node1][node2]["time_lag"] = current_time_lag + (time_lag,)
|
||||
else:
|
||||
graph[node1][node2]["time_lag"] = (current_time_lag, time_lag)
|
||||
else:
|
||||
# If the edge does not exist, create a new edge with a tuple containing the time_lag
|
||||
graph.add_edge(node1, node2, time_lag=(time_lag,))
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def create_graph_from_csv(file_path: str) -> nx.DiGraph:
|
||||
"""
|
||||
Creates a directed graph from a CSV file.
|
||||
|
||||
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
|
||||
Each edge can contain multiple time lags, therefore each combination of (node1,node2,time_lag) must be input individually in the CSV file.
|
||||
|
||||
The CSV file should have at least three columns: 'node1', 'node2', and 'time_lag'.
|
||||
Each row represents an edge from 'node1' to 'node2' with a 'time_lag' attribute.
|
||||
|
||||
:param file_path: The path to the CSV file.
|
||||
:type file_path: str
|
||||
:return: A directed graph created from the CSV file.
|
||||
:rtype: nx.DiGraph
|
||||
|
||||
Example:
|
||||
Example CSV content:
|
||||
|
||||
.. code-block:: csv
|
||||
|
||||
node1,node2,time_lag
|
||||
A,B,5
|
||||
B,C,2
|
||||
A,C,7
|
||||
"""
|
||||
# Read the CSV file into a DataFrame
|
||||
df = pd.read_csv(file_path)
|
||||
|
||||
# Initialize an empty directed graph
|
||||
graph = nx.DiGraph()
|
||||
|
||||
# Add edges with time lag to the graph
|
||||
for index, row in df.iterrows():
|
||||
# Add validation for the time lag column to be a number
|
||||
try:
|
||||
time_lag = int(row["time_lag"])
|
||||
except ValueError:
|
||||
print(
|
||||
"Invalid weight. Please enter a numerical value for the time_lag for the edge between {} and {}.".format(
|
||||
row["node1"], row["node2"]
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if the edge already exists
|
||||
if graph.has_edge(row["node1"], row["node2"]):
|
||||
# If the edge exists, append the time_lag to the existing tuple
|
||||
current_time_lag = graph[row["node1"]][row["node2"]]["time_lag"]
|
||||
if isinstance(current_time_lag, tuple):
|
||||
graph[row["node1"]][row["node2"]]["time_lag"] = current_time_lag + (time_lag,)
|
||||
else:
|
||||
graph[row["node1"]][row["node2"]]["time_lag"] = (current_time_lag, time_lag)
|
||||
else:
|
||||
# If the edge does not exist, create a new edge with a tuple containing the time_lag
|
||||
graph.add_edge(row["node1"], row["node2"], time_lag=(time_lag,))
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def create_graph_from_dot_format(file_path: str) -> nx.DiGraph:
|
||||
"""
|
||||
Creates a directed graph from a DOT file and ensures it is a DiGraph.
|
||||
|
||||
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
|
||||
Each edge can contain multiple valid time lags.
|
||||
|
||||
The DOT file should contain a graph in DOT format.
|
||||
|
||||
:param file_path: The path to the DOT file.
|
||||
:type file_path: str
|
||||
:return: A directed graph (DiGraph) created from the DOT file.
|
||||
:rtype: nx.DiGraph
|
||||
"""
|
||||
# Read the DOT file into a MultiDiGraph
|
||||
multi_graph = nx.drawing.nx_agraph.read_dot(file_path)
|
||||
|
||||
# Initialize a new DiGraph
|
||||
graph = nx.DiGraph()
|
||||
|
||||
# Iterate over edges of the MultiDiGraph
|
||||
for u, v, data in multi_graph.edges(data=True):
|
||||
if "label" in data:
|
||||
try:
|
||||
# Convert the label to a tuple of time lags
|
||||
time_lag_tuple = tuple(map(int, data["label"].strip("()").split(",")))
|
||||
|
||||
if graph.has_edge(u, v):
|
||||
existing_data = graph.get_edge_data(u, v)
|
||||
if "time_lag" in existing_data:
|
||||
# Merge the existing time lags with the new ones
|
||||
existing_time_lags = existing_data["time_lag"]
|
||||
new_time_lags = existing_time_lags + time_lag_tuple
|
||||
# Remove duplicates by converting to a set and back to a tuple
|
||||
graph[u][v]["time_lag"] = tuple(set(new_time_lags))
|
||||
else:
|
||||
graph[u][v]["time_lag"] = time_lag_tuple
|
||||
else:
|
||||
graph.add_edge(u, v, time_lag=time_lag_tuple)
|
||||
|
||||
except ValueError:
|
||||
print(f"Invalid weight for the edge between {u} and {v}.")
|
||||
return None
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def create_graph_from_networkx_array(array: np.ndarray, var_names: list) -> nx.DiGraph:
|
||||
"""
|
||||
Create a NetworkX directed graph from a numpy array with time lag information.
|
||||
|
||||
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
|
||||
Each edge can contain multiple valid time lags.
|
||||
|
||||
The resulting graph will be a directed graph with edge attributes indicating
|
||||
the type of link based on the array values.
|
||||
|
||||
:param array: A numpy array of shape (n, n, tau) representing the causal links.
|
||||
:type array: np.ndarray
|
||||
:param var_names: A list of variable names.
|
||||
:type var_names: list
|
||||
:return: A directed graph with edge attributes based on the array values.
|
||||
:rtype: nx.DiGraph
|
||||
"""
|
||||
n = array.shape[0] # Number of variables
|
||||
assert n == array.shape[1], "The array must be square."
|
||||
tau = array.shape[2] # Number of time lags
|
||||
|
||||
# Initialize a directed graph
|
||||
graph = nx.DiGraph()
|
||||
|
||||
# Add nodes with names
|
||||
graph.add_nodes_from(var_names)
|
||||
|
||||
# Iterate over all pairs of nodes
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
if i == j:
|
||||
continue # Skip self-loops
|
||||
|
||||
for t in range(tau):
|
||||
# Check for directed links
|
||||
if array[i, j, t] == "-->":
|
||||
if graph.has_edge(var_names[i], var_names[j]):
|
||||
# Append the time lag to the existing tuple
|
||||
current_time_lag = graph[var_names[i]][var_names[j]].get("time_lag", ())
|
||||
graph[var_names[i]][var_names[j]]["time_lag"] = current_time_lag + (t,)
|
||||
else:
|
||||
# Create a new edge with a tuple containing the time lag
|
||||
graph.add_edge(var_names[i], var_names[j], time_lag=(t,))
|
||||
|
||||
elif array[i, j, t] == "<--":
|
||||
if graph.has_edge(var_names[j], var_names[i]):
|
||||
# Append the time lag to the existing tuple
|
||||
current_time_lag = graph[var_names[j]][var_names[i]].get("time_lag", ())
|
||||
graph[var_names[j]][var_names[i]]["time_lag"] = current_time_lag + (t,)
|
||||
else:
|
||||
# Create a new edge with a tuple containing the time lag
|
||||
graph.add_edge(var_names[j], var_names[i], time_lag=(t,))
|
||||
|
||||
elif array[i, j, t] == "o-o":
|
||||
raise ValueError(
|
||||
"Unsupported link type 'o-o' found between {} and {} at lag {}.".format(
|
||||
var_names[i], var_names[j], t
|
||||
)
|
||||
)
|
||||
|
||||
elif array[i, j, t] == "x-x":
|
||||
raise ValueError(
|
||||
"Unsupported link type 'x-x' found between {} and {} at lag {}.".format(
|
||||
var_names[i], var_names[j], t
|
||||
)
|
||||
)
|
||||
|
||||
return graph
|
|
@ -0,0 +1,227 @@
|
|||
import unittest
|
||||
from io import StringIO
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from dowhy.utils.timeseries import create_graph_from_csv, create_graph_from_dot_format, create_graph_from_networkx_array
|
||||
|
||||
|
||||
class TestCreateGraphFromCSV(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
csv_content = """node1,node2,time_lag
|
||||
A,B,5
|
||||
B,C,2
|
||||
A,C,7"""
|
||||
df = pd.read_csv(StringIO(csv_content))
|
||||
df.to_csv("test.csv", index=False)
|
||||
|
||||
graph = create_graph_from_csv("test.csv")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 3)
|
||||
self.assertEqual(graph["A"]["B"]["time_lag"], (5,))
|
||||
self.assertEqual(graph["B"]["C"]["time_lag"], (2,))
|
||||
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
|
||||
|
||||
def test_multiple_time_lags(self):
|
||||
csv_content = """node1,node2,time_lag
|
||||
A,B,5
|
||||
A,B,3
|
||||
A,C,7"""
|
||||
df = pd.read_csv(StringIO(csv_content))
|
||||
df.to_csv("test.csv", index=False)
|
||||
|
||||
graph = create_graph_from_csv("test.csv")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 2)
|
||||
self.assertEqual(graph["A"]["B"]["time_lag"], (5, 3))
|
||||
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
|
||||
|
||||
def test_invalid_time_lag(self):
|
||||
csv_content = """node1,node2,time_lag
|
||||
A,B,five
|
||||
B,C,2
|
||||
A,C,7"""
|
||||
df = pd.read_csv(StringIO(csv_content))
|
||||
df.to_csv("test.csv", index=False)
|
||||
|
||||
graph = create_graph_from_csv("test.csv")
|
||||
|
||||
self.assertIsNone(graph)
|
||||
|
||||
def test_empty_csv(self):
|
||||
csv_content = """node1,node2,time_lag"""
|
||||
df = pd.read_csv(StringIO(csv_content))
|
||||
df.to_csv("test.csv", index=False)
|
||||
|
||||
graph = create_graph_from_csv("test.csv")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 0)
|
||||
|
||||
def test_self_loop(self):
|
||||
csv_content = """node1,node2,time_lag
|
||||
A,A,5"""
|
||||
df = pd.read_csv(StringIO(csv_content))
|
||||
df.to_csv("test.csv", index=False)
|
||||
|
||||
graph = create_graph_from_csv("test.csv")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 1)
|
||||
self.assertEqual(graph["A"]["A"]["time_lag"], (5,))
|
||||
|
||||
|
||||
class TestCreateGraphFromDotFormat(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Helper method to create a DOT file from string content
|
||||
def create_dot_file(dot_content, file_name="test.dot"):
|
||||
with open(file_name, "w") as f:
|
||||
f.write(dot_content)
|
||||
|
||||
self.create_dot_file = create_dot_file
|
||||
|
||||
def test_basic_functionality(self):
|
||||
dot_content = """digraph G {
|
||||
A -> B [label="(5)"];
|
||||
B -> C [label="(2)"];
|
||||
A -> C [label="(7)"];
|
||||
}"""
|
||||
self.create_dot_file(dot_content)
|
||||
graph = create_graph_from_dot_format("test.dot")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 3)
|
||||
self.assertEqual(graph["A"]["B"]["time_lag"], (5,))
|
||||
self.assertEqual(graph["B"]["C"]["time_lag"], (2,))
|
||||
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
|
||||
|
||||
def test_multiple_time_lags(self):
|
||||
dot_content = """digraph G {
|
||||
A -> B [label="(5,3)"];
|
||||
A -> C [label="(7)"];
|
||||
}"""
|
||||
self.create_dot_file(dot_content)
|
||||
graph = create_graph_from_dot_format("test.dot")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 2)
|
||||
self.assertEqual(graph["A"]["B"]["time_lag"], (5, 3))
|
||||
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
|
||||
|
||||
def test_invalid_time_lag(self):
|
||||
dot_content = """digraph G {
|
||||
A -> B [label="(five)"];
|
||||
B -> C [label="(2)"];
|
||||
A -> C [label="(7)"];
|
||||
}"""
|
||||
self.create_dot_file(dot_content)
|
||||
graph = create_graph_from_dot_format("test.dot")
|
||||
|
||||
self.assertIsNone(graph)
|
||||
|
||||
def test_empty_dot(self):
|
||||
dot_content = """digraph G {}"""
|
||||
self.create_dot_file(dot_content)
|
||||
graph = create_graph_from_dot_format("test.dot")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 0)
|
||||
|
||||
def test_self_loop(self):
|
||||
dot_content = """digraph G {
|
||||
A -> A [label="(5)"];
|
||||
}"""
|
||||
self.create_dot_file(dot_content)
|
||||
graph = create_graph_from_dot_format("test.dot")
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 1)
|
||||
self.assertEqual(graph["A"]["A"]["time_lag"], (5,))
|
||||
|
||||
|
||||
class TestCreateGraphFromNetworkxArray(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
array = np.zeros((3, 3, 2), dtype=object)
|
||||
array[0, 1, 0] = "-->"
|
||||
array[1, 2, 0] = "-->"
|
||||
array[0, 2, 1] = "-->"
|
||||
|
||||
var_names = ["X1", "X2", "X3"]
|
||||
|
||||
graph = create_graph_from_networkx_array(array, var_names)
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 3)
|
||||
self.assertTrue(graph.has_edge("X1", "X2"))
|
||||
self.assertEqual(graph["X1"]["X2"]["time_lag"], (0,))
|
||||
self.assertTrue(graph.has_edge("X1", "X3"))
|
||||
self.assertEqual(graph["X1"]["X3"]["time_lag"], (1,))
|
||||
self.assertTrue(graph.has_edge("X2", "X3"))
|
||||
self.assertEqual(graph["X2"]["X3"]["time_lag"], (0,))
|
||||
|
||||
def test_multiple_time_lags(self):
|
||||
array = np.zeros((3, 3, 3), dtype=object)
|
||||
array[0, 1, 0] = "-->"
|
||||
array[0, 1, 1] = "-->"
|
||||
array[1, 2, 2] = "-->"
|
||||
|
||||
var_names = ["X1", "X2", "X3"]
|
||||
|
||||
graph = create_graph_from_networkx_array(array, var_names)
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 2)
|
||||
self.assertTrue(graph.has_edge("X1", "X2"))
|
||||
self.assertEqual(graph["X1"]["X2"]["time_lag"], (0, 1))
|
||||
self.assertTrue(graph.has_edge("X2", "X3"))
|
||||
self.assertEqual(graph["X2"]["X3"]["time_lag"], (2,))
|
||||
|
||||
def test_invalid_link_type_oo(self):
|
||||
array = np.zeros((2, 2, 1), dtype=object)
|
||||
array[0, 1, 0] = "o-o"
|
||||
|
||||
var_names = ["X1", "X2"]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
create_graph_from_networkx_array(array, var_names)
|
||||
|
||||
def test_invalid_link_type_xx(self):
|
||||
array = np.zeros((2, 2, 1), dtype=object)
|
||||
array[0, 1, 0] = "x-x"
|
||||
|
||||
var_names = ["X1", "X2"]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
create_graph_from_networkx_array(array, var_names)
|
||||
|
||||
def test_empty_array(self):
|
||||
array = np.zeros((0, 0, 0), dtype=object)
|
||||
var_names = []
|
||||
|
||||
graph = create_graph_from_networkx_array(array, var_names)
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.nodes()), 0)
|
||||
self.assertEqual(len(graph.edges()), 0)
|
||||
|
||||
def test_self_loop(self):
|
||||
array = np.zeros((2, 2, 1), dtype=object)
|
||||
array[0, 0, 0] = "-->"
|
||||
|
||||
var_names = ["X1", "X2"]
|
||||
|
||||
graph = create_graph_from_networkx_array(array, var_names)
|
||||
|
||||
self.assertIsInstance(graph, nx.DiGraph)
|
||||
self.assertEqual(len(graph.edges()), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,142 @@
|
|||
import unittest
|
||||
from typing import List, Optional
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
from dowhy.timeseries.temporal_shift import add_lagged_edges, shift_columns_by_lag_using_unrolled_graph
|
||||
|
||||
|
||||
class TestAddLaggedEdges(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
graph = nx.DiGraph()
|
||||
graph.add_edge("A", "B", time_lag=(1,))
|
||||
graph.add_edge("B", "C", time_lag=(2,))
|
||||
|
||||
new_graph = add_lagged_edges(graph, "C")
|
||||
|
||||
self.assertIsInstance(new_graph, nx.DiGraph)
|
||||
self.assertTrue(new_graph.has_node("C_0"))
|
||||
self.assertTrue(new_graph.has_node("B_-2"))
|
||||
self.assertTrue(new_graph.has_node("A_-3"))
|
||||
self.assertTrue(new_graph.has_edge("B_-2", "C_0"))
|
||||
self.assertTrue(new_graph.has_edge("A_-3", "B_-2"))
|
||||
|
||||
def test_multiple_time_lags(self):
|
||||
graph = nx.DiGraph()
|
||||
graph.add_edge("A", "B", time_lag=(1, 2))
|
||||
graph.add_edge("B", "C", time_lag=(1, 3))
|
||||
|
||||
new_graph = add_lagged_edges(graph, "C")
|
||||
|
||||
self.assertIsInstance(new_graph, nx.DiGraph)
|
||||
self.assertTrue(new_graph.has_node("C_0"))
|
||||
self.assertTrue(new_graph.has_node("B_-1"))
|
||||
self.assertTrue(new_graph.has_node("B_-3"))
|
||||
self.assertTrue(new_graph.has_node("A_-2"))
|
||||
self.assertTrue(new_graph.has_node("A_-3"))
|
||||
self.assertTrue(new_graph.has_node("A_-4"))
|
||||
self.assertTrue(new_graph.has_node("A_-5"))
|
||||
self.assertTrue(new_graph.has_edge("B_-1", "C_0"))
|
||||
self.assertTrue(new_graph.has_edge("B_-3", "C_0"))
|
||||
self.assertTrue(new_graph.has_edge("A_-2", "B_-1"))
|
||||
self.assertTrue(new_graph.has_edge("A_-4", "B_-3"))
|
||||
self.assertTrue(new_graph.has_edge("B_-3", "B_-1"))
|
||||
self.assertTrue(new_graph.has_edge("A_-5", "A_-4"))
|
||||
self.assertTrue(new_graph.has_edge("A_-4", "A_-3"))
|
||||
self.assertTrue(new_graph.has_edge("A_-3", "A_-2"))
|
||||
|
||||
def test_complex_graph_structure(self):
|
||||
graph = nx.DiGraph()
|
||||
graph.add_edge("A", "B", time_lag=(1,))
|
||||
graph.add_edge("B", "C", time_lag=(2,))
|
||||
graph.add_edge("A", "C", time_lag=(3,))
|
||||
|
||||
new_graph = add_lagged_edges(graph, "C")
|
||||
|
||||
self.assertIsInstance(new_graph, nx.DiGraph)
|
||||
self.assertTrue(new_graph.has_node("C_0"))
|
||||
self.assertTrue(new_graph.has_node("B_-2"))
|
||||
self.assertTrue(new_graph.has_node("A_-3"))
|
||||
self.assertTrue(new_graph.has_edge("B_-2", "C_0"))
|
||||
self.assertTrue(new_graph.has_edge("A_-3", "B_-2"))
|
||||
self.assertTrue(new_graph.has_edge("A_-3", "C_0"))
|
||||
|
||||
def test_no_time_lag(self):
|
||||
graph = nx.DiGraph()
|
||||
graph.add_edge("A", "B")
|
||||
graph.add_edge("B", "C")
|
||||
|
||||
new_graph = add_lagged_edges(graph, "C")
|
||||
|
||||
self.assertIsInstance(new_graph, nx.DiGraph)
|
||||
self.assertEqual(len(new_graph.nodes()), 0)
|
||||
self.assertEqual(len(new_graph.edges()), 0)
|
||||
|
||||
|
||||
class TestShiftColumnsByLagUsingUnrolledGraph(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]})
|
||||
|
||||
unrolled_graph = nx.DiGraph()
|
||||
unrolled_graph.add_nodes_from(["A_0", "A_-1", "B_0", "B_-2"])
|
||||
|
||||
expected_df = pd.DataFrame(
|
||||
{"A_0": [1, 2, 3, 4, 5], "A_-1": [0, 1, 2, 3, 4], "B_0": [5, 4, 3, 2, 1], "B_-2": [0, 0, 5, 4, 3]}
|
||||
)
|
||||
|
||||
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
|
||||
|
||||
assert_frame_equal(result_df, expected_df)
|
||||
|
||||
def test_complex_graph_structure(self):
|
||||
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1], "C": [1, 3, 5, 7, 9]})
|
||||
|
||||
unrolled_graph = nx.DiGraph()
|
||||
unrolled_graph.add_nodes_from(["A_0", "A_-1", "B_0", "B_-2", "C_-1", "C_-3"])
|
||||
|
||||
expected_df = pd.DataFrame(
|
||||
{
|
||||
"A_0": [1, 2, 3, 4, 5],
|
||||
"A_-1": [0, 1, 2, 3, 4],
|
||||
"B_0": [5, 4, 3, 2, 1],
|
||||
"B_-2": [0, 0, 5, 4, 3],
|
||||
"C_-1": [0, 1, 3, 5, 7],
|
||||
"C_-3": [0, 0, 0, 1, 3],
|
||||
}
|
||||
)
|
||||
|
||||
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
|
||||
|
||||
assert_frame_equal(result_df, expected_df)
|
||||
|
||||
def test_invalid_node_format(self):
|
||||
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]})
|
||||
|
||||
unrolled_graph = nx.DiGraph()
|
||||
unrolled_graph.add_nodes_from(["A_0", "B_invalid"])
|
||||
|
||||
expected_df = pd.DataFrame({"A_0": [1, 2, 3, 4, 5]})
|
||||
|
||||
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
|
||||
|
||||
assert_frame_equal(result_df, expected_df)
|
||||
|
||||
def test_non_matching_columns(self):
|
||||
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]})
|
||||
|
||||
unrolled_graph = nx.DiGraph()
|
||||
unrolled_graph.add_nodes_from(["C_0", "C_-1"])
|
||||
|
||||
expected_df = pd.DataFrame()
|
||||
|
||||
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
|
||||
|
||||
assert_frame_equal(result_df, expected_df)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче