add episode and iter cols to base reader

This commit is contained in:
Ali Zaidi 2021-01-15 00:14:23 -08:00
Родитель 5982153dc1
Коммит 67c2d671cc
1 изменённых файлов: 8 добавлений и 2 удалений

10
base.py
Просмотреть файл

@ -37,6 +37,8 @@ class BaseModel(abc.ABC):
augm_cols: Union[str, List[str]] = ["action_command"],
output_cols: Union[str, List[str]] = "state",
iteration_order: int = -1,
episode_col: str = "episode",
iteration_col: str = "iteration",
max_rows: Union[int, None] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Read CSV data into two datasets for modeling
@ -74,7 +76,7 @@ class BaseModel(abc.ABC):
else:
df = pd.read_csv(dataset_path, nrows=max_rows)
if type(input_cols) == str:
base_features = [col for col in df if col.startswith(input_cols)]
base_features = [str(col) for col in df if col.startswith(input_cols)]
elif type(input_cols) == list:
base_features = input_cols
else:
@ -82,7 +84,7 @@ class BaseModel(abc.ABC):
f"input_cols expected type List[str] or str but received type {type(input_cols)}"
)
if type(augm_cols) == str:
augm_features = [col for col in df if col.startswith(augm_cols)]
augm_features = [str(col) for col in df if col.startswith(augm_cols)]
elif type(augm_cols) == list:
augm_features = augm_cols
else:
@ -91,6 +93,7 @@ class BaseModel(abc.ABC):
)
features = base_features + augm_features
self.features = features
if type(output_cols) == str:
labels = [col for col in df if col.startswith(output_cols)]
@ -100,12 +103,15 @@ class BaseModel(abc.ABC):
raise TypeError(
f"output_cols expected type List[str] but received type {type(output_cols)}"
)
self.labels = labels
df = csv_reader.read(
df,
iteration_order=iteration_order,
feature_cols=features,
label_cols=labels,
episode_col=episode_col,
iteration_col=iteration_col,
)
X = df[csv_reader.feature_cols].values
y = df[csv_reader.label_cols].values