UPDATE: gboost models when separate regressors
This commit is contained in:
Родитель
9e191a1e46
Коммит
6cd5395e42
|
@ -63,12 +63,16 @@ class GBoostModel(BaseModel):
|
|||
else:
|
||||
self.model.partial_fit(X, y)
|
||||
|
||||
def predict(self, X):
|
||||
def predict(self, X: np.ndarray):
|
||||
|
||||
if len(X.shape) == 1:
|
||||
X = X.reshape(1, -1)
|
||||
|
||||
if self.scale_data:
|
||||
X = self.xscalar.transform(X)
|
||||
|
||||
if self.separate_models:
|
||||
pred = []
|
||||
if self.scale_data:
|
||||
X = self.xscalar.transform(X)
|
||||
for i in range(len(self.models)):
|
||||
logger.debug(f"Predicting model {i} of {len(self.models)}")
|
||||
pred.append(self.models[i].predict(X))
|
||||
|
@ -86,6 +90,18 @@ class GBoostModel(BaseModel):
|
|||
|
||||
def save_model(self, filename):
|
||||
|
||||
if self.scale_data:
|
||||
if not self.separate_models:
|
||||
path_name = str(pathlib.Path(filename).parent)
|
||||
else:
|
||||
path_name = filename
|
||||
pickle.dump(
|
||||
self.xscalar, open(os.path.join(path_name, "xscalar.pkl"), "wb")
|
||||
)
|
||||
pickle.dump(
|
||||
self.yscalar, open(os.path.join(path_name, "yscalar.pkl"), "wb")
|
||||
)
|
||||
|
||||
if self.separate_models:
|
||||
if not pathlib.Path(filename).exists():
|
||||
pathlib.Path(filename).mkdir(parents=True, exist_ok=True)
|
||||
|
@ -98,24 +114,37 @@ class GBoostModel(BaseModel):
|
|||
pickle.dump(self.model, open(filename, "wb"))
|
||||
|
||||
def load_model(
|
||||
self, dir_path: str, scale_data: bool = False, separate_models: bool = False
|
||||
self, filename: str, scale_data: bool = False, separate_models: bool = False
|
||||
):
|
||||
|
||||
self.separate_models = separate_models
|
||||
if self.separate_models:
|
||||
all_models = os.listdir(dir_path)
|
||||
all_models = os.listdir(filename)
|
||||
all_models.sort()
|
||||
if self.scale_data:
|
||||
all_models = all_models[:-2]
|
||||
num_models = len(all_models)
|
||||
models = []
|
||||
for i in range(num_models):
|
||||
models.append(
|
||||
pickle.load(open(os.path.join(dir_path, all_models[i]), "rb"))
|
||||
pickle.load(open(os.path.join(filename, all_models[i]), "rb"))
|
||||
)
|
||||
self.models = models
|
||||
else:
|
||||
self.model = pickle.load(open(dir_path, "rb"))
|
||||
self.model = pickle.load(open(filename, "rb"))
|
||||
|
||||
self.scale_data = scale_data
|
||||
if scale_data:
|
||||
if not separate_models:
|
||||
path_name = str(pathlib.Path(filename).parent)
|
||||
else:
|
||||
path_name = filename
|
||||
self.xscalar = pickle.load(
|
||||
open(os.path.join(path_name, "xscalar.pkl"), "rb")
|
||||
)
|
||||
self.yscalar = pickle.load(
|
||||
open(os.path.join(path_name, "yscalar.pkl"), "rb")
|
||||
)
|
||||
|
||||
def sweep(self, params: Dict, X, y):
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче