refactor: loaders class now only deals with features and labels directly
This commit is contained in:
Родитель
80978f3acb
Коммит
f0c0b07309
56
base.py
56
base.py
|
@ -33,9 +33,9 @@ class BaseModel(abc.ABC):
|
||||||
def load_csv(
|
def load_csv(
|
||||||
self,
|
self,
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
input_cols_read: Union[str, List[str]] = "state",
|
input_cols: Union[str, List[str]] = "state",
|
||||||
augm_cols: Union[str, List[str]] = ["action_command"],
|
augm_cols: Union[str, List[str]] = ["action_command"],
|
||||||
output_col: Union[str, List[str]] = "state",
|
output_cols: Union[str, List[str]] = "state",
|
||||||
iteration_order: int = -1,
|
iteration_order: int = -1,
|
||||||
max_rows: Union[int, None] = None,
|
max_rows: Union[int, None] = None,
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
@ -45,7 +45,7 @@ class BaseModel(abc.ABC):
|
||||||
----------
|
----------
|
||||||
dataset_path : str
|
dataset_path : str
|
||||||
path to csv dataset
|
path to csv dataset
|
||||||
input_cols_read : Union[str, List[str]], optional
|
input_cols : Union[str, List[str]], optional
|
||||||
list of columns represent the inputs to the dynamical system in the raw dataset. Can either be a string which is then matched for all columns in the dataset, or a list of strings with exact matches, by default "state"
|
list of columns represent the inputs to the dynamical system in the raw dataset. Can either be a string which is then matched for all columns in the dataset, or a list of strings with exact matches, by default "state"
|
||||||
augm_cols : Union[str, List[str]], optional
|
augm_cols : Union[str, List[str]], optional
|
||||||
Exact match of additional columns to use for modeling, such as the actions of the current iteration and any scenario/config parameters, by default ["action_command"]
|
Exact match of additional columns to use for modeling, such as the actions of the current iteration and any scenario/config parameters, by default ["action_command"]
|
||||||
|
@ -72,19 +72,43 @@ class BaseModel(abc.ABC):
|
||||||
if not os.path.exists(dataset_path):
|
if not os.path.exists(dataset_path):
|
||||||
raise ValueError(f"No data found at {dataset_path}")
|
raise ValueError(f"No data found at {dataset_path}")
|
||||||
else:
|
else:
|
||||||
df = csv_reader.read(
|
df = pd.read_csv(dataset_path, nrows=max_rows)
|
||||||
dataset_path,
|
if type(input_cols) == str:
|
||||||
iteration_order=iteration_order,
|
base_features = [col for col in df if col.startswith(input_cols)]
|
||||||
feature_cols=input_cols_read,
|
elif type(input_cols) == list:
|
||||||
max_rows=max_rows,
|
base_features = input_cols
|
||||||
)
|
|
||||||
features = csv_reader.feature_cols + augm_cols
|
|
||||||
if type(output_col) == str:
|
|
||||||
output_cols = [col for col in df if col.startswith(output_col)]
|
|
||||||
else:
|
else:
|
||||||
output_cols = output_col
|
raise TypeError(
|
||||||
X = df[features].values
|
f"input_cols expected type List[str] or str but received type {type(input_cols)}"
|
||||||
y = df[output_cols].values
|
)
|
||||||
|
if type(augm_cols) == str:
|
||||||
|
augm_features = [col for col in df if col.startswith(augm_cols)]
|
||||||
|
elif type(augm_cols) == list:
|
||||||
|
augm_features = augm_cols
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"augm_cols expected type List[str] or str but received type {type(augm_cols)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
features = base_features + augm_features
|
||||||
|
|
||||||
|
if type(output_cols) == str:
|
||||||
|
labels = [col for col in df if col.startswith(output_cols)]
|
||||||
|
elif type(output_cols) == list:
|
||||||
|
labels = output_cols
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"output_cols expected type List[str] but received type {type(output_cols)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
df = csv_reader.read(
|
||||||
|
df,
|
||||||
|
iteration_order=iteration_order,
|
||||||
|
feature_cols=features,
|
||||||
|
label_cols=labels,
|
||||||
|
)
|
||||||
|
X = df[csv_reader.feature_cols].values
|
||||||
|
y = df[csv_reader.label_cols].values
|
||||||
|
|
||||||
self.input_dim = X.shape[1]
|
self.input_dim = X.shape[1]
|
||||||
self.output_dim = y.shape[1]
|
self.output_dim = y.shape[1]
|
||||||
|
@ -168,7 +192,7 @@ class BaseModel(abc.ABC):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
base_model = BaseModel()
|
base_model = BaseModel()
|
||||||
base_model.load_csv(
|
x, y = base_model.load_csv(
|
||||||
dataset_path="csv_data/cartpole-log.csv",
|
dataset_path="csv_data/cartpole-log.csv",
|
||||||
max_rows=1000,
|
max_rows=1000,
|
||||||
augm_cols=["action_command", "config_length", "config_masspole"],
|
augm_cols=["action_command", "config_length", "config_masspole"],
|
||||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
160
loaders.py
160
loaders.py
|
@ -8,24 +8,89 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class CsvReader(object):
|
class CsvReader(object):
|
||||||
def order_by_time(self, time_col: str = "timestamp"):
|
def split(
|
||||||
|
self,
|
||||||
|
df,
|
||||||
|
iteration_col,
|
||||||
|
episode_col,
|
||||||
|
iteration_order,
|
||||||
|
lagger_str,
|
||||||
|
current_row,
|
||||||
|
feature_cols,
|
||||||
|
label_cols,
|
||||||
|
):
|
||||||
|
"""Split the dataset by features and labels
|
||||||
|
|
||||||
pass
|
Parameters
|
||||||
|
----------
|
||||||
|
df : [type]
|
||||||
|
[description]
|
||||||
|
iteration_col : [type]
|
||||||
|
[description]
|
||||||
|
episode_col : [type]
|
||||||
|
[description]
|
||||||
|
iteration_order : [type]
|
||||||
|
[description]
|
||||||
|
lagger_str : [type]
|
||||||
|
[description]
|
||||||
|
current_row : [type]
|
||||||
|
[description]
|
||||||
|
feature_cols : [type]
|
||||||
|
[description]
|
||||||
|
label_cols : [type]
|
||||||
|
[description]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
[type]
|
||||||
|
[description]
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Iteration order set to {iteration_order} so using {current_row} from {lagger_str} {iteration_order} row"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We group by episode and iteration indices to make dataset episodic
|
||||||
|
df = df.sort_values(by=[episode_col, iteration_col])
|
||||||
|
# Create a lagged dataframe for capturing inputs and outputs
|
||||||
|
lagged_df = df.groupby(by=episode_col, as_index=False).shift(
|
||||||
|
iteration_order * -1
|
||||||
|
)
|
||||||
|
lagged_df = lagged_df.drop([iteration_col], axis=1)
|
||||||
|
|
||||||
|
features_df = lagged_df[feature_cols]
|
||||||
|
|
||||||
|
# eventually we will join the labels_df with the features_df
|
||||||
|
# if any columns are matching then rename them
|
||||||
|
if bool(set(feature_cols) & set(label_cols)):
|
||||||
|
features_df = features_df.rename(columns=lambda x: lagger_str[:4] + "_" + x)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{lagger_str.title()} states are being added to same row with prefix {lagger_str[:4]}"
|
||||||
|
)
|
||||||
|
self.feature_cols = list(features_df.columns.values)
|
||||||
|
self.label_cols = list(label_cols)
|
||||||
|
logger.info(f"Feature columns are: {self.feature_cols}")
|
||||||
|
logger.info(f"Label columns are: {self.label_cols}")
|
||||||
|
# joined_df = df.join(features_df)
|
||||||
|
vars_to_keep = (
|
||||||
|
[episode_col, iteration_col] + self.feature_cols + self.label_cols
|
||||||
|
)
|
||||||
|
return df.join(features_df)[vars_to_keep]
|
||||||
|
|
||||||
def read(
|
def read(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
df: pd.DataFrame,
|
||||||
iteration_order: int = -1,
|
iteration_order: int = -1,
|
||||||
episode_col: str = "episode",
|
episode_col: str = "episode",
|
||||||
iteration_col: str = "iteration",
|
iteration_col: str = "iteration",
|
||||||
feature_cols: Union[List, str] = "state_",
|
feature_cols: List[str] = ["state_x_position", "action_command"],
|
||||||
max_rows: Union[int, None] = None,
|
label_cols: List[str] = ["state_x_position"],
|
||||||
):
|
):
|
||||||
"""Read episodic data where each row contains either inputs and its preceding output output or the causal inputs/outputs relationship
|
"""Read episodic data where each row contains either inputs and its preceding output output or the causal inputs/outputs relationship
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
filename : str
|
df : pdf.DataFrame
|
||||||
[description]
|
[description]
|
||||||
iteration_order : int, optional
|
iteration_order : int, optional
|
||||||
[description], by default -1
|
[description], by default -1
|
||||||
|
@ -35,8 +100,6 @@ class CsvReader(object):
|
||||||
[description], by default "iteration"
|
[description], by default "iteration"
|
||||||
feature_cols : Union[List, str], optional
|
feature_cols : Union[List, str], optional
|
||||||
[description], by default "state_"
|
[description], by default "state_"
|
||||||
max_rows : Union[int, None], optional
|
|
||||||
[description], by default None
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -44,37 +107,25 @@ class CsvReader(object):
|
||||||
[description]
|
[description]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info(f"Reading data from {filename}")
|
|
||||||
df = pd.read_csv(filename, nrows=max_rows)
|
|
||||||
|
|
||||||
# CASE 1: rows are of the form {st+1, at}
|
# CASE 1: rows are of the form {st+1, at}
|
||||||
# Append st into next row
|
# Append st into next row
|
||||||
# if iteration_order < 0 then drop the iteration - iteration_order iteration from each episode
|
# if iteration_order < 0 then drop the iteration - iteration_order iteration from each episode
|
||||||
# and append previous state columns into each row: {st+1, at} -> {st, at, st+1}
|
# and append previous state columns into each row: {st+1, at} -> {st, at, st+1}
|
||||||
if all([episode_col, iteration_col, iteration_order < 0]):
|
if all([episode_col, iteration_col, iteration_order < 0]):
|
||||||
logger.info(
|
lagger_str = "previous"
|
||||||
f"Iteration order set to {iteration_order} so using inputs from previous {iteration_order} row"
|
current_row = "inputs"
|
||||||
|
|
||||||
|
joined_df = self.split(
|
||||||
|
df,
|
||||||
|
iteration_col,
|
||||||
|
episode_col,
|
||||||
|
iteration_order,
|
||||||
|
lagger_str,
|
||||||
|
current_row,
|
||||||
|
feature_cols,
|
||||||
|
label_cols,
|
||||||
)
|
)
|
||||||
df = df.sort_values(by=[episode_col, iteration_col])
|
|
||||||
lagged_df = df.groupby(by=episode_col, as_index=False).shift(
|
|
||||||
iteration_order * -1
|
|
||||||
)
|
|
||||||
lagged_df = lagged_df.drop([iteration_col], axis=1)
|
|
||||||
if type(feature_cols) == list:
|
|
||||||
self.feature_cols = feature_cols
|
|
||||||
lagged_df = lagged_df[feature_cols]
|
|
||||||
else:
|
|
||||||
self.feature_cols = [
|
|
||||||
col for col in lagged_df if col.startswith(feature_cols)
|
|
||||||
]
|
|
||||||
lagged_df = lagged_df[self.feature_cols]
|
|
||||||
lagged_df = lagged_df.rename(columns=lambda x: "prev_" + x)
|
|
||||||
logger.info(
|
|
||||||
f"Previous states are being added to same row with prefix prev_"
|
|
||||||
)
|
|
||||||
self.feature_cols = list(lagged_df.columns.values)
|
|
||||||
logger.info(f"Feature columns are: {self.feature_cols}")
|
|
||||||
joined_df = df.join(lagged_df)
|
|
||||||
# skip the first row of each episode since we do not have its st
|
# skip the first row of each episode since we do not have its st
|
||||||
joined_df = (
|
joined_df = (
|
||||||
joined_df.groupby(by=episode_col, as_index=False)
|
joined_df.groupby(by=episode_col, as_index=False)
|
||||||
|
@ -82,30 +133,23 @@ class CsvReader(object):
|
||||||
.reset_index()
|
.reset_index()
|
||||||
)
|
)
|
||||||
return joined_df.drop(["level_0", "level_1"], axis=1)
|
return joined_df.drop(["level_0", "level_1"], axis=1)
|
||||||
|
|
||||||
# CASE 2: rows of the form {st, at}
|
# CASE 2: rows of the form {st, at}
|
||||||
# Append st+1 from next row into current row {st, at, st+1}
|
# Append st+1 from next row into current row {st, at, st+1}
|
||||||
elif all([episode_col, iteration_col, iteration_order > 0]):
|
elif all([episode_col, iteration_col, iteration_order > 0]):
|
||||||
logger.info(
|
lagger_str = "next"
|
||||||
f"Iteration order set to {iteration_order} so using outputs from next {iteration_order} row"
|
current_row = "outputs"
|
||||||
|
|
||||||
|
joined_df = self.split(
|
||||||
|
df,
|
||||||
|
iteration_col,
|
||||||
|
episode_col,
|
||||||
|
iteration_order,
|
||||||
|
lagger_str,
|
||||||
|
current_row,
|
||||||
|
feature_cols,
|
||||||
|
label_cols,
|
||||||
)
|
)
|
||||||
df = df.sort_values(by=[episode_col, iteration_col])
|
|
||||||
lagged_df = df.groupby(by=episode_col, as_index=False).shift(
|
|
||||||
iteration_order * -1
|
|
||||||
)
|
|
||||||
lagged_df = lagged_df.drop([iteration_col], axis=1)
|
|
||||||
if type(feature_cols) == list:
|
|
||||||
lagged_df = lagged_df[feature_cols]
|
|
||||||
else:
|
|
||||||
self.feature_cols = [
|
|
||||||
col for col in lagged_df if col.startswith(feature_cols)
|
|
||||||
]
|
|
||||||
lagged_df = lagged_df[self.feature_cols]
|
|
||||||
lagged_df = lagged_df.rename(columns=lambda x: "next_" + x)
|
|
||||||
self.feature_cols = list(lagged_df.columns.values)
|
|
||||||
logger.info(
|
|
||||||
f"Next states are being added to same row with prefix next_"
|
|
||||||
)
|
|
||||||
joined_df = df.join(lagged_df)
|
|
||||||
# truncate before the end of iteration_order for complete observations only
|
# truncate before the end of iteration_order for complete observations only
|
||||||
joined_df = (
|
joined_df = (
|
||||||
joined_df.groupby(by=episode_col, as_index=False)
|
joined_df.groupby(by=episode_col, as_index=False)
|
||||||
|
@ -123,9 +167,7 @@ if __name__ == "__main__":
|
||||||
logger.info(f"Using data saved in directory {data_dir}")
|
logger.info(f"Using data saved in directory {data_dir}")
|
||||||
|
|
||||||
csv_reader = CsvReader()
|
csv_reader = CsvReader()
|
||||||
df = csv_reader.read(
|
df = pd.read_csv(os.path.join(data_dir, "cartpole-log.csv"), nrows=1000)
|
||||||
os.path.join(data_dir, "cartpole-log.csv"), iteration_order=-1, max_rows=1000
|
df = csv_reader.read(df, iteration_order=-1)
|
||||||
)
|
df2 = pd.read_csv(os.path.join(data_dir, "cartpole_at_st.csv"), nrows=1000)
|
||||||
df2 = csv_reader.read(
|
df2 = csv_reader.read(df2, iteration_order=1)
|
||||||
os.path.join(data_dir, "cartpole_at_st.csv"), iteration_order=1, max_rows=1000
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,9 +1,28 @@
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
from loaders import CsvReader
|
from loaders import CsvReader
|
||||||
from base import BaseModel
|
from base import BaseModel
|
||||||
|
|
||||||
data_dir = "csv_data"
|
data_dir = "csv_data"
|
||||||
|
df = pd.read_csv(os.path.join(data_dir, "cartpole-log.csv"), nrows=1000)
|
||||||
|
df2 = pd.read_csv(os.path.join(data_dir, "cartpole_at_st.csv"), nrows=1000)
|
||||||
|
features = [
|
||||||
|
"state_x_position",
|
||||||
|
"state_x_velocity",
|
||||||
|
"state_angle_position",
|
||||||
|
"state_angle_velocity",
|
||||||
|
"action_command",
|
||||||
|
"config_length",
|
||||||
|
"config_masspole",
|
||||||
|
]
|
||||||
|
|
||||||
|
labels = [
|
||||||
|
"state_x_position",
|
||||||
|
"state_x_velocity",
|
||||||
|
"state_angle_position",
|
||||||
|
"state_angle_velocity",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -15,7 +34,7 @@ def csv_reader():
|
||||||
def test_cartpole_at_st1(csv_reader):
|
def test_cartpole_at_st1(csv_reader):
|
||||||
|
|
||||||
cp_df = csv_reader.read(
|
cp_df = csv_reader.read(
|
||||||
os.path.join(data_dir, "cartpole-log.csv"), max_rows=1000, iteration_order=-1
|
df, iteration_order=-1, feature_cols=features, label_cols=labels
|
||||||
)
|
)
|
||||||
assert cp_df.shape[0] == 980
|
assert cp_df.shape[0] == 980
|
||||||
assert cp_df.shape[1] == 13
|
assert cp_df.shape[1] == 13
|
||||||
|
@ -27,7 +46,7 @@ def test_cartpole_at_st1(csv_reader):
|
||||||
def test_cartpole_at_st(csv_reader):
|
def test_cartpole_at_st(csv_reader):
|
||||||
|
|
||||||
cp2_df = csv_reader.read(
|
cp2_df = csv_reader.read(
|
||||||
os.path.join(data_dir, "cartpole_at_st.csv"), iteration_order=1, max_rows=1000
|
df2, feature_cols=features, label_cols=labels, iteration_order=1
|
||||||
)
|
)
|
||||||
|
|
||||||
assert cp2_df.shape[0] == 980
|
assert cp2_df.shape[0] == 980
|
||||||
|
@ -52,11 +71,15 @@ def test_base_reader():
|
||||||
assert y.shape[1] == 4
|
assert y.shape[1] == 4
|
||||||
|
|
||||||
|
|
||||||
# def test_diff_names():
|
def test_diff_names():
|
||||||
|
|
||||||
# base_model = BaseModel()
|
base_model = BaseModel()
|
||||||
# X, y = base_model.load_csv(
|
X, y = base_model.load_csv(
|
||||||
# dataset_path=os.path.join(data_dir, "off_names.csv"), max_rows=1000
|
dataset_path=os.path.join(data_dir, "off_names.csv"),
|
||||||
# )
|
input_cols=["x_position", "x_velocity", "angle_position", "angle_velocity",],
|
||||||
|
output_cols=["angle_position", "angle_velocity"],
|
||||||
|
augm_cols=["command", "length", "masspole"],
|
||||||
|
max_rows=1000,
|
||||||
|
)
|
||||||
|
|
||||||
# assert X.shape[0] == 980 == y.shape[0]
|
assert X.shape[0] == 980 == y.shape[0]
|
||||||
|
|
|
@ -4,8 +4,8 @@ import os
|
||||||
|
|
||||||
torch_model = PyTorchModel()
|
torch_model = PyTorchModel()
|
||||||
X, y = torch_model.load_csv(
|
X, y = torch_model.load_csv(
|
||||||
input_cols_read="state",
|
input_cols="state",
|
||||||
output_col="state",
|
output_cols="state",
|
||||||
dataset_path="csv_data/cartpole-log.csv",
|
dataset_path="csv_data/cartpole-log.csv",
|
||||||
max_rows=1000,
|
max_rows=1000,
|
||||||
augm_cols=["action_command", "config_length", "config_masspole"],
|
augm_cols=["action_command", "config_length", "config_masspole"],
|
||||||
|
|
|
@ -13,7 +13,7 @@ from base import BaseModel
|
||||||
class MVRegressor(nn.Module):
|
class MVRegressor(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_units=50,
|
num_units: int = 50,
|
||||||
input_dim: int = 28,
|
input_dim: int = 28,
|
||||||
output_dim: int = 18,
|
output_dim: int = 18,
|
||||||
p_dropout: float = 0.5,
|
p_dropout: float = 0.5,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче