Minimal working for Journey using gboost and sim_predictor.py
This commit is contained in:
Родитель
73cfd9915d
Коммит
306afa0f34
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -7,6 +7,7 @@ dependencies:
|
|||
- pip=19.1.1
|
||||
- pytorch=1.7.0
|
||||
- torchvision=0.8
|
||||
- cryptography=3.1.1
|
||||
- pip:
|
||||
- ray==1.0.0
|
||||
- black==19.10b0
|
||||
|
|
|
@ -137,4 +137,4 @@ if __name__ == "__main__":
|
|||
xgm.fit(X, y, fit_separate=False)
|
||||
yhat = xgm.predict(X)
|
||||
|
||||
# xgm.save_model(dir_path="models/xgbm_pole_multi.pkl")
|
||||
xgm.save_model(dir_path="models/xgbm_pole_multi.pkl")
|
||||
|
|
|
@ -26,9 +26,11 @@ formater = logging.Formatter("%(name)-13s: %(levelname)-8s %(message)s")
|
|||
console.setFormatter(formater)
|
||||
logging.getLogger(__name__).addHandler(console)
|
||||
|
||||
save_path = os.path.join("models", "gbm_pole")
|
||||
#save_path = os.path.join("models", "gbm_pole")
|
||||
save_path = './models/xgbm_pole_multi.pkl'
|
||||
ddm_model = GBoostModel()
|
||||
ddm_model.load_model(dir_path=save_path, model_type="lightgbm")
|
||||
#ddm_model.load_model(dir_path=save_path, model_type="lightgbm")
|
||||
ddm_model.load_model(dir_path=save_path)
|
||||
|
||||
feature_cols = [
|
||||
"x_position",
|
||||
|
@ -82,7 +84,7 @@ class Simulator(BaseModel):
|
|||
self.last_position.update(action)
|
||||
X = np.array(list(self.last_position.values())).reshape(1, -1)
|
||||
preds = self.ddm.predict(X)
|
||||
self.state = dict(zip(self.label_cols, preds))
|
||||
self.state = dict(zip(self.label_cols, preds.reshape(preds.shape[1]).tolist()))
|
||||
return self.state
|
||||
|
||||
def get_state(self):
|
||||
|
@ -154,7 +156,7 @@ def test_random_policy(
|
|||
return sim
|
||||
|
||||
|
||||
def main(config_setup: bool = False, env_name: str = "DDM-Repsol"):
|
||||
def main(config_setup: bool = False, env_name: str = "ddm-sim-generic"):
|
||||
"""Main entrypoint for running simulator connections
|
||||
|
||||
Parameters
|
||||
|
@ -274,7 +276,7 @@ def main(config_setup: bool = False, env_name: str = "DDM-Repsol"):
|
|||
print("Episode Finishing...")
|
||||
iteration = 0
|
||||
elif event.type == "Unregister":
|
||||
print("Simulator Session unregistered by platform, Registering again!")
|
||||
print("Simulator Session unregistered by platform because '{}', Registering again!".format(event.unregister.details))
|
||||
registered_session, sequence_id = CreateSession(
|
||||
registration_info, config_client
|
||||
)
|
||||
|
|
|
@ -176,7 +176,7 @@ def main():
|
|||
|
||||
# Create simulator session and init sequence id
|
||||
registration_info = SimulatorInterface(
|
||||
name="PetroSimDDM",
|
||||
name="YourSimDDM",
|
||||
timeout=60,
|
||||
simulator_context=config_client.simulator_context,
|
||||
)
|
||||
|
@ -264,7 +264,7 @@ def main():
|
|||
elif event.type == "EpisodeFinish":
|
||||
print("Episode Finishing...")
|
||||
elif event.type == "Unregister":
|
||||
print("Simulator Session unregistered by platform, Registering again!")
|
||||
print("Simulator Session unregistered by platform because '{}', Registering again!".format(event.unregister.details))
|
||||
registered_session, sequence_id = CreateSession(
|
||||
registration_info, config_client
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче