Effect estimation over timeseries data (#1218)

* Library functions for temporal causal functionality

* shifting plotter function

* printing graph: best practices

* added docstrings

* moved datasets

* updated tutorial notebook

* sphinx documentation

* updated shifting columns with 0,1,..,max_lag

* support for dot format

* tigramite support

* updated filter to be a hidden function

* black and isort utils

* black and isort timeseries

* updated notebook text

Signed-off-by: Amit Sharma <amit_sharma@live.com>

* integer range fix

* correction in timestamp : notebook text

* time lagged causal estimation

* removed cell outputs

* find ancestors

* include ancestors in notebook

* formatting changes

* comments : notebook

* multiple time lags : csv graph'

* multiple time lags

* unrolled graph using bfs

* cleanup of functions

* removed find parents and ancestors

* tests for causal graph creation

* tests for adding lagged edges

* tests for shifting columns


* tigramite dependency added

---------

Signed-off-by: Amit Sharma <amit_sharma@live.com>
Co-authored-by: Amit Sharma <amit_sharma@live.com>
This commit is contained in:
Ashutosh Srivastava 2024-08-03 22:54:02 +05:30 коммит произвёл GitHub
Родитель e783e37db0
Коммит becbf7f502
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 1055 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,15 @@
V1,V2,V3,V4,V5,V6,V7
1,2,3,4,5,6,7
2,3,4,5,6,7,8
3,4,5,6,7,8,9
4,5,6,7,8,9,10
0,1,5,7,8,9,7
3,5,4,1,2,6,5
6,7,1,2,4,5,9
12,3,5,7,3,8,9
3,2,1,6,3,8,9
4,6,3,5,8,9,1
3,5,9,6,2,1,3
5,2,6,8,11,3,4
2,2,4,1,1,4,6
5,6,4,3,4,6,2
1 V1 V2 V3 V4 V5 V6 V7
2 1 2 3 4 5 6 7
3 2 3 4 5 6 7 8
4 3 4 5 6 7 8 9
5 4 5 6 7 8 9 10
6 0 1 5 7 8 9 7
7 3 5 4 1 2 6 5
8 6 7 1 2 4 5 9
9 12 3 5 7 3 8 9
10 3 2 1 6 3 8 9
11 4 6 3 5 8 9 1
12 3 5 9 6 2 1 3
13 5 2 6 8 11 3 4
14 2 2 4 1 1 4 6
15 5 6 4 3 4 6 2

Просмотреть файл

@ -0,0 +1,8 @@
node1,node2,time_lag
V1,V2,3
V2,V3,4
V5,V6,1
V4,V7,4
V4,V5,2
V7,V6,3
V7,V6,5
1 node1 node2 time_lag
2 V1 V2 3
3 V2 V3 4
4 V5 V6 1
5 V4 V7 4
6 V4 V5 2
7 V7 V6 3
8 V7 V6 5

Просмотреть файл

@ -0,0 +1,8 @@
digraph G {
V1 -> V2 [label="(3)"];
V2 -> V3 [label="(4)"];
V5 -> V6 [label="(1)"];
V4 -> V7 [label="(4)"];
V4 -> V5 [label="(2)"];
V7 -> V6 [label="(3, 5)"];
}

Просмотреть файл

@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Effect inference with timeseries data\n",
"\n",
"In this notebook, we will look at an example of causal effect inference from timeseries data. We will use DoWhy's functionality to add temporal dependencies to a causal graph and estimate causal effect based on the augmented graph. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"import pandas as pd\n",
"from dowhy.utils.timeseries import create_graph_from_csv,create_graph_from_user\n",
"from dowhy.utils.plotting import plot, pretty_print_graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading timeseries data and causal graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_path=\"../datasets/temporal_dataset.csv\"\n",
"\n",
"dataframe=pd.read_csv(dataset_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In temporal causal inference, accurately estimating causal effects often requires accounting for time lags between nodes in a graph. For instance, if $node_1$ influences $node_2$ with a time lag of 5 timestamps, we represent this dependency as $node_1^{t-5}$ -> $node_2^{t}$.\n",
"\n",
"We can provide the causal graph as a networkx DAG or as a dot file. The edge attributes should mention the exact `time_lag` that is associated with each edge (if any)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.utils.timeseries import create_graph_from_dot_format\n",
"\n",
"file_path = \"../datasets/temporal_graph.dot\"\n",
"\n",
"graph = create_graph_from_dot_format(file_path)\n",
"plot(graph)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also create a csv file with the edges in the temporal graph. The columns in the csv are node1, node2, time_lag which represents an directed edge node1 -> node2 with the time lag of time_lag. Let us consider the following graph as the input:\n",
"\n",
"| node1 | node2 | time_lag |\n",
"|--------|--------|----------|\n",
"| V1 | V2 | 3 |\n",
"| V2 | V3 | 4 |\n",
"| V5 | V6 | 1 |\n",
"| V4 | V7 | 4 |\n",
"| V4 | V5 | 2 |\n",
"| V7 | V6 | 3 |\n",
"| V7 | V6 | 5 |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Input a csv file with the edges in the graph with the columns: node_1,node_2,time_lag\n",
"file_path = \"../datasets/temporal_graph.csv\"\n",
"\n",
"# Create the graph from the CSV file\n",
"graph = create_graph_from_csv(file_path)\n",
"plot(graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset Shifting and Filtering\n",
"\n",
"To prepare the dataset for temporal causal inference, we need to shift the columns by the given time lag.\n",
"\n",
"For example, in the causal graph above, $node_1^{t-5}$ -> $node_2^{t}$ with a lag of 5. When considering $node_2$ as the target node, the data for $node_1$ should be shifted down by 5 timestamps. This adjustment ensures that the edge $node_1$ -> $node_2$ accurately represents the lagged dependency. Shifting the data in this manner creates additional columns and allows downstream estimators to acccess the correct values in the same row of a dataframe. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.timeseries.temporal_shift import shift_columns_by_lag_using_unrolled_graph, add_lagged_edges"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# the outcome node for which effect estimation has to be done, node:6\n",
"target_node = 'V6'\n",
"unrolled_graph = add_lagged_edges(graph, target_node)\n",
"plot(unrolled_graph)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time_shifted_df = shift_columns_by_lag_using_unrolled_graph(dataframe, unrolled_graph)\n",
"time_shifted_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Causal Effect Estimation\n",
"\n",
"Once you have the new dataframe, causal effect estimation can be performed on the target node with respect to the action nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_node = 'V6_0'\n",
"# include all the treatments\n",
"treatment_columns = list(time_shifted_df.columns)\n",
"treatment_columns.remove(target_node)\n",
"treatment_columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# perform causal effect estimation on this new dataset\n",
"import dowhy\n",
"from dowhy import CausalModel\n",
"\n",
"model = CausalModel(\n",
" data=time_shifted_df,\n",
" treatment='V5_-1',\n",
" outcome=target_node,\n",
" graph = unrolled_graph\n",
")\n",
"\n",
"identified_estimand = model.identify_effect()\n",
"\n",
"estimate = model.estimate_effect(identified_estimand,\n",
" method_name=\"backdoor.linear_regression\",\n",
" test_significance=True)\n",
"\n",
"\n",
"print(estimate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Importing temporal causal graph from Tigramite library\n",
"\n",
"Tigramite is a popular temporal causal discovery library. In this section, we highlight how the causal graph can be obtained by applying PCMCI+ algorithm from tigramite and imported into DoWhy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install tigramite"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tigramite\n",
"import tigramite.data_processing as pp\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"dataframe = dataframe.astype(float)\n",
"var_names = dataframe.columns\n",
"# convert the dataframe values to float\n",
"dataframe = pp.DataFrame(dataframe.values, var_names=var_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tigramite import plotting as tp\n",
"tp.plot_timeseries(dataframe, figsize=(15, 5)); plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tigramite.pcmci import PCMCI\n",
"from tigramite.independence_tests.parcorr import ParCorr\n",
"import numpy as np\n",
"parcorr = ParCorr(significance='analytic')\n",
"pcmci = PCMCI(\n",
" dataframe=dataframe, \n",
" cond_ind_test=parcorr,\n",
" verbosity=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"correlations = pcmci.run_bivci(tau_max=3, val_only=True)['val_matrix']\n",
"matrix_lags = np.argmax(np.abs(correlations), axis=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tau_max = 3\n",
"pc_alpha = None\n",
"pcmci.verbosity = 2\n",
"\n",
"results = pcmci.run_pcmciplus(tau_min=0, tau_max=tau_max, pc_alpha=pc_alpha)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.utils.timeseries import create_graph_from_networkx_array\n",
"\n",
"graph = create_graph_from_networkx_array(results['graph'], var_names)\n",
"\n",
"plot(graph)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -0,0 +1,102 @@
from collections import deque
from typing import List, Optional, Tuple
import networkx as nx
import pandas as pd
def add_lagged_edges(graph: nx.DiGraph, start_node: str) -> nx.DiGraph:
"""
Perform a reverse BFS starting from the node and proceed to parents level-wise,
adding edges from the ancestor to the current node with the accumulated time lag if it does not already exist.
Additionally, create lagged nodes for each time lag encountered.
:param graph: The directed graph object.
:type graph: networkx.DiGraph
:param start_node: The node from which to start the reverse BFS.
:type start_node: string
:return: A new graph with added edges based on accumulated time lags and lagged nodes.
:rtype: networkx.DiGraph
"""
new_graph = nx.DiGraph()
queue = deque([start_node])
lagged_node_mapping = {} # Maps original nodes to their corresponding lagged nodes
while queue:
current_node = queue.popleft()
for parent in graph.predecessors(current_node):
edge_data = graph.get_edge_data(parent, current_node)
if "time_lag" in edge_data:
parent_time_lag = edge_data["time_lag"]
# Ensure parent_time_lag is in tuple form
if not isinstance(parent_time_lag, tuple):
parent_time_lag = (parent_time_lag,)
for lag in parent_time_lag:
# Find or create the lagged node for the current node
if current_node in lagged_node_mapping:
lagged_nodes = lagged_node_mapping[current_node]
else:
lagged_nodes = set()
lagged_nodes.add(f"{current_node}_0")
new_graph.add_node(f"{current_node}_0")
lagged_node_mapping[current_node] = lagged_nodes
# For each lagged node, create new time-lagged parent nodes and add edges
new_lagged_nodes = set()
for lagged_node in lagged_nodes:
total_lag = -int(lagged_node.split("_")[-1]) + lag
new_lagged_parent_node = f"{parent}_{-total_lag}"
new_lagged_nodes.add(new_lagged_parent_node)
if not new_graph.has_node(new_lagged_parent_node):
new_graph.add_node(new_lagged_parent_node)
new_graph.add_edge(new_lagged_parent_node, lagged_node)
# Add the parent to the queue for further exploration
queue.append(parent)
# append the lagged nodes
if parent in lagged_node_mapping:
lagged_node_mapping[parent] = lagged_node_mapping[parent].union(new_lagged_nodes)
else:
lagged_node_mapping[parent] = new_lagged_nodes
for original_node, lagged_nodes in lagged_node_mapping.items():
sorted_lagged_nodes = sorted(lagged_nodes, key=lambda x: int(x.split("_")[-1]))
for i in range(len(sorted_lagged_nodes) - 1):
lesser_lagged_node = sorted_lagged_nodes[i]
more_lagged_node = sorted_lagged_nodes[i + 1]
new_graph.add_edge(lesser_lagged_node, more_lagged_node)
return new_graph
def shift_columns_by_lag_using_unrolled_graph(df: pd.DataFrame, unrolled_graph: nx.DiGraph) -> pd.DataFrame:
"""
Given a dataframe and an unrolled graph, this function shifts the columns in the dataframe by the corresponding time lags mentioned in the node names of the unrolled graph,
creating a new unique column for each shifted version.
:param df: The dataframe to shift.
:type df: pandas.DataFrame
:param unrolled_graph: The unrolled graph with nodes containing time lags in their names.
:type unrolled_graph: networkx.DiGraph
:return: The dataframe with the columns shifted by the corresponding time lags.
:rtype: pandas.DataFrame
"""
new_df = pd.DataFrame()
for node in unrolled_graph.nodes:
if "_" in node:
base_node, lag_str = node.rsplit("_", 1)
try:
lag = -int(lag_str)
if base_node in df.columns:
new_column_name = f"{base_node}_{-lag}"
new_df[new_column_name] = df[base_node].shift(lag, axis=0, fill_value=0)
except ValueError:
print(f"Warning: Cannot extract lag from node name {node}. Expected format 'baseNode_lag'")
return new_df

Просмотреть файл

@ -199,3 +199,17 @@ def _plot_as_pyplot_figure(pygraphviz_graph: Any, figure_size: Optional[Tuple[in
if figure_size is not None:
plt.rcParams["figure.figsize"] = org_fig_size
def pretty_print_graph(graph: nx.DiGraph) -> None:
"""
Pretty print the graph edges with time lags.
:param graph: The networkx graph.
:type graph: networkx.Graph
:return: None
:rtype: None
"""
print("\nGraph edges with time lags:")
for edge in graph.edges(data=True):
print(f"{edge[0]} -> {edge[1]} with time-lagged dependency {edge[2]['time_lag']}")

233
dowhy/utils/timeseries.py Normal file
Просмотреть файл

@ -0,0 +1,233 @@
import networkx as nx
import numpy as np
import pandas as pd
def create_graph_from_user() -> nx.DiGraph:
"""
Creates a directed graph based on user input from the console.
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
Each edge can contain multiple time lags, therefore each combination of (node1,node2,time_lag) must be input individually by the user.
The user is prompted to enter edges one by one in the format 'node1 node2 time_lag',
where 'node1' and 'node2' are the nodes connected by the edge, and 'time_lag' is a numerical
value representing the weight of the edge. The user should enter 'done' to finish inputting edges.
:return: A directed graph created from the user's input.
:rtype: nx.DiGraph
Example user input:
Enter an edge: A B 4
Enter an edge: B C 2
Enter an edge: done
"""
graph = nx.DiGraph()
print("Enter the graph as a list of edges with time lags. Enter 'done' when you are finished.")
print("Each edge should be entered in the format 'node1 node2 time_lag'. For example: 'A B 4'")
while True:
edge = input("Enter an edge: ")
if edge.lower() == "done":
break
edge = edge.split()
if len(edge) != 3:
print("Invalid edge. Please enter an edge in the format 'node1 node2 time_lag'.")
continue
node1, node2, time_lag = edge
try:
time_lag = int(time_lag)
except ValueError:
print("Invalid weight. Please enter a numerical value for the time_lag.")
continue
# Check if the edge already exists
if graph.has_edge(node1, node2):
# If the edge exists, append the time_lag to the existing tuple
current_time_lag = graph[node1][node2]["time_lag"]
if isinstance(current_time_lag, tuple):
graph[node1][node2]["time_lag"] = current_time_lag + (time_lag,)
else:
graph[node1][node2]["time_lag"] = (current_time_lag, time_lag)
else:
# If the edge does not exist, create a new edge with a tuple containing the time_lag
graph.add_edge(node1, node2, time_lag=(time_lag,))
return graph
def create_graph_from_csv(file_path: str) -> nx.DiGraph:
"""
Creates a directed graph from a CSV file.
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
Each edge can contain multiple time lags, therefore each combination of (node1,node2,time_lag) must be input individually in the CSV file.
The CSV file should have at least three columns: 'node1', 'node2', and 'time_lag'.
Each row represents an edge from 'node1' to 'node2' with a 'time_lag' attribute.
:param file_path: The path to the CSV file.
:type file_path: str
:return: A directed graph created from the CSV file.
:rtype: nx.DiGraph
Example:
Example CSV content:
.. code-block:: csv
node1,node2,time_lag
A,B,5
B,C,2
A,C,7
"""
# Read the CSV file into a DataFrame
df = pd.read_csv(file_path)
# Initialize an empty directed graph
graph = nx.DiGraph()
# Add edges with time lag to the graph
for index, row in df.iterrows():
# Add validation for the time lag column to be a number
try:
time_lag = int(row["time_lag"])
except ValueError:
print(
"Invalid weight. Please enter a numerical value for the time_lag for the edge between {} and {}.".format(
row["node1"], row["node2"]
)
)
return None
# Check if the edge already exists
if graph.has_edge(row["node1"], row["node2"]):
# If the edge exists, append the time_lag to the existing tuple
current_time_lag = graph[row["node1"]][row["node2"]]["time_lag"]
if isinstance(current_time_lag, tuple):
graph[row["node1"]][row["node2"]]["time_lag"] = current_time_lag + (time_lag,)
else:
graph[row["node1"]][row["node2"]]["time_lag"] = (current_time_lag, time_lag)
else:
# If the edge does not exist, create a new edge with a tuple containing the time_lag
graph.add_edge(row["node1"], row["node2"], time_lag=(time_lag,))
return graph
def create_graph_from_dot_format(file_path: str) -> nx.DiGraph:
"""
Creates a directed graph from a DOT file and ensures it is a DiGraph.
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
Each edge can contain multiple valid time lags.
The DOT file should contain a graph in DOT format.
:param file_path: The path to the DOT file.
:type file_path: str
:return: A directed graph (DiGraph) created from the DOT file.
:rtype: nx.DiGraph
"""
# Read the DOT file into a MultiDiGraph
multi_graph = nx.drawing.nx_agraph.read_dot(file_path)
# Initialize a new DiGraph
graph = nx.DiGraph()
# Iterate over edges of the MultiDiGraph
for u, v, data in multi_graph.edges(data=True):
if "label" in data:
try:
# Convert the label to a tuple of time lags
time_lag_tuple = tuple(map(int, data["label"].strip("()").split(",")))
if graph.has_edge(u, v):
existing_data = graph.get_edge_data(u, v)
if "time_lag" in existing_data:
# Merge the existing time lags with the new ones
existing_time_lags = existing_data["time_lag"]
new_time_lags = existing_time_lags + time_lag_tuple
# Remove duplicates by converting to a set and back to a tuple
graph[u][v]["time_lag"] = tuple(set(new_time_lags))
else:
graph[u][v]["time_lag"] = time_lag_tuple
else:
graph.add_edge(u, v, time_lag=time_lag_tuple)
except ValueError:
print(f"Invalid weight for the edge between {u} and {v}.")
return None
return graph
def create_graph_from_networkx_array(array: np.ndarray, var_names: list) -> nx.DiGraph:
"""
Create a NetworkX directed graph from a numpy array with time lag information.
The time_lag parameter of the networkx graph represents the exact causal lag of an edge between any 2 nodes in the graph.
Each edge can contain multiple valid time lags.
The resulting graph will be a directed graph with edge attributes indicating
the type of link based on the array values.
:param array: A numpy array of shape (n, n, tau) representing the causal links.
:type array: np.ndarray
:param var_names: A list of variable names.
:type var_names: list
:return: A directed graph with edge attributes based on the array values.
:rtype: nx.DiGraph
"""
n = array.shape[0] # Number of variables
assert n == array.shape[1], "The array must be square."
tau = array.shape[2] # Number of time lags
# Initialize a directed graph
graph = nx.DiGraph()
# Add nodes with names
graph.add_nodes_from(var_names)
# Iterate over all pairs of nodes
for i in range(n):
for j in range(n):
if i == j:
continue # Skip self-loops
for t in range(tau):
# Check for directed links
if array[i, j, t] == "-->":
if graph.has_edge(var_names[i], var_names[j]):
# Append the time lag to the existing tuple
current_time_lag = graph[var_names[i]][var_names[j]].get("time_lag", ())
graph[var_names[i]][var_names[j]]["time_lag"] = current_time_lag + (t,)
else:
# Create a new edge with a tuple containing the time lag
graph.add_edge(var_names[i], var_names[j], time_lag=(t,))
elif array[i, j, t] == "<--":
if graph.has_edge(var_names[j], var_names[i]):
# Append the time lag to the existing tuple
current_time_lag = graph[var_names[j]][var_names[i]].get("time_lag", ())
graph[var_names[j]][var_names[i]]["time_lag"] = current_time_lag + (t,)
else:
# Create a new edge with a tuple containing the time lag
graph.add_edge(var_names[j], var_names[i], time_lag=(t,))
elif array[i, j, t] == "o-o":
raise ValueError(
"Unsupported link type 'o-o' found between {} and {} at lag {}.".format(
var_names[i], var_names[j], t
)
)
elif array[i, j, t] == "x-x":
raise ValueError(
"Unsupported link type 'x-x' found between {} and {} at lag {}.".format(
var_names[i], var_names[j], t
)
)
return graph

Просмотреть файл

@ -0,0 +1,227 @@
import unittest
from io import StringIO
import networkx as nx
import numpy as np
import pandas as pd
from dowhy.utils.timeseries import create_graph_from_csv, create_graph_from_dot_format, create_graph_from_networkx_array
class TestCreateGraphFromCSV(unittest.TestCase):
def test_basic_functionality(self):
csv_content = """node1,node2,time_lag
A,B,5
B,C,2
A,C,7"""
df = pd.read_csv(StringIO(csv_content))
df.to_csv("test.csv", index=False)
graph = create_graph_from_csv("test.csv")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 3)
self.assertEqual(graph["A"]["B"]["time_lag"], (5,))
self.assertEqual(graph["B"]["C"]["time_lag"], (2,))
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
def test_multiple_time_lags(self):
csv_content = """node1,node2,time_lag
A,B,5
A,B,3
A,C,7"""
df = pd.read_csv(StringIO(csv_content))
df.to_csv("test.csv", index=False)
graph = create_graph_from_csv("test.csv")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 2)
self.assertEqual(graph["A"]["B"]["time_lag"], (5, 3))
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
def test_invalid_time_lag(self):
csv_content = """node1,node2,time_lag
A,B,five
B,C,2
A,C,7"""
df = pd.read_csv(StringIO(csv_content))
df.to_csv("test.csv", index=False)
graph = create_graph_from_csv("test.csv")
self.assertIsNone(graph)
def test_empty_csv(self):
csv_content = """node1,node2,time_lag"""
df = pd.read_csv(StringIO(csv_content))
df.to_csv("test.csv", index=False)
graph = create_graph_from_csv("test.csv")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 0)
def test_self_loop(self):
csv_content = """node1,node2,time_lag
A,A,5"""
df = pd.read_csv(StringIO(csv_content))
df.to_csv("test.csv", index=False)
graph = create_graph_from_csv("test.csv")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 1)
self.assertEqual(graph["A"]["A"]["time_lag"], (5,))
class TestCreateGraphFromDotFormat(unittest.TestCase):
def setUp(self):
# Helper method to create a DOT file from string content
def create_dot_file(dot_content, file_name="test.dot"):
with open(file_name, "w") as f:
f.write(dot_content)
self.create_dot_file = create_dot_file
def test_basic_functionality(self):
dot_content = """digraph G {
A -> B [label="(5)"];
B -> C [label="(2)"];
A -> C [label="(7)"];
}"""
self.create_dot_file(dot_content)
graph = create_graph_from_dot_format("test.dot")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 3)
self.assertEqual(graph["A"]["B"]["time_lag"], (5,))
self.assertEqual(graph["B"]["C"]["time_lag"], (2,))
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
def test_multiple_time_lags(self):
dot_content = """digraph G {
A -> B [label="(5,3)"];
A -> C [label="(7)"];
}"""
self.create_dot_file(dot_content)
graph = create_graph_from_dot_format("test.dot")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 2)
self.assertEqual(graph["A"]["B"]["time_lag"], (5, 3))
self.assertEqual(graph["A"]["C"]["time_lag"], (7,))
def test_invalid_time_lag(self):
dot_content = """digraph G {
A -> B [label="(five)"];
B -> C [label="(2)"];
A -> C [label="(7)"];
}"""
self.create_dot_file(dot_content)
graph = create_graph_from_dot_format("test.dot")
self.assertIsNone(graph)
def test_empty_dot(self):
dot_content = """digraph G {}"""
self.create_dot_file(dot_content)
graph = create_graph_from_dot_format("test.dot")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 0)
def test_self_loop(self):
dot_content = """digraph G {
A -> A [label="(5)"];
}"""
self.create_dot_file(dot_content)
graph = create_graph_from_dot_format("test.dot")
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 1)
self.assertEqual(graph["A"]["A"]["time_lag"], (5,))
class TestCreateGraphFromNetworkxArray(unittest.TestCase):
def test_basic_functionality(self):
array = np.zeros((3, 3, 2), dtype=object)
array[0, 1, 0] = "-->"
array[1, 2, 0] = "-->"
array[0, 2, 1] = "-->"
var_names = ["X1", "X2", "X3"]
graph = create_graph_from_networkx_array(array, var_names)
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 3)
self.assertTrue(graph.has_edge("X1", "X2"))
self.assertEqual(graph["X1"]["X2"]["time_lag"], (0,))
self.assertTrue(graph.has_edge("X1", "X3"))
self.assertEqual(graph["X1"]["X3"]["time_lag"], (1,))
self.assertTrue(graph.has_edge("X2", "X3"))
self.assertEqual(graph["X2"]["X3"]["time_lag"], (0,))
def test_multiple_time_lags(self):
array = np.zeros((3, 3, 3), dtype=object)
array[0, 1, 0] = "-->"
array[0, 1, 1] = "-->"
array[1, 2, 2] = "-->"
var_names = ["X1", "X2", "X3"]
graph = create_graph_from_networkx_array(array, var_names)
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 2)
self.assertTrue(graph.has_edge("X1", "X2"))
self.assertEqual(graph["X1"]["X2"]["time_lag"], (0, 1))
self.assertTrue(graph.has_edge("X2", "X3"))
self.assertEqual(graph["X2"]["X3"]["time_lag"], (2,))
def test_invalid_link_type_oo(self):
array = np.zeros((2, 2, 1), dtype=object)
array[0, 1, 0] = "o-o"
var_names = ["X1", "X2"]
with self.assertRaises(ValueError):
create_graph_from_networkx_array(array, var_names)
def test_invalid_link_type_xx(self):
array = np.zeros((2, 2, 1), dtype=object)
array[0, 1, 0] = "x-x"
var_names = ["X1", "X2"]
with self.assertRaises(ValueError):
create_graph_from_networkx_array(array, var_names)
def test_empty_array(self):
array = np.zeros((0, 0, 0), dtype=object)
var_names = []
graph = create_graph_from_networkx_array(array, var_names)
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.nodes()), 0)
self.assertEqual(len(graph.edges()), 0)
def test_self_loop(self):
array = np.zeros((2, 2, 1), dtype=object)
array[0, 0, 0] = "-->"
var_names = ["X1", "X2"]
graph = create_graph_from_networkx_array(array, var_names)
self.assertIsInstance(graph, nx.DiGraph)
self.assertEqual(len(graph.edges()), 0)
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -0,0 +1,142 @@
import unittest
from typing import List, Optional
import networkx as nx
import pandas as pd
from pandas.testing import assert_frame_equal
from dowhy.timeseries.temporal_shift import add_lagged_edges, shift_columns_by_lag_using_unrolled_graph
class TestAddLaggedEdges(unittest.TestCase):
def test_basic_functionality(self):
graph = nx.DiGraph()
graph.add_edge("A", "B", time_lag=(1,))
graph.add_edge("B", "C", time_lag=(2,))
new_graph = add_lagged_edges(graph, "C")
self.assertIsInstance(new_graph, nx.DiGraph)
self.assertTrue(new_graph.has_node("C_0"))
self.assertTrue(new_graph.has_node("B_-2"))
self.assertTrue(new_graph.has_node("A_-3"))
self.assertTrue(new_graph.has_edge("B_-2", "C_0"))
self.assertTrue(new_graph.has_edge("A_-3", "B_-2"))
def test_multiple_time_lags(self):
graph = nx.DiGraph()
graph.add_edge("A", "B", time_lag=(1, 2))
graph.add_edge("B", "C", time_lag=(1, 3))
new_graph = add_lagged_edges(graph, "C")
self.assertIsInstance(new_graph, nx.DiGraph)
self.assertTrue(new_graph.has_node("C_0"))
self.assertTrue(new_graph.has_node("B_-1"))
self.assertTrue(new_graph.has_node("B_-3"))
self.assertTrue(new_graph.has_node("A_-2"))
self.assertTrue(new_graph.has_node("A_-3"))
self.assertTrue(new_graph.has_node("A_-4"))
self.assertTrue(new_graph.has_node("A_-5"))
self.assertTrue(new_graph.has_edge("B_-1", "C_0"))
self.assertTrue(new_graph.has_edge("B_-3", "C_0"))
self.assertTrue(new_graph.has_edge("A_-2", "B_-1"))
self.assertTrue(new_graph.has_edge("A_-4", "B_-3"))
self.assertTrue(new_graph.has_edge("B_-3", "B_-1"))
self.assertTrue(new_graph.has_edge("A_-5", "A_-4"))
self.assertTrue(new_graph.has_edge("A_-4", "A_-3"))
self.assertTrue(new_graph.has_edge("A_-3", "A_-2"))
def test_complex_graph_structure(self):
graph = nx.DiGraph()
graph.add_edge("A", "B", time_lag=(1,))
graph.add_edge("B", "C", time_lag=(2,))
graph.add_edge("A", "C", time_lag=(3,))
new_graph = add_lagged_edges(graph, "C")
self.assertIsInstance(new_graph, nx.DiGraph)
self.assertTrue(new_graph.has_node("C_0"))
self.assertTrue(new_graph.has_node("B_-2"))
self.assertTrue(new_graph.has_node("A_-3"))
self.assertTrue(new_graph.has_edge("B_-2", "C_0"))
self.assertTrue(new_graph.has_edge("A_-3", "B_-2"))
self.assertTrue(new_graph.has_edge("A_-3", "C_0"))
def test_no_time_lag(self):
graph = nx.DiGraph()
graph.add_edge("A", "B")
graph.add_edge("B", "C")
new_graph = add_lagged_edges(graph, "C")
self.assertIsInstance(new_graph, nx.DiGraph)
self.assertEqual(len(new_graph.nodes()), 0)
self.assertEqual(len(new_graph.edges()), 0)
class TestShiftColumnsByLagUsingUnrolledGraph(unittest.TestCase):
def test_basic_functionality(self):
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]})
unrolled_graph = nx.DiGraph()
unrolled_graph.add_nodes_from(["A_0", "A_-1", "B_0", "B_-2"])
expected_df = pd.DataFrame(
{"A_0": [1, 2, 3, 4, 5], "A_-1": [0, 1, 2, 3, 4], "B_0": [5, 4, 3, 2, 1], "B_-2": [0, 0, 5, 4, 3]}
)
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
assert_frame_equal(result_df, expected_df)
def test_complex_graph_structure(self):
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1], "C": [1, 3, 5, 7, 9]})
unrolled_graph = nx.DiGraph()
unrolled_graph.add_nodes_from(["A_0", "A_-1", "B_0", "B_-2", "C_-1", "C_-3"])
expected_df = pd.DataFrame(
{
"A_0": [1, 2, 3, 4, 5],
"A_-1": [0, 1, 2, 3, 4],
"B_0": [5, 4, 3, 2, 1],
"B_-2": [0, 0, 5, 4, 3],
"C_-1": [0, 1, 3, 5, 7],
"C_-3": [0, 0, 0, 1, 3],
}
)
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
assert_frame_equal(result_df, expected_df)
def test_invalid_node_format(self):
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]})
unrolled_graph = nx.DiGraph()
unrolled_graph.add_nodes_from(["A_0", "B_invalid"])
expected_df = pd.DataFrame({"A_0": [1, 2, 3, 4, 5]})
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
assert_frame_equal(result_df, expected_df)
def test_non_matching_columns(self):
df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [5, 4, 3, 2, 1]})
unrolled_graph = nx.DiGraph()
unrolled_graph.add_nodes_from(["C_0", "C_-1"])
expected_df = pd.DataFrame()
result_df = shift_columns_by_lag_using_unrolled_graph(df, unrolled_graph)
assert_frame_equal(result_df, expected_df)
if __name__ == "__main__":
unittest.main()