diff --git a/docs/source/example_notebooks/datasets/temporal_dataset.csv b/docs/source/example_notebooks/datasets/temporal_dataset.csv new file mode 100644 index 000000000..d29e7b27e --- /dev/null +++ b/docs/source/example_notebooks/datasets/temporal_dataset.csv @@ -0,0 +1,15 @@ +V1,V2,V3,V4,V5,V6,V7 +1,2,3,4,5,6,7 +2,3,4,5,6,7,8 +3,4,5,6,7,8,9 +4,5,6,7,8,9,10 +0,1,5,7,8,9,7 +3,5,4,1,2,6,5 +6,7,1,2,4,5,9 +12,3,5,7,3,8,9 +3,2,1,6,3,8,9 +4,6,3,5,8,9,1 +3,5,9,6,2,1,3 +5,2,6,8,11,3,4 +2,2,4,1,1,4,6 +5,6,4,3,4,6,2 \ No newline at end of file diff --git a/docs/source/example_notebooks/datasets/temporal_graph.csv b/docs/source/example_notebooks/datasets/temporal_graph.csv new file mode 100644 index 000000000..59048d5be --- /dev/null +++ b/docs/source/example_notebooks/datasets/temporal_graph.csv @@ -0,0 +1,8 @@ +node1,node2,time_lag +V1,V2,3 +V2,V3,4 +V5,V6,1 +V4,V7,4 +V4,V5,2 +V7,V6,3 +V7,V6,5 \ No newline at end of file diff --git a/docs/source/example_notebooks/datasets/temporal_graph.dot b/docs/source/example_notebooks/datasets/temporal_graph.dot new file mode 100644 index 000000000..da6b6ed67 --- /dev/null +++ b/docs/source/example_notebooks/datasets/temporal_graph.dot @@ -0,0 +1,8 @@ +digraph G { + V1 -> V2 [label="(3)"]; + V2 -> V3 [label="(4)"]; + V5 -> V6 [label="(1)"]; + V4 -> V7 [label="(4)"]; + V4 -> V5 [label="(2)"]; + V7 -> V6 [label="(3, 5)"]; +} \ No newline at end of file diff --git a/docs/source/example_notebooks/timeseries/effect_inference_timeseries_data.ipynb b/docs/source/example_notebooks/timeseries/effect_inference_timeseries_data.ipynb new file mode 100644 index 000000000..1d77bef6b --- /dev/null +++ b/docs/source/example_notebooks/timeseries/effect_inference_timeseries_data.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Effect inference with timeseries data\n", + "\n", + "In this notebook, we will look at an example of causal effect inference from timeseries data. We will use DoWhy's functionality to add temporal dependencies to a causal graph and estimate causal effect based on the augmented graph. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import networkx as nx\n", + "import pandas as pd\n", + "from dowhy.utils.timeseries import create_graph_from_csv,create_graph_from_user\n", + "from dowhy.utils.plotting import plot, pretty_print_graph" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading timeseries data and causal graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path=\"../datasets/temporal_dataset.csv\"\n", + "\n", + "dataframe=pd.read_csv(dataset_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In temporal causal inference, accurately estimating causal effects often requires accounting for time lags between nodes in a graph. For instance, if $node_1$ influences $node_2$ with a time lag of 5 timestamps, we represent this dependency as $node_1^{t-5}$ -> $node_2^{t}$.\n", + "\n", + "We can provide the causal graph as a networkx DAG or as a dot file. The edge attributes should mention the exact `time_lag` that is associated with each edge (if any)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dowhy.utils.timeseries import create_graph_from_dot_format\n", + "\n", + "file_path = \"../datasets/temporal_graph.dot\"\n", + "\n", + "graph = create_graph_from_dot_format(file_path)\n", + "plot(graph)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also create a csv file with the edges in the temporal graph. The columns in the csv are node1, node2, time_lag which represents an directed edge node1 -> node2 with the time lag of time_lag. Let us consider the following graph as the input:\n", + "\n", + "| node1 | node2 | time_lag |\n", + "|--------|--------|----------|\n", + "| V1 | V2 | 3 |\n", + "| V2 | V3 | 4 |\n", + "| V5 | V6 | 1 |\n", + "| V4 | V7 | 4 |\n", + "| V4 | V5 | 2 |\n", + "| V7 | V6 | 3 |\n", + "| V7 | V6 | 5 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Input a csv file with the edges in the graph with the columns: node_1,node_2,time_lag\n", + "file_path = \"../datasets/temporal_graph.csv\"\n", + "\n", + "# Create the graph from the CSV file\n", + "graph = create_graph_from_csv(file_path)\n", + "plot(graph)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset Shifting and Filtering\n", + "\n", + "To prepare the dataset for temporal causal inference, we need to shift the columns by the given time lag.\n", + "\n", + "For example, in the causal graph above, $node_1^{t-5}$ -> $node_2^{t}$ with a lag of 5. When considering $node_2$ as the target node, the data for $node_1$ should be shifted down by 5 timestamps. This adjustment ensures that the edge $node_1$ -> $node_2$ accurately represents the lagged dependency. Shifting the data in this manner creates additional columns and allows downstream estimators to acccess the correct values in the same row of a dataframe. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dowhy.timeseries.temporal_shift import shift_columns_by_lag_using_unrolled_graph, add_lagged_edges" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# the outcome node for which effect estimation has to be done, node:6\n", + "target_node = 'V6'\n", + "unrolled_graph = add_lagged_edges(graph, target_node)\n", + "plot(unrolled_graph)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "time_shifted_df = shift_columns_by_lag_using_unrolled_graph(dataframe, unrolled_graph)\n", + "time_shifted_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Causal Effect Estimation\n", + "\n", + "Once you have the new dataframe, causal effect estimation can be performed on the target node with respect to the action nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_node = 'V6_0'\n", + "# include all the treatments\n", + "treatment_columns = list(time_shifted_df.columns)\n", + "treatment_columns.remove(target_node)\n", + "treatment_columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# perform causal effect estimation on this new dataset\n", + "import dowhy\n", + "from dowhy import CausalModel\n", + "\n", + "model = CausalModel(\n", + " data=time_shifted_df,\n", + " treatment='V5_-1',\n", + " outcome=target_node,\n", + " graph = unrolled_graph\n", + ")\n", + "\n", + "identified_estimand = model.identify_effect()\n", + "\n", + "estimate = model.estimate_effect(identified_estimand,\n", + " method_name=\"backdoor.linear_regression\",\n", + " test_significance=True)\n", + "\n", + "\n", + "print(estimate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Importing temporal causal graph from Tigramite library\n", + "\n", + "Tigramite is a popular temporal causal discovery library. In this section, we highlight how the causal graph can be obtained by applying PCMCI+ algorithm from tigramite and imported into DoWhy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install tigramite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tigramite\n", + "import tigramite.data_processing as pp\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "dataframe = dataframe.astype(float)\n", + "var_names = dataframe.columns\n", + "# convert the dataframe values to float\n", + "dataframe = pp.DataFrame(dataframe.values, var_names=var_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tigramite import plotting as tp\n", + "tp.plot_timeseries(dataframe, figsize=(15, 5)); plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tigramite.pcmci import PCMCI\n", + "from tigramite.independence_tests.parcorr import ParCorr\n", + "import numpy as np\n", + "parcorr = ParCorr(significance='analytic')\n", + "pcmci = PCMCI(\n", + " dataframe=dataframe, \n", + " cond_ind_test=parcorr,\n", + " verbosity=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "correlations = pcmci.run_bivci(tau_max=3, val_only=True)['val_matrix']\n", + "matrix_lags = np.argmax(np.abs(correlations), axis=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tau_max = 3\n", + "pc_alpha = None\n", + "pcmci.verbosity = 2\n", + "\n", + "results = pcmci.run_pcmciplus(tau_min=0, tau_max=tau_max, pc_alpha=pc_alpha)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dowhy.utils.timeseries import create_graph_from_networkx_array\n", + "\n", + "graph = create_graph_from_networkx_array(results['graph'], var_names)\n", + "\n", + "plot(graph)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/dowhy/timeseries/temporal_shift.py b/dowhy/timeseries/temporal_shift.py new file mode 100644 index 000000000..82c933cdf --- /dev/null +++ b/dowhy/timeseries/temporal_shift.py @@ -0,0 +1,102 @@ +from collections import deque +from typing import List, Optional, Tuple + +import networkx as nx +import pandas as pd + + +def add_lagged_edges(graph: nx.DiGraph, start_node: str) -> nx.DiGraph: + """ + Perform a reverse BFS starting from the node and proceed to parents level-wise, + adding edges from the ancestor to the current node with the accumulated time lag if it does not already exist. + Additionally, create lagged nodes for each time lag encountered. + + :param graph: The directed graph object. + :type graph: networkx.DiGraph + :param start_node: The node from which to start the reverse BFS. + :type start_node: string + :return: A new graph with added edges based on accumulated time lags and lagged nodes. + :rtype: networkx.DiGraph + """ + new_graph = nx.DiGraph() + queue = deque([start_node]) + lagged_node_mapping = {} # Maps original nodes to their corresponding lagged nodes + + while queue: + current_node = queue.popleft() + + for parent in graph.predecessors(current_node): + edge_data = graph.get_edge_data(parent, current_node) + if "time_lag" in edge_data: + parent_time_lag = edge_data["time_lag"] + + # Ensure parent_time_lag is in tuple form + if not isinstance(parent_time_lag, tuple): + parent_time_lag = (parent_time_lag,) + + for lag in parent_time_lag: + # Find or create the lagged node for the current node + if current_node in lagged_node_mapping: + lagged_nodes = lagged_node_mapping[current_node] + else: + lagged_nodes = set() + lagged_nodes.add(f"{current_node}_0") + new_graph.add_node(f"{current_node}_0") + lagged_node_mapping[current_node] = lagged_nodes + + # For each lagged node, create new time-lagged parent nodes and add edges + new_lagged_nodes = set() + for lagged_node in lagged_nodes: + total_lag = -int(lagged_node.split("_")[-1]) + lag + new_lagged_parent_node = f"{parent}_{-total_lag}" + new_lagged_nodes.add(new_lagged_parent_node) + + if not new_graph.has_node(new_lagged_parent_node): + new_graph.add_node(new_lagged_parent_node) + + new_graph.add_edge(new_lagged_parent_node, lagged_node) + + # Add the parent to the queue for further exploration + queue.append(parent) + + # append the lagged nodes + if parent in lagged_node_mapping: + lagged_node_mapping[parent] = lagged_node_mapping[parent].union(new_lagged_nodes) + else: + lagged_node_mapping[parent] = new_lagged_nodes + + for original_node, lagged_nodes in lagged_node_mapping.items(): + sorted_lagged_nodes = sorted(lagged_nodes, key=lambda x: int(x.split("_")[-1])) + for i in range(len(sorted_lagged_nodes) - 1): + lesser_lagged_node = sorted_lagged_nodes[i] + more_lagged_node = sorted_lagged_nodes[i + 1] + new_graph.add_edge(lesser_lagged_node, more_lagged_node) + + return new_graph + + +def shift_columns_by_lag_using_unrolled_graph(df: pd.DataFrame, unrolled_graph: nx.DiGraph) -> pd.DataFrame: + """ + Given a dataframe and an unrolled graph, this function shifts the columns in the dataframe by the corresponding time lags mentioned in the node names of the unrolled graph, + creating a new unique column for each shifted version. + + :param df: The dataframe to shift. + :type df: pandas.DataFrame + :param unrolled_graph: The unrolled graph with nodes containing time lags in their names. + :type unrolled_graph: networkx.DiGraph + :return: The dataframe with the columns shifted by the corresponding time lags. + :rtype: pandas.DataFrame + """ + new_df = pd.DataFrame() + for node in unrolled_graph.nodes: + if "_" in node: + base_node, lag_str = node.rsplit("_", 1) + try: + lag = -int(lag_str) + if base_node in df.columns: + new_column_name = f"{base_node}_{-lag}" + new_df[new_column_name] = df[base_node].shift(lag, axis=0, fill_value=0) + except ValueError: + print(f"Warning: Cannot extract lag from node name {node}. Expected format 'baseNode_lag'") + + return new_df diff --git a/dowhy/utils/plotting.py b/dowhy/utils/plotting.py index 2f67001cc..44779f276 100644 --- a/dowhy/utils/plotting.py +++ b/dowhy/utils/plotting.py @@ -199,3 +199,17 @@ def _plot_as_pyplot_figure(pygraphviz_graph: Any, figure_size: Optional[Tuple[in if figure_size is not None: plt.rcParams["figure.figsize"] = org_fig_size + + +def pretty_print_graph(graph: nx.DiGraph) -> None: + """ + Pretty print the graph edges with time lags. + + :param graph: The networkx graph. + :type graph: networkx.Graph + :return: None + :rtype: None + """ + print("\nGraph edges with time lags:") + for edge in graph.edges(data=True): + print(f"{edge[0]} -> {edge[1]} with time-lagged dependency {edge[2]['time_lag']}") diff --git a/dowhy/utils/timeseries.py b/dowhy/utils/timeseries.py new file mode 100644 index 000000000..a815736d2 --- /dev/null +++ b/dowhy/utils/timeseries.py @@ -0,0 +1,233 @@ +import networkx as nx +import numpy as np +import pandas as pd + + +def create_graph_from_user() -> nx.DiGraph: + """ + Creates a directed graph based on user input from the console. + + The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph. + Each edge can contain multiple time lags, therefore each combination of (node1,node2,time_lag) must be input individually by the user. + + The user is prompted to enter edges one by one in the format 'node1 node2 time_lag', + where 'node1' and 'node2' are the nodes connected by the edge, and 'time_lag' is a numerical + value representing the weight of the edge. The user should enter 'done' to finish inputting edges. + + :return: A directed graph created from the user's input. + :rtype: nx.DiGraph + + Example user input: + Enter an edge: A B 4 + Enter an edge: B C 2 + Enter an edge: done + """ + graph = nx.DiGraph() + + print("Enter the graph as a list of edges with time lags. Enter 'done' when you are finished.") + print("Each edge should be entered in the format 'node1 node2 time_lag'. For example: 'A B 4'") + + while True: + edge = input("Enter an edge: ") + if edge.lower() == "done": + break + edge = edge.split() + if len(edge) != 3: + print("Invalid edge. Please enter an edge in the format 'node1 node2 time_lag'.") + continue + node1, node2, time_lag = edge + try: + time_lag = int(time_lag) + except ValueError: + print("Invalid weight. Please enter a numerical value for the time_lag.") + continue + + # Check if the edge already exists + if graph.has_edge(node1, node2): + # If the edge exists, append the time_lag to the existing tuple + current_time_lag = graph[node1][node2]["time_lag"] + if isinstance(current_time_lag, tuple): + graph[node1][node2]["time_lag"] = current_time_lag + (time_lag,) + else: + graph[node1][node2]["time_lag"] = (current_time_lag, time_lag) + else: + # If the edge does not exist, create a new edge with a tuple containing the time_lag + graph.add_edge(node1, node2, time_lag=(time_lag,)) + + return graph + + +def create_graph_from_csv(file_path: str) -> nx.DiGraph: + """ + Creates a directed graph from a CSV file. + + The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph. + Each edge can contain multiple time lags, therefore each combination of (node1,node2,time_lag) must be input individually in the CSV file. + + The CSV file should have at least three columns: 'node1', 'node2', and 'time_lag'. + Each row represents an edge from 'node1' to 'node2' with a 'time_lag' attribute. + + :param file_path: The path to the CSV file. + :type file_path: str + :return: A directed graph created from the CSV file. + :rtype: nx.DiGraph + + Example: + Example CSV content: + + .. code-block:: csv + + node1,node2,time_lag + A,B,5 + B,C,2 + A,C,7 + """ + # Read the CSV file into a DataFrame + df = pd.read_csv(file_path) + + # Initialize an empty directed graph + graph = nx.DiGraph() + + # Add edges with time lag to the graph + for index, row in df.iterrows(): + # Add validation for the time lag column to be a number + try: + time_lag = int(row["time_lag"]) + except ValueError: + print( + "Invalid weight. Please enter a numerical value for the time_lag for the edge between {} and {}.".format( + row["node1"], row["node2"] + ) + ) + return None + + # Check if the edge already exists + if graph.has_edge(row["node1"], row["node2"]): + # If the edge exists, append the time_lag to the existing tuple + current_time_lag = graph[row["node1"]][row["node2"]]["time_lag"] + if isinstance(current_time_lag, tuple): + graph[row["node1"]][row["node2"]]["time_lag"] = current_time_lag + (time_lag,) + else: + graph[row["node1"]][row["node2"]]["time_lag"] = (current_time_lag, time_lag) + else: + # If the edge does not exist, create a new edge with a tuple containing the time_lag + graph.add_edge(row["node1"], row["node2"], time_lag=(time_lag,)) + + return graph + + +def create_graph_from_dot_format(file_path: str) -> nx.DiGraph: + """ + Creates a directed graph from a DOT file and ensures it is a DiGraph. + + The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph. + Each edge can contain multiple valid time lags. + + The DOT file should contain a graph in DOT format. + + :param file_path: The path to the DOT file. + :type file_path: str + :return: A directed graph (DiGraph) created from the DOT file. + :rtype: nx.DiGraph + """ + # Read the DOT file into a MultiDiGraph + multi_graph = nx.drawing.nx_agraph.read_dot(file_path) + + # Initialize a new DiGraph + graph = nx.DiGraph() + + # Iterate over edges of the MultiDiGraph + for u, v, data in multi_graph.edges(data=True): + if "label" in data: + try: + # Convert the label to a tuple of time lags + time_lag_tuple = tuple(map(int, data["label"].strip("()").split(","))) + + if graph.has_edge(u, v): + existing_data = graph.get_edge_data(u, v) + if "time_lag" in existing_data: + # Merge the existing time lags with the new ones + existing_time_lags = existing_data["time_lag"] + new_time_lags = existing_time_lags + time_lag_tuple + # Remove duplicates by converting to a set and back to a tuple + graph[u][v]["time_lag"] = tuple(set(new_time_lags)) + else: + graph[u][v]["time_lag"] = time_lag_tuple + else: + graph.add_edge(u, v, time_lag=time_lag_tuple) + + except ValueError: + print(f"Invalid weight for the edge between {u} and {v}.") + return None + + return graph + + +def create_graph_from_networkx_array(array: np.ndarray, var_names: list) -> nx.DiGraph: + """ + Create a NetworkX directed graph from a numpy array with time lag information. + + The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph. + Each edge can contain multiple valid time lags. + + The resulting graph will be a directed graph with edge attributes indicating + the type of link based on the array values. + + :param array: A numpy array of shape (n, n, tau) representing the causal links. + :type array: np.ndarray + :param var_names: A list of variable names. + :type var_names: list + :return: A directed graph with edge attributes based on the array values. + :rtype: nx.DiGraph + """ + n = array.shape[0] # Number of variables + assert n == array.shape[1], "The array must be square." + tau = array.shape[2] # Number of time lags + + # Initialize a directed graph + graph = nx.DiGraph() + + # Add nodes with names + graph.add_nodes_from(var_names) + + # Iterate over all pairs of nodes + for i in range(n): + for j in range(n): + if i == j: + continue # Skip self-loops + + for t in range(tau): + # Check for directed links + if array[i, j, t] == "-->": + if graph.has_edge(var_names[i], var_names[j]): + # Append the time lag to the existing tuple + current_time_lag = graph[var_names[i]][var_names[j]].get("time_lag", ()) + graph[var_names[i]][var_names[j]]["time_lag"] = current_time_lag + (t,) + else: + # Create a new edge with a tuple containing the time lag + graph.add_edge(var_names[i], var_names[j], time_lag=(t,)) + + elif array[i, j, t] == "<--": + if graph.has_edge(var_names[j], var_names[i]): + # Append the time lag to the existing tuple + current_time_lag = graph[var_names[j]][var_names[i]].get("time_lag", ()) + graph[var_names[j]][var_names[i]]["time_lag"] = current_time_lag + (t,) + else: + # Create a new edge with a tuple containing the time lag + graph.add_edge(var_names[j], var_names[i], time_lag=(t,)) + + elif array[i, j, t] == "o-o": + raise ValueError( + "Unsupported link type 'o-o' found between {} and {} at lag {}.".format( + var_names[i], var_names[j], t + ) + ) + + elif array[i, j, t] == "x-x": + raise ValueError( + "Unsupported link type 'x-x' found between {} and {} at lag {}.".format( + var_names[i], var_names[j], t + ) + ) + + return graph diff --git a/tests/timeseries/test_temporal_causal_graph_creation.py b/tests/timeseries/test_temporal_causal_graph_creation.py new file mode 100644 index 000000000..40a82b56b --- /dev/null +++ b/tests/timeseries/test_temporal_causal_graph_creation.py @@ -0,0 +1,227 @@ +import unittest +from io import StringIO + +import networkx as nx +import numpy as np +import pandas as pd + +from dowhy.utils.timeseries import create_graph_from_csv, create_graph_from_dot_format, create_graph_from_networkx_array + + +class TestCreateGraphFromCSV(unittest.TestCase): + + def test_basic_functionality(self): + csv_content = """node1,node2,time_lag +A,B,5 +B,C,2 +A,C,7""" + df = pd.read_csv(StringIO(csv_content)) + df.to_csv("test.csv", index=False) + + graph = create_graph_from_csv("test.csv") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 3) + self.assertEqual(graph["A"]["B"]["time_lag"], (5,)) + self.assertEqual(graph["B"]["C"]["time_lag"], (2,)) + self.assertEqual(graph["A"]["C"]["time_lag"], (7,)) + + def test_multiple_time_lags(self): + csv_content = """node1,node2,time_lag +A,B,5 +A,B,3 +A,C,7""" + df = pd.read_csv(StringIO(csv_content)) + df.to_csv("test.csv", index=False) + + graph = create_graph_from_csv("test.csv") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 2) + self.assertEqual(graph["A"]["B"]["time_lag"], (5, 3)) + self.assertEqual(graph["A"]["C"]["time_lag"], (7,)) + + def test_invalid_time_lag(self): + csv_content = """node1,node2,time_lag +A,B,five +B,C,2 +A,C,7""" + df = pd.read_csv(StringIO(csv_content)) + df.to_csv("test.csv", index=False) + + graph = create_graph_from_csv("test.csv") + + self.assertIsNone(graph) + + def test_empty_csv(self): + csv_content = """node1,node2,time_lag""" + df = pd.read_csv(StringIO(csv_content)) + df.to_csv("test.csv", index=False) + + graph = create_graph_from_csv("test.csv") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 0) + + def test_self_loop(self): + csv_content = """node1,node2,time_lag +A,A,5""" + df = pd.read_csv(StringIO(csv_content)) + df.to_csv("test.csv", index=False) + + graph = create_graph_from_csv("test.csv") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 1) + self.assertEqual(graph["A"]["A"]["time_lag"], (5,)) + + +class TestCreateGraphFromDotFormat(unittest.TestCase): + + def setUp(self): + # Helper method to create a DOT file from string content + def create_dot_file(dot_content, file_name="test.dot"): + with open(file_name, "w") as f: + f.write(dot_content) + + self.create_dot_file = create_dot_file + + def test_basic_functionality(self): + dot_content = """digraph G { +A -> B [label="(5)"]; +B -> C [label="(2)"]; +A -> C [label="(7)"]; +}""" + self.create_dot_file(dot_content) + graph = create_graph_from_dot_format("test.dot") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 3) + self.assertEqual(graph["A"]["B"]["time_lag"], (5,)) + self.assertEqual(graph["B"]["C"]["time_lag"], (2,)) + self.assertEqual(graph["A"]["C"]["time_lag"], (7,)) + + def test_multiple_time_lags(self): + dot_content = """digraph G { +A -> B [label="(5,3)"]; +A -> C [label="(7)"]; +}""" + self.create_dot_file(dot_content) + graph = create_graph_from_dot_format("test.dot") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 2) + self.assertEqual(graph["A"]["B"]["time_lag"], (5, 3)) + self.assertEqual(graph["A"]["C"]["time_lag"], (7,)) + + def test_invalid_time_lag(self): + dot_content = """digraph G { +A -> B [label="(five)"]; +B -> C [label="(2)"]; +A -> C [label="(7)"]; +}""" + self.create_dot_file(dot_content) + graph = create_graph_from_dot_format("test.dot") + + self.assertIsNone(graph) + + def test_empty_dot(self): + dot_content = """digraph G {}""" + self.create_dot_file(dot_content) + graph = create_graph_from_dot_format("test.dot") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 0) + + def test_self_loop(self): + dot_content = """digraph G { +A -> A [label="(5)"]; +}""" + self.create_dot_file(dot_content) + graph = create_graph_from_dot_format("test.dot") + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 1) + self.assertEqual(graph["A"]["A"]["time_lag"], (5,)) + + +class TestCreateGraphFromNetworkxArray(unittest.TestCase): + + def test_basic_functionality(self): + array = np.zeros((3, 3, 2), dtype=object) + array[0, 1, 0] = "-->" + array[1, 2, 0] = "-->" + array[0, 2, 1] = "-->" + + var_names = ["X1", "X2", "X3"] + + graph = create_graph_from_networkx_array(array, var_names) + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 3) + self.assertTrue(graph.has_edge("X1", "X2")) + self.assertEqual(graph["X1"]["X2"]["time_lag"], (0,)) + self.assertTrue(graph.has_edge("X1", "X3")) + self.assertEqual(graph["X1"]["X3"]["time_lag"], (1,)) + self.assertTrue(graph.has_edge("X2", "X3")) + self.assertEqual(graph["X2"]["X3"]["time_lag"], (0,)) + + def test_multiple_time_lags(self): + array = np.zeros((3, 3, 3), dtype=object) + array[0, 1, 0] = "-->" + array[0, 1, 1] = "-->" + array[1, 2, 2] = "-->" + + var_names = ["X1", "X2", "X3"] + + graph = create_graph_from_networkx_array(array, var_names) + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 2) + self.assertTrue(graph.has_edge("X1", "X2")) + self.assertEqual(graph["X1"]["X2"]["time_lag"], (0, 1)) + self.assertTrue(graph.has_edge("X2", "X3")) + self.assertEqual(graph["X2"]["X3"]["time_lag"], (2,)) + + def test_invalid_link_type_oo(self): + array = np.zeros((2, 2, 1), dtype=object) + array[0, 1, 0] = "o-o" + + var_names = ["X1", "X2"] + + with self.assertRaises(ValueError): + create_graph_from_networkx_array(array, var_names) + + def test_invalid_link_type_xx(self): + array = np.zeros((2, 2, 1), dtype=object) + array[0, 1, 0] = "x-x" + + var_names = ["X1", "X2"] + + with self.assertRaises(ValueError): + create_graph_from_networkx_array(array, var_names) + + def test_empty_array(self): + array = np.zeros((0, 0, 0), dtype=object) + var_names = [] + + graph = create_graph_from_networkx_array(array, var_names) + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.nodes()), 0) + self.assertEqual(len(graph.edges()), 0) + + def test_self_loop(self): + array = np.zeros((2, 2, 1), dtype=object) + array[0, 0, 0] = "-->" + + var_names = ["X1", "X2"] + + graph = create_graph_from_networkx_array(array, var_names) + + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.edges()), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/timeseries/test_temporal_shift.py b/tests/timeseries/test_temporal_shift.py new file mode 100644 index 000000000..b663609b4 --- /dev/null +++ b/tests/timeseries/test_temporal_shift.py @@ -0,0 +1,142 @@ +import unittest +from typing import List, Optional + +import networkx as nx +import pandas as pd +from pandas.testing import assert_frame_equal + +from dowhy.timeseries.temporal_shift import add_lagged_edges, shift_columns_by_lag_using_unrolled_graph + + +class TestAddLaggedEdges(unittest.TestCase): + + def test_basic_functionality(self): + graph = nx.DiGraph() + graph.add_edge("A", "B", time_lag=(1,)) + graph.add_edge("B", "C", time_lag=(2,)) + + new_graph = add_lagged_edges(graph, "C") + + self.assertIsInstance(new_graph, nx.DiGraph) + self.assertTrue(new_graph.has_node("C_0")) + self.assertTrue(new_graph.has_node("B_-2")) + self.assertTrue(new_graph.has_node("A_-3")) + self.assertTrue(new_graph.has_edge("B_-2", "C_0")) + self.assertTrue(new_graph.has_edge("A_-3", "B_-2")) + + def test_multiple_time_lags(self): + graph = nx.DiGraph() + graph.add_edge("A", "B", time_lag=(1, 2)) + graph.add_edge("B", "C", time_lag=(1, 3)) + + new_graph = add_lagged_edges(graph, "C") + + self.assertIsInstance(new_graph, nx.DiGraph) + self.assertTrue(new_graph.has_node("C_0")) + self.assertTrue(new_graph.has_node("B_-1")) + self.assertTrue(new_graph.has_node("B_-3")) + self.assertTrue(new_graph.has_node("A_-2")) + self.assertTrue(new_graph.has_node("A_-3")) + self.assertTrue(new_graph.has_node("A_-4")) + self.assertTrue(new_graph.has_node("A_-5")) + self.assertTrue(new_graph.has_edge("B_-1", "C_0")) + self.assertTrue(new_graph.has_edge("B_-3", "C_0")) + self.assertTrue(new_graph.has_edge("A_-2", "B_-1")) + self.assertTrue(new_graph.has_edge("A_-4", "B_-3")) + self.assertTrue(new_graph.has_edge("B_-3", "B_-1")) + self.assertTrue(new_graph.has_edge("A_-5", "A_-4")) + self.assertTrue(new_graph.has_edge("A_-4", "A_-3")) + self.assertTrue(new_graph.has_edge("A_-3", "A_-2")) + + def test_complex_graph_structure(self): + graph = nx.DiGraph() + graph.add_edge("A", "B", time_lag=(1,)) + graph.add_edge("B", "C", time_lag=(2,)) + graph.add_edge("A", "C", time_lag=(3,)) + + new_graph = add_lagged_edges(graph, "C") + + self.assertIsInstance(new_graph, nx.DiGraph) + self.assertTrue(new_graph.has_node("C_0")) + self.assertTrue(new_graph.has_node("B_-2")) + self.assertTrue(new_graph.has_node("A_-3")) + self.assertTrue(new_graph.has_edge("B_-2", "C_0")) + self.assertTrue(new_graph.has_edge("A_-3", "B_-2")) + self.assertTrue(new_graph.has_edge("A_-3", "C_0")) + + def test_no_time_lag(self): + graph = nx.DiGraph() + graph.add_edge("A", "B") + graph.add_edge("B", "C") + + new_graph = add_lagged_edges(graph, "C") + + self.assertIsInstance(new_graph, nx.DiGraph) + self.assertEqual(len(new_graph.nodes()), 0) + self.assertEqual(len(new_graph.edges()), 0) + + +class TestShiftColumnsByLagUsingUnrolledGraph(unittest.TestCase): + + def test_basic_functionality(self): + df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}) + + unrolled_graph = nx.DiGraph() + unrolled_graph.add_nodes_from(["A_0", "A_-1", "B_0", "B_-2"]) + + expected_df = pd.DataFrame( + {"A_0": [1, 2, 3, 4, 5], "A_-1": [0, 1, 2, 3, 4], "B_0": [5, 4, 3, 2, 1], "B_-2": [0, 0, 5, 4, 3]} + ) + + result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph) + + assert_frame_equal(result_df, expected_df) + + def test_complex_graph_structure(self): + df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1], "C": [1, 3, 5, 7, 9]}) + + unrolled_graph = nx.DiGraph() + unrolled_graph.add_nodes_from(["A_0", "A_-1", "B_0", "B_-2", "C_-1", "C_-3"]) + + expected_df = pd.DataFrame( + { + "A_0": [1, 2, 3, 4, 5], + "A_-1": [0, 1, 2, 3, 4], + "B_0": [5, 4, 3, 2, 1], + "B_-2": [0, 0, 5, 4, 3], + "C_-1": [0, 1, 3, 5, 7], + "C_-3": [0, 0, 0, 1, 3], + } + ) + + result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph) + + assert_frame_equal(result_df, expected_df) + + def test_invalid_node_format(self): + df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}) + + unrolled_graph = nx.DiGraph() + unrolled_graph.add_nodes_from(["A_0", "B_invalid"]) + + expected_df = pd.DataFrame({"A_0": [1, 2, 3, 4, 5]}) + + result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph) + + assert_frame_equal(result_df, expected_df) + + def test_non_matching_columns(self): + df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]}) + + unrolled_graph = nx.DiGraph() + unrolled_graph.add_nodes_from(["C_0", "C_-1"]) + + expected_df = pd.DataFrame() + + result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph) + + assert_frame_equal(result_df, expected_df) + + +if __name__ == "__main__": + unittest.main()