add episode and iter cols to base reader
This commit is contained in:
Родитель
5982153dc1
Коммит
67c2d671cc
10
base.py
10
base.py
|
@ -37,6 +37,8 @@ class BaseModel(abc.ABC):
|
|||
augm_cols: Union[str, List[str]] = ["action_command"],
|
||||
output_cols: Union[str, List[str]] = "state",
|
||||
iteration_order: int = -1,
|
||||
episode_col: str = "episode",
|
||||
iteration_col: str = "iteration",
|
||||
max_rows: Union[int, None] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Read CSV data into two datasets for modeling
|
||||
|
@ -74,7 +76,7 @@ class BaseModel(abc.ABC):
|
|||
else:
|
||||
df = pd.read_csv(dataset_path, nrows=max_rows)
|
||||
if type(input_cols) == str:
|
||||
base_features = [col for col in df if col.startswith(input_cols)]
|
||||
base_features = [str(col) for col in df if col.startswith(input_cols)]
|
||||
elif type(input_cols) == list:
|
||||
base_features = input_cols
|
||||
else:
|
||||
|
@ -82,7 +84,7 @@ class BaseModel(abc.ABC):
|
|||
f"input_cols expected type List[str] or str but received type {type(input_cols)}"
|
||||
)
|
||||
if type(augm_cols) == str:
|
||||
augm_features = [col for col in df if col.startswith(augm_cols)]
|
||||
augm_features = [str(col) for col in df if col.startswith(augm_cols)]
|
||||
elif type(augm_cols) == list:
|
||||
augm_features = augm_cols
|
||||
else:
|
||||
|
@ -91,6 +93,7 @@ class BaseModel(abc.ABC):
|
|||
)
|
||||
|
||||
features = base_features + augm_features
|
||||
self.features = features
|
||||
|
||||
if type(output_cols) == str:
|
||||
labels = [col for col in df if col.startswith(output_cols)]
|
||||
|
@ -100,12 +103,15 @@ class BaseModel(abc.ABC):
|
|||
raise TypeError(
|
||||
f"output_cols expected type List[str] but received type {type(output_cols)}"
|
||||
)
|
||||
self.labels = labels
|
||||
|
||||
df = csv_reader.read(
|
||||
df,
|
||||
iteration_order=iteration_order,
|
||||
feature_cols=features,
|
||||
label_cols=labels,
|
||||
episode_col=episode_col,
|
||||
iteration_col=iteration_col,
|
||||
)
|
||||
X = df[csv_reader.feature_cols].values
|
||||
y = df[csv_reader.label_cols].values
|
||||
|
|
Загрузка…
Ссылка в новой задаче