Minimal working for Journey using gboost and sim_predictor.py

This commit is contained in:
Journey McDowell 2021-01-03 18:02:36 -08:00
Родитель 73cfd9915d
Коммит 306afa0f34
5 изменённых файлов: 11 добавлений и 50009 удалений

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -7,6 +7,7 @@ dependencies:
- pip=19.1.1
- pytorch=1.7.0
- torchvision=0.8
- cryptography=3.1.1
- pip:
- ray==1.0.0
- black==19.10b0

Просмотреть файл

@ -137,4 +137,4 @@ if __name__ == "__main__":
xgm.fit(X, y, fit_separate=False)
yhat = xgm.predict(X)
# xgm.save_model(dir_path="models/xgbm_pole_multi.pkl")
xgm.save_model(dir_path="models/xgbm_pole_multi.pkl")

Просмотреть файл

@ -26,9 +26,11 @@ formater = logging.Formatter("%(name)-13s: %(levelname)-8s %(message)s")
console.setFormatter(formater)
logging.getLogger(__name__).addHandler(console)
save_path = os.path.join("models", "gbm_pole")
#save_path = os.path.join("models", "gbm_pole")
save_path = './models/xgbm_pole_multi.pkl'
ddm_model = GBoostModel()
ddm_model.load_model(dir_path=save_path, model_type="lightgbm")
#ddm_model.load_model(dir_path=save_path, model_type="lightgbm")
ddm_model.load_model(dir_path=save_path)
feature_cols = [
"x_position",
@ -82,7 +84,7 @@ class Simulator(BaseModel):
self.last_position.update(action)
X = np.array(list(self.last_position.values())).reshape(1, -1)
preds = self.ddm.predict(X)
self.state = dict(zip(self.label_cols, preds))
self.state = dict(zip(self.label_cols, preds.reshape(preds.shape[1]).tolist()))
return self.state
def get_state(self):
@ -154,7 +156,7 @@ def test_random_policy(
return sim
def main(config_setup: bool = False, env_name: str = "DDM-Repsol"):
def main(config_setup: bool = False, env_name: str = "ddm-sim-generic"):
"""Main entrypoint for running simulator connections
Parameters
@ -274,7 +276,7 @@ def main(config_setup: bool = False, env_name: str = "DDM-Repsol"):
print("Episode Finishing...")
iteration = 0
elif event.type == "Unregister":
print("Simulator Session unregistered by platform, Registering again!")
print("Simulator Session unregistered by platform because '{}', Registering again!".format(event.unregister.details))
registered_session, sequence_id = CreateSession(
registration_info, config_client
)

Просмотреть файл

@ -176,7 +176,7 @@ def main():
# Create simulator session and init sequence id
registration_info = SimulatorInterface(
name="PetroSimDDM",
name="YourSimDDM",
timeout=60,
simulator_context=config_client.simulator_context,
)
@ -264,7 +264,7 @@ def main():
elif event.type == "EpisodeFinish":
print("Episode Finishing...")
elif event.type == "Unregister":
print("Simulator Session unregistered by platform, Registering again!")
print("Simulator Session unregistered by platform because '{}', Registering again!".format(event.unregister.details))
registered_session, sequence_id = CreateSession(
registration_info, config_client
)