landcover/web_tool/ModelSessionRandomForest.py

153 строки
4.5 KiB
Python
Исходник Обычный вид История

2020-10-16 06:22:37 +03:00
import os
import pickle
import joblib
import numpy as np
import sklearn.base
from sklearn.ensemble import RandomForestClassifier
import tensorflow as tf
import tensorflow.keras as keras
import logging
LOGGER = logging.getLogger("server")
from . import ROOT_DIR
from .ModelSessionAbstract import ModelSession
class ModelSessionRandomForest(ModelSession):
AUGMENT_MODEL = RandomForestClassifier()
def __init__(self, **kwargs):
self.augment_x_train = []
self.augment_y_train = []
self.augment_model = sklearn.base.clone(ModelSessionRandomForest.AUGMENT_MODEL)
self.augment_model_trained = False
self._last_tile = None
@property
def last_tile(self):
return self._last_tile
def run(self, tile, inference_mode=False):
tile = tile / 256.0
if self.augment_model_trained:
original_shape = tile.shape
output = tile.reshape(-1, tile.shape[2])
output = self.augment_model.predict_proba(output)
output = output.reshape(original_shape[0], original_shape[1], -1)
else:
output = tile.copy()
if not inference_mode:
self._last_tile = tile
return output
def retrain(self, **kwargs):
x_train = np.array(self.augment_x_train)
y_train = np.array(self.augment_y_train)
if x_train.shape[0] == 0:
return {
"message": "Need to add training samples in order to train",
"success": False
}
try:
self.augment_model.fit(x_train, y_train)
score = self.augment_model.score(x_train, y_train)
LOGGER.debug("Fine-tuning accuracy: %0.4f" % (score))
self.augment_model_trained = True
return {
"message": "Fine-tuning accuracy on data: %0.2f" % (score),
"success": True
}
except Exception as e:
return {
"message": "Error in 'retrain()': %s" % (e),
"success": False
}
def add_sample_point(self, row, col, class_idx):
if self._last_tile is not None:
self.augment_x_train.append(self._last_tile[row, col, :].copy())
self.augment_y_train.append(class_idx)
return {
"message": "Training sample for class %d added" % (class_idx),
"success": True
}
else:
return {
"message": "Must run model before adding a training sample",
"success": False
}
def undo(self):
if len(self.augment_y_train) > 0:
self.augment_x_train.pop()
self.augment_y_train.pop()
return {
"message": "Undid training sample",
"success": True
}
else:
return {
"message": "Nothing to undo",
"success": False
}
def reset(self):
self._last_tile = None
self.augment_x_train = []
self.augment_y_train = []
self.augment_model = sklearn.base.clone(ModelSessionRandomForest.AUGMENT_MODEL)
self.augment_model_trained = False
return {
"message": "Model reset successfully",
"success": True
}
def save_state_to(self, directory):
np.save(os.path.join(directory, "augment_x_train.npy"), np.array(self.augment_x_train))
np.save(os.path.join(directory, "augment_y_train.npy"), np.array(self.augment_y_train))
joblib.dump(self.augment_model, os.path.join(directory, "augment_model.p"))
if self.augment_model_trained:
with open(os.path.join(directory, "trained.txt"), "w") as f:
f.write("")
return {
"message": "Saved model state",
"success": True
}
def load_state_from(self, directory):
self.augment_x_train = []
self.augment_y_train = []
for sample in np.load(os.path.join(directory, "augment_x_train.npy")):
self.augment_x_train.append(sample)
for sample in np.load(os.path.join(directory, "augment_y_train.npy")):
self.augment_y_train.append(sample)
self.augment_model = joblib.load(os.path.join(directory, "augment_model.p"))
self.augment_model_trained = os.path.exists(os.path.join(directory, "trained.txt"))
return {
"message": "Loaded model state",
"success": True
}