UPDATE: tests + updates to modeler scripts for separate algos

This commit is contained in:
Ali Zaidi 2021-01-11 16:20:36 -08:00
Родитель 2485284f05
Коммит 692a0851c6
6 изменённых файлов: 32 добавлений и 17 удалений

Просмотреть файл

@ -79,7 +79,10 @@ class BaseModel(abc.ABC):
max_rows=max_rows,
)
features = csv_reader.feature_cols + augm_cols
output_cols = [col for col in df if col.startswith(output_col)]
if type(output_col) == str:
output_cols = [col for col in df if col.startswith(output_col)]
else:
output_cols = output_col
X = df[features].values
y = df[output_cols].values
@ -117,7 +120,6 @@ class BaseModel(abc.ABC):
def build_model(self, scale_data: bool = False):
self.scale_data = scale_data
raise NotImplementedError
def fit(self, X, y):

Просмотреть файл

@ -84,18 +84,18 @@ class GBoostModel(BaseModel):
return preds
def save_model(self, dir_path):
def save_model(self, filename):
if self.separate_models:
if not pathlib.Path(dir_path).exists():
pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
if not pathlib.Path(filename).exists():
pathlib.Path(filename).mkdir(parents=True, exist_ok=True)
# pickle.dump(self.models, open(filename, "wb"))
for i in range(len(self.models)):
pickle.dump(
self.models[i], open(os.path.join(dir_path, f"model{i}.pkl"), "wb")
self.models[i], open(os.path.join(filename, f"model{i}.pkl"), "wb")
)
else:
pickle.dump(self.model, open(dir_path, "wb"))
pickle.dump(self.model, open(filename, "wb"))
def load_model(
self, dir_path: str, scale_data: bool = False, separate_models: bool = False
@ -145,4 +145,4 @@ if __name__ == "__main__":
xgm.fit(X, y, fit_separate=False)
yhat = xgm.predict(X)
xgm.save_model(dir_path="models/xgbm_pole_multi.pkl")
xgm.save_model(filename="models/xgbm_pole_multi.pkl")

Просмотреть файл

@ -3,11 +3,8 @@ import pandas as pd
from typing import List, Tuple, Union
import logging
FORMAT = "%(message)s"
logging.basicConfig(level="INFO", format=FORMAT, datefmt="[%X]")
logger = logging.getLogger("data_loader")
data_dir = "csv_data"
logger.info(f"Using data saved in directory {data_dir}")
logger = logging.getLogger("data_loaders")
logger.setLevel(logging.INFO)
class CsvReader(object):
@ -122,6 +119,9 @@ class CsvReader(object):
if __name__ == "__main__":
data_dir = "csv_data"
logger.info(f"Using data saved in directory {data_dir}")
csv_reader = CsvReader()
df = csv_reader.read(
os.path.join(data_dir, "cartpole-log.csv"), iteration_order=-1, max_rows=1000

Просмотреть файл

@ -26,7 +26,7 @@ def test_lgm_train():
lgbm.build_model(model_type="lightgbm")
lgbm.fit(X, y)
lgbm.save_model(dir_path="tmp/gbm_pole.pkl")
lgbm.save_model(filename="tmp/gbm_pole.pkl")
lgbm2 = GBoostModel()
lgbm2.load_model(dir_path="tmp/gbm_pole.pkl")
@ -40,7 +40,7 @@ def test_xgb_train():
xgboost_model.build_model(model_type="xgboost")
xgboost_model.fit(X, y)
xgboost_model.save_model(dir_path="tmp/gbm_pole.pkl")
xgboost_model.save_model(filename="tmp/gbm_pole.pkl")
xgm2 = GBoostModel()
xgm2.load_model(dir_path="tmp/gbm_pole.pkl")

Просмотреть файл

@ -1,8 +1,10 @@
import os
import pytest
from loaders import CsvReader, data_dir
from loaders import CsvReader
from base import BaseModel
data_dir = "csv_data"
@pytest.fixture
def csv_reader():
@ -39,7 +41,7 @@ def test_cartpole_at_st(csv_reader):
def test_base_reader():
base_model = BaseModel()
X,y = base_model.load_csv(
X, y = base_model.load_csv(
dataset_path=os.path.join(data_dir, "cartpole-log.csv"),
max_rows=1000,
augm_cols=["action_command", "config_length", "config_masspole"],
@ -49,3 +51,12 @@ def test_base_reader():
assert X.shape[1] == 7
assert y.shape[1] == 4
def test_diff_names():
base_model = BaseModel()
X, y = base_model.load_csv(
dataset_path=os.path.join(data_dir, "off_names.csv"), max_rows=1000
)
assert X.shape[0] == 980 == y.shape[0]

Просмотреть файл

@ -4,6 +4,8 @@ import os
torch_model = PyTorchModel()
X, y = torch_model.load_csv(
input_cols_read="state",
output_col="state",
dataset_path="csv_data/cartpole-log.csv",
max_rows=1000,
augm_cols=["action_command", "config_length", "config_masspole"],