πŸ› fix - don't lag actions,configs when iteration_order = -1

This commit is contained in:
Ali Zaidi 2021-07-16 15:21:23 -07:00
Π ΠΎΠ΄ΠΈΡ‚Π΅Π»ΡŒ 7639de604f
ΠšΠΎΠΌΠΌΠΈΡ‚ 897e9b3d0b
2 ΠΈΠ·ΠΌΠ΅Π½Ρ‘Π½Π½Ρ‹Ρ… Ρ„Π°ΠΉΠ»ΠΎΠ²: 13 Π΄ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠΉ ΠΈ 6 ΡƒΠ΄Π°Π»Π΅Π½ΠΈΠΉ

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -99,6 +99,7 @@ class BaseModel(abc.ABC):
)
if not augm_cols:
logging.debug(f"No augmented columns...")
augm_features = []
elif type(augm_cols) == str:
augm_features = [str(col) for col in df if col.startswith(augm_cols)]
elif isinstance(augm_cols, (list, ListConfig)):
@ -133,6 +134,7 @@ class BaseModel(abc.ABC):
label_cols=labels,
episode_col=episode_col,
iteration_col=iteration_col,
augmented_cols=augm_features,
)
X = df[csv_reader.feature_cols].values
y = df[csv_reader.label_cols].values
@ -618,11 +620,7 @@ class BaseModel(abc.ABC):
)
elif search_algorithm == "grid":
search = GridSearchCV(
self.model,
param_grid=params,
refit=True,
cv=cv,
scoring=scoring_func,
self.model, param_grid=params, refit=True, cv=cv, scoring=scoring_func,
)
elif search_algorithm == "random":
search = RandomizedSearchCV(

ΠŸΡ€ΠΎΡΠΌΠΎΡ‚Ρ€Π΅Ρ‚ΡŒ Ρ„Π°ΠΉΠ»

@ -18,6 +18,7 @@ class CsvReader(object):
current_row,
feature_cols,
label_cols,
augmented_cols,
):
"""Split the dataset by features and labels
@ -58,6 +59,11 @@ class CsvReader(object):
lagged_df = lagged_df.drop([iteration_col], axis=1)
features_df = lagged_df[feature_cols]
# if iteration order is less than 1
# then the actions, configs should not be lagged
# only states should be lagged
if iteration_order < 0:
features_df[augmented_cols] = df[augmented_cols]
# eventually we will join the labels_df with the features_df
# if any columns are matching then rename them
@ -83,8 +89,9 @@ class CsvReader(object):
iteration_order: int = -1,
episode_col: str = "episode",
iteration_col: str = "iteration",
feature_cols: List[str] = ["state_x_position", "action_command"],
feature_cols: List[str] = ["state_x_position"],
label_cols: List[str] = ["state_x_position"],
augmented_cols: List[str] = ["action_command"],
):
"""Read episodic data where each row contains either inputs and its preceding output output or the causal inputs/outputs relationship
@ -124,6 +131,7 @@ class CsvReader(object):
current_row,
feature_cols,
label_cols,
augmented_cols,
)
# skip the first row of each episode since we do not have its st
@ -149,6 +157,7 @@ class CsvReader(object):
current_row,
feature_cols,
label_cols,
augmented_cols,
)
# truncate before the end of iteration_order for complete observations only
joined_df = (