fixup: Format Python code with Black
This commit is contained in:
Родитель
e06e6b940c
Коммит
9211031e4e
6
base.py
6
base.py
|
@ -618,7 +618,11 @@ class BaseModel(abc.ABC):
|
|||
)
|
||||
elif search_algorithm == "grid":
|
||||
search = GridSearchCV(
|
||||
self.model, param_grid=params, refit=True, cv=cv, scoring=scoring_func,
|
||||
self.model,
|
||||
param_grid=params,
|
||||
refit=True,
|
||||
cv=cv,
|
||||
scoring=scoring_func,
|
||||
)
|
||||
elif search_algorithm == "random":
|
||||
search = RandomizedSearchCV(
|
||||
|
|
|
@ -163,7 +163,9 @@ def env_setup():
|
|||
|
||||
|
||||
def test_random_policy(
|
||||
num_episodes: int = 500, num_iterations: int = 250, sim: Simulator = None,
|
||||
num_episodes: int = 500,
|
||||
num_iterations: int = 250,
|
||||
sim: Simulator = None,
|
||||
):
|
||||
"""Test a policy using random actions over a fixed number of episodes
|
||||
|
||||
|
@ -310,7 +312,9 @@ def main(cfg: DictConfig):
|
|||
while True:
|
||||
# Advance by the new state depending on the event type
|
||||
sim_state = SimulatorState(
|
||||
sequence_id=sequence_id, state=sim.get_state(), halted=sim.halted(),
|
||||
sequence_id=sequence_id,
|
||||
state=sim.get_state(),
|
||||
halted=sim.halted(),
|
||||
)
|
||||
try:
|
||||
event = client.session.advance(
|
||||
|
|
|
@ -20,14 +20,14 @@ from all_models import available_models
|
|||
|
||||
## Add a local simulator in a `sim` folder to validate data-driven model
|
||||
## Example: Quanser from a Microsoft Bonsai
|
||||
'''
|
||||
"""
|
||||
├───ddm_test_validate.py
|
||||
├───main.py
|
||||
├───sim
|
||||
│ ├───quanser
|
||||
│ │ ├───sim
|
||||
│ │ | ├───qube_simulator.py
|
||||
'''
|
||||
"""
|
||||
# TODO: from main import TemplateSimulatorSession, env_setup
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -52,7 +52,9 @@ class Simulator(BaseModel):
|
|||
self.config_keys = configs
|
||||
self.state_keys = states
|
||||
self.action_keys = actions
|
||||
self.sim_orig = sim_orig() # include simulator function if comparing to simulator
|
||||
self.sim_orig = (
|
||||
sim_orig()
|
||||
) # include simulator function if comparing to simulator
|
||||
self.diff_state = diff_state
|
||||
if log_file == "enable":
|
||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
@ -205,10 +207,10 @@ def test_sim_model(
|
|||
for episode in range(num_episodes):
|
||||
iteration = 0
|
||||
terminal = False
|
||||
'''
|
||||
"""
|
||||
TODO: Add episode_start(config) so sim works properly and not initializing
|
||||
with unrealistic initial conditions.
|
||||
'''
|
||||
"""
|
||||
sim.episode_start()
|
||||
ddm_state = sim.get_state()
|
||||
sim_state = sim.get_sim_state()
|
||||
|
@ -267,17 +269,16 @@ def main(cfg: DictConfig):
|
|||
|
||||
input_cols = input_cols + augmented_cols
|
||||
|
||||
|
||||
ddModel = available_models[model_name]
|
||||
model = ddModel()
|
||||
|
||||
#model.build_model(**cfg["model"]["build_params"])
|
||||
# model.build_model(**cfg["model"]["build_params"])
|
||||
if model_name.lower() == "pytorch":
|
||||
model.load_model(
|
||||
input_dim=len(input_cols),
|
||||
output_dim=len(output_cols),
|
||||
filename=save_path,
|
||||
scale_data=scale_data
|
||||
scale_data=scale_data,
|
||||
)
|
||||
else:
|
||||
model.load_model(filename=save_path, scale_data=scale_data)
|
||||
|
|
|
@ -73,12 +73,20 @@ def main(cfg: DictConfig) -> None:
|
|||
)
|
||||
train_id_end = floor(X.shape[0] * (1 - test_perc))
|
||||
X_train, y_train = (
|
||||
X[:train_id_end,],
|
||||
y[:train_id_end,],
|
||||
X[
|
||||
:train_id_end,
|
||||
],
|
||||
y[
|
||||
:train_id_end,
|
||||
],
|
||||
)
|
||||
X_test, y_test = (
|
||||
X[train_id_end:,],
|
||||
y[train_id_end:,],
|
||||
X[
|
||||
train_id_end:,
|
||||
],
|
||||
y[
|
||||
train_id_end:,
|
||||
],
|
||||
)
|
||||
|
||||
# save training and test sets
|
||||
|
|
|
@ -120,7 +120,11 @@ class PyTorchModel(BaseModel):
|
|||
self.model.fit(X, y, **fit_params)
|
||||
|
||||
def load_model(
|
||||
self, input_dim: str, output_dim: str, filename: str, scale_data: bool = False,
|
||||
self,
|
||||
input_dim: str,
|
||||
output_dim: str,
|
||||
filename: str,
|
||||
scale_data: bool = False,
|
||||
):
|
||||
|
||||
self.scale_data = scale_data
|
||||
|
|
Загрузка…
Ссылка в новой задаче