UPDATE: tests + updates to modeler scripts for separate algos
This commit is contained in:
Родитель
2485284f05
Коммит
692a0851c6
6
base.py
6
base.py
|
@ -79,7 +79,10 @@ class BaseModel(abc.ABC):
|
|||
max_rows=max_rows,
|
||||
)
|
||||
features = csv_reader.feature_cols + augm_cols
|
||||
output_cols = [col for col in df if col.startswith(output_col)]
|
||||
if type(output_col) == str:
|
||||
output_cols = [col for col in df if col.startswith(output_col)]
|
||||
else:
|
||||
output_cols = output_col
|
||||
X = df[features].values
|
||||
y = df[output_cols].values
|
||||
|
||||
|
@ -117,7 +120,6 @@ class BaseModel(abc.ABC):
|
|||
def build_model(self, scale_data: bool = False):
|
||||
|
||||
self.scale_data = scale_data
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, X, y):
|
||||
|
||||
|
|
|
@ -84,18 +84,18 @@ class GBoostModel(BaseModel):
|
|||
|
||||
return preds
|
||||
|
||||
def save_model(self, dir_path):
|
||||
def save_model(self, filename):
|
||||
|
||||
if self.separate_models:
|
||||
if not pathlib.Path(dir_path).exists():
|
||||
pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
|
||||
if not pathlib.Path(filename).exists():
|
||||
pathlib.Path(filename).mkdir(parents=True, exist_ok=True)
|
||||
# pickle.dump(self.models, open(filename, "wb"))
|
||||
for i in range(len(self.models)):
|
||||
pickle.dump(
|
||||
self.models[i], open(os.path.join(dir_path, f"model{i}.pkl"), "wb")
|
||||
self.models[i], open(os.path.join(filename, f"model{i}.pkl"), "wb")
|
||||
)
|
||||
else:
|
||||
pickle.dump(self.model, open(dir_path, "wb"))
|
||||
pickle.dump(self.model, open(filename, "wb"))
|
||||
|
||||
def load_model(
|
||||
self, dir_path: str, scale_data: bool = False, separate_models: bool = False
|
||||
|
@ -145,4 +145,4 @@ if __name__ == "__main__":
|
|||
xgm.fit(X, y, fit_separate=False)
|
||||
yhat = xgm.predict(X)
|
||||
|
||||
xgm.save_model(dir_path="models/xgbm_pole_multi.pkl")
|
||||
xgm.save_model(filename="models/xgbm_pole_multi.pkl")
|
||||
|
|
10
loaders.py
10
loaders.py
|
@ -3,11 +3,8 @@ import pandas as pd
|
|||
from typing import List, Tuple, Union
|
||||
import logging
|
||||
|
||||
FORMAT = "%(message)s"
|
||||
logging.basicConfig(level="INFO", format=FORMAT, datefmt="[%X]")
|
||||
logger = logging.getLogger("data_loader")
|
||||
data_dir = "csv_data"
|
||||
logger.info(f"Using data saved in directory {data_dir}")
|
||||
logger = logging.getLogger("data_loaders")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class CsvReader(object):
|
||||
|
@ -122,6 +119,9 @@ class CsvReader(object):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
data_dir = "csv_data"
|
||||
logger.info(f"Using data saved in directory {data_dir}")
|
||||
|
||||
csv_reader = CsvReader()
|
||||
df = csv_reader.read(
|
||||
os.path.join(data_dir, "cartpole-log.csv"), iteration_order=-1, max_rows=1000
|
||||
|
|
|
@ -26,7 +26,7 @@ def test_lgm_train():
|
|||
|
||||
lgbm.build_model(model_type="lightgbm")
|
||||
lgbm.fit(X, y)
|
||||
lgbm.save_model(dir_path="tmp/gbm_pole.pkl")
|
||||
lgbm.save_model(filename="tmp/gbm_pole.pkl")
|
||||
|
||||
lgbm2 = GBoostModel()
|
||||
lgbm2.load_model(dir_path="tmp/gbm_pole.pkl")
|
||||
|
@ -40,7 +40,7 @@ def test_xgb_train():
|
|||
|
||||
xgboost_model.build_model(model_type="xgboost")
|
||||
xgboost_model.fit(X, y)
|
||||
xgboost_model.save_model(dir_path="tmp/gbm_pole.pkl")
|
||||
xgboost_model.save_model(filename="tmp/gbm_pole.pkl")
|
||||
|
||||
xgm2 = GBoostModel()
|
||||
xgm2.load_model(dir_path="tmp/gbm_pole.pkl")
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import pytest
|
||||
from loaders import CsvReader, data_dir
|
||||
from loaders import CsvReader
|
||||
from base import BaseModel
|
||||
|
||||
data_dir = "csv_data"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def csv_reader():
|
||||
|
@ -39,7 +41,7 @@ def test_cartpole_at_st(csv_reader):
|
|||
def test_base_reader():
|
||||
|
||||
base_model = BaseModel()
|
||||
X,y = base_model.load_csv(
|
||||
X, y = base_model.load_csv(
|
||||
dataset_path=os.path.join(data_dir, "cartpole-log.csv"),
|
||||
max_rows=1000,
|
||||
augm_cols=["action_command", "config_length", "config_masspole"],
|
||||
|
@ -49,3 +51,12 @@ def test_base_reader():
|
|||
assert X.shape[1] == 7
|
||||
assert y.shape[1] == 4
|
||||
|
||||
|
||||
def test_diff_names():
|
||||
|
||||
base_model = BaseModel()
|
||||
X, y = base_model.load_csv(
|
||||
dataset_path=os.path.join(data_dir, "off_names.csv"), max_rows=1000
|
||||
)
|
||||
|
||||
assert X.shape[0] == 980 == y.shape[0]
|
||||
|
|
|
@ -4,6 +4,8 @@ import os
|
|||
|
||||
torch_model = PyTorchModel()
|
||||
X, y = torch_model.load_csv(
|
||||
input_cols_read="state",
|
||||
output_col="state",
|
||||
dataset_path="csv_data/cartpole-log.csv",
|
||||
max_rows=1000,
|
||||
augm_cols=["action_command", "config_length", "config_masspole"],
|
||||
|
|
Загрузка…
Ссылка в новой задаче