fixup: Format Python code with Black

This commit is contained in:
Ali Zaidi 2021-07-12 23:55:27 +00:00
Родитель e06e6b940c
Коммит 9211031e4e
5 изменённых файлов: 37 добавлений и 16 удалений

Просмотреть файл

@ -618,7 +618,11 @@ class BaseModel(abc.ABC):
)
elif search_algorithm == "grid":
search = GridSearchCV(
self.model, param_grid=params, refit=True, cv=cv, scoring=scoring_func,
self.model,
param_grid=params,
refit=True,
cv=cv,
scoring=scoring_func,
)
elif search_algorithm == "random":
search = RandomizedSearchCV(

Просмотреть файл

@ -163,7 +163,9 @@ def env_setup():
def test_random_policy(
num_episodes: int = 500, num_iterations: int = 250, sim: Simulator = None,
num_episodes: int = 500,
num_iterations: int = 250,
sim: Simulator = None,
):
"""Test a policy using random actions over a fixed number of episodes
@ -310,7 +312,9 @@ def main(cfg: DictConfig):
while True:
# Advance by the new state depending on the event type
sim_state = SimulatorState(
sequence_id=sequence_id, state=sim.get_state(), halted=sim.halted(),
sequence_id=sequence_id,
state=sim.get_state(),
halted=sim.halted(),
)
try:
event = client.session.advance(

Просмотреть файл

@ -20,14 +20,14 @@ from all_models import available_models
## Add a local simulator in a `sim` folder to validate data-driven model
## Example: Quanser from a Microsoft Bonsai
'''
"""
ddm_test_validate.py
main.py
sim
quanser
sim
| qube_simulator.py
'''
"""
# TODO: from main import TemplateSimulatorSession, env_setup
dir_path = os.path.dirname(os.path.realpath(__file__))
@ -52,7 +52,9 @@ class Simulator(BaseModel):
self.config_keys = configs
self.state_keys = states
self.action_keys = actions
self.sim_orig = sim_orig() # include simulator function if comparing to simulator
self.sim_orig = (
sim_orig()
) # include simulator function if comparing to simulator
self.diff_state = diff_state
if log_file == "enable":
current_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
@ -205,10 +207,10 @@ def test_sim_model(
for episode in range(num_episodes):
iteration = 0
terminal = False
'''
"""
TODO: Add episode_start(config) so sim works properly and not initializing
with unrealistic initial conditions.
'''
"""
sim.episode_start()
ddm_state = sim.get_state()
sim_state = sim.get_sim_state()
@ -267,17 +269,16 @@ def main(cfg: DictConfig):
input_cols = input_cols + augmented_cols
ddModel = available_models[model_name]
model = ddModel()
#model.build_model(**cfg["model"]["build_params"])
# model.build_model(**cfg["model"]["build_params"])
if model_name.lower() == "pytorch":
model.load_model(
input_dim=len(input_cols),
output_dim=len(output_cols),
filename=save_path,
scale_data=scale_data
scale_data=scale_data,
)
else:
model.load_model(filename=save_path, scale_data=scale_data)

Просмотреть файл

@ -73,12 +73,20 @@ def main(cfg: DictConfig) -> None:
)
train_id_end = floor(X.shape[0] * (1 - test_perc))
X_train, y_train = (
X[:train_id_end,],
y[:train_id_end,],
X[
:train_id_end,
],
y[
:train_id_end,
],
)
X_test, y_test = (
X[train_id_end:,],
y[train_id_end:,],
X[
train_id_end:,
],
y[
train_id_end:,
],
)
# save training and test sets

Просмотреть файл

@ -120,7 +120,11 @@ class PyTorchModel(BaseModel):
self.model.fit(X, y, **fit_params)
def load_model(
self, input_dim: str, output_dim: str, filename: str, scale_data: bool = False,
self,
input_dim: str,
output_dim: str,
filename: str,
scale_data: bool = False,
):
self.scale_data = scale_data