2020-12-18 00:02:27 +03:00
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
|
|
from typing import List, Tuple, Union
|
|
|
|
import logging
|
|
|
|
|
2021-01-12 03:20:36 +03:00
|
|
|
logger = logging.getLogger("data_loaders")
|
|
|
|
logger.setLevel(logging.INFO)
|
2020-12-18 00:02:27 +03:00
|
|
|
|
|
|
|
|
|
|
|
class CsvReader(object):
|
2020-12-18 04:55:28 +03:00
|
|
|
def order_by_time(self, time_col: str = "timestamp"):
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
2020-12-18 00:02:27 +03:00
|
|
|
def read(
|
|
|
|
self,
|
|
|
|
filename: str,
|
2020-12-18 05:48:51 +03:00
|
|
|
iteration_order: int = -1,
|
2020-12-18 04:55:28 +03:00
|
|
|
episode_col: str = "episode",
|
|
|
|
iteration_col: str = "iteration",
|
2020-12-18 00:02:27 +03:00
|
|
|
feature_cols: Union[List, str] = "state_",
|
|
|
|
max_rows: Union[int, None] = None,
|
|
|
|
):
|
2020-12-18 04:55:28 +03:00
|
|
|
"""Read episodic data where each row contains either inputs and its preceding output output or the causal inputs/outputs relationship
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
filename : str
|
|
|
|
[description]
|
2020-12-18 05:48:51 +03:00
|
|
|
iteration_order : int, optional
|
2020-12-18 04:55:28 +03:00
|
|
|
[description], by default -1
|
|
|
|
episode_col : str, optional
|
|
|
|
[description], by default "episode"
|
|
|
|
iteration_col : str, optional
|
|
|
|
[description], by default "iteration"
|
|
|
|
feature_cols : Union[List, str], optional
|
|
|
|
[description], by default "state_"
|
|
|
|
max_rows : Union[int, None], optional
|
|
|
|
[description], by default None
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
[type]
|
|
|
|
[description]
|
|
|
|
"""
|
2020-12-18 00:02:27 +03:00
|
|
|
|
2020-12-18 00:54:46 +03:00
|
|
|
logger.info(f"Reading data from {filename}")
|
2020-12-18 00:02:27 +03:00
|
|
|
df = pd.read_csv(filename, nrows=max_rows)
|
|
|
|
|
|
|
|
# CASE 1: rows are of the form {st+1, at}
|
|
|
|
# Append st into next row
|
2020-12-18 05:48:51 +03:00
|
|
|
# if iteration_order < 0 then drop the iteration - iteration_order iteration from each episode
|
2020-12-18 00:02:27 +03:00
|
|
|
# and append previous state columns into each row: {st+1, at} -> {st, at, st+1}
|
2020-12-18 05:48:51 +03:00
|
|
|
if all([episode_col, iteration_col, iteration_order < 0]):
|
2020-12-18 00:54:46 +03:00
|
|
|
logger.info(
|
2020-12-18 05:48:51 +03:00
|
|
|
f"Iteration order set to {iteration_order} so using inputs from previous {iteration_order} row"
|
2020-12-18 00:54:46 +03:00
|
|
|
)
|
2020-12-18 00:02:27 +03:00
|
|
|
df = df.sort_values(by=[episode_col, iteration_col])
|
2020-12-18 05:48:51 +03:00
|
|
|
lagged_df = df.groupby(by=episode_col, as_index=False).shift(
|
|
|
|
iteration_order * -1
|
|
|
|
)
|
2020-12-18 00:02:27 +03:00
|
|
|
lagged_df = lagged_df.drop([iteration_col], axis=1)
|
|
|
|
if type(feature_cols) == list:
|
2020-12-18 05:37:28 +03:00
|
|
|
self.feature_cols = feature_cols
|
2020-12-18 00:02:27 +03:00
|
|
|
lagged_df = lagged_df[feature_cols]
|
|
|
|
else:
|
|
|
|
self.feature_cols = [
|
|
|
|
col for col in lagged_df if col.startswith(feature_cols)
|
|
|
|
]
|
|
|
|
lagged_df = lagged_df[self.feature_cols]
|
|
|
|
lagged_df = lagged_df.rename(columns=lambda x: "prev_" + x)
|
2020-12-18 00:54:46 +03:00
|
|
|
logger.info(
|
|
|
|
f"Previous states are being added to same row with prefix prev_"
|
|
|
|
)
|
2020-12-18 00:02:27 +03:00
|
|
|
self.feature_cols = list(lagged_df.columns.values)
|
2020-12-18 05:37:28 +03:00
|
|
|
logger.info(f"Feature columns are: {self.feature_cols}")
|
2020-12-18 00:02:27 +03:00
|
|
|
joined_df = df.join(lagged_df)
|
|
|
|
# skip the first row of each episode since we do not have its st
|
|
|
|
joined_df = (
|
|
|
|
joined_df.groupby(by=episode_col, as_index=False)
|
2020-12-18 05:48:51 +03:00
|
|
|
.apply(lambda x: x.iloc[iteration_order * -1 :])
|
2020-12-18 00:02:27 +03:00
|
|
|
.reset_index()
|
|
|
|
)
|
|
|
|
return joined_df.drop(["level_0", "level_1"], axis=1)
|
|
|
|
# CASE 2: rows of the form {st, at}
|
|
|
|
# Append st+1 from next row into current row {st, at, st+1}
|
2020-12-18 05:48:51 +03:00
|
|
|
elif all([episode_col, iteration_col, iteration_order > 0]):
|
2020-12-18 00:54:46 +03:00
|
|
|
logger.info(
|
2020-12-18 05:48:51 +03:00
|
|
|
f"Iteration order set to {iteration_order} so using outputs from next {iteration_order} row"
|
2020-12-18 00:54:46 +03:00
|
|
|
)
|
2020-12-18 00:02:27 +03:00
|
|
|
df = df.sort_values(by=[episode_col, iteration_col])
|
2020-12-18 05:48:51 +03:00
|
|
|
lagged_df = df.groupby(by=episode_col, as_index=False).shift(
|
|
|
|
iteration_order * -1
|
|
|
|
)
|
2020-12-18 00:02:27 +03:00
|
|
|
lagged_df = lagged_df.drop([iteration_col], axis=1)
|
|
|
|
if type(feature_cols) == list:
|
|
|
|
lagged_df = lagged_df[feature_cols]
|
|
|
|
else:
|
|
|
|
self.feature_cols = [
|
|
|
|
col for col in lagged_df if col.startswith(feature_cols)
|
|
|
|
]
|
|
|
|
lagged_df = lagged_df[self.feature_cols]
|
|
|
|
lagged_df = lagged_df.rename(columns=lambda x: "next_" + x)
|
|
|
|
self.feature_cols = list(lagged_df.columns.values)
|
2020-12-18 00:54:46 +03:00
|
|
|
logger.info(
|
|
|
|
f"Next states are being added to same row with prefix next_"
|
|
|
|
)
|
2020-12-18 00:02:27 +03:00
|
|
|
joined_df = df.join(lagged_df)
|
2020-12-18 05:48:51 +03:00
|
|
|
# truncate before the end of iteration_order for complete observations only
|
2020-12-18 00:02:27 +03:00
|
|
|
joined_df = (
|
|
|
|
joined_df.groupby(by=episode_col, as_index=False)
|
2020-12-18 05:48:51 +03:00
|
|
|
.apply(lambda x: x.iloc[: iteration_order * -1])
|
2020-12-18 00:02:27 +03:00
|
|
|
.reset_index()
|
|
|
|
)
|
|
|
|
return joined_df.drop(["level_0", "level_1"], axis=1)
|
|
|
|
else:
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
2021-01-12 03:20:36 +03:00
|
|
|
data_dir = "csv_data"
|
|
|
|
logger.info(f"Using data saved in directory {data_dir}")
|
|
|
|
|
2020-12-18 00:02:27 +03:00
|
|
|
csv_reader = CsvReader()
|
|
|
|
df = csv_reader.read(
|
2020-12-18 05:48:51 +03:00
|
|
|
os.path.join(data_dir, "cartpole-log.csv"), iteration_order=-1, max_rows=1000
|
2020-12-18 00:02:27 +03:00
|
|
|
)
|
2020-12-18 00:54:46 +03:00
|
|
|
df2 = csv_reader.read(
|
2020-12-18 05:48:51 +03:00
|
|
|
os.path.join(data_dir, "cartpole_at_st.csv"), iteration_order=1, max_rows=1000
|
2020-12-18 00:54:46 +03:00
|
|
|
)
|