зеркало из https://github.com/microsoft/landcover.git
153 строки
4.5 KiB
Python
153 строки
4.5 KiB
Python
import os
|
|
|
|
import pickle
|
|
import joblib
|
|
import numpy as np
|
|
|
|
import sklearn.base
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
import tensorflow as tf
|
|
import tensorflow.keras as keras
|
|
|
|
import logging
|
|
LOGGER = logging.getLogger("server")
|
|
|
|
from . import ROOT_DIR
|
|
from .ModelSessionAbstract import ModelSession
|
|
|
|
class ModelSessionRandomForest(ModelSession):
|
|
|
|
AUGMENT_MODEL = RandomForestClassifier()
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
self.augment_x_train = []
|
|
self.augment_y_train = []
|
|
self.augment_model = sklearn.base.clone(ModelSessionRandomForest.AUGMENT_MODEL)
|
|
self.augment_model_trained = False
|
|
|
|
self._last_tile = None
|
|
|
|
@property
|
|
def last_tile(self):
|
|
return self._last_tile
|
|
|
|
def run(self, tile, inference_mode=False):
|
|
tile = tile / 256.0
|
|
|
|
if self.augment_model_trained:
|
|
original_shape = tile.shape
|
|
output = tile.reshape(-1, tile.shape[2])
|
|
output = self.augment_model.predict_proba(output)
|
|
output = output.reshape(original_shape[0], original_shape[1], -1)
|
|
else:
|
|
output = tile.copy()
|
|
|
|
if not inference_mode:
|
|
self._last_tile = tile
|
|
|
|
return output
|
|
|
|
def retrain(self, **kwargs):
|
|
x_train = np.array(self.augment_x_train)
|
|
y_train = np.array(self.augment_y_train)
|
|
|
|
if x_train.shape[0] == 0:
|
|
return {
|
|
"message": "Need to add training samples in order to train",
|
|
"success": False
|
|
}
|
|
|
|
try:
|
|
self.augment_model.fit(x_train, y_train)
|
|
score = self.augment_model.score(x_train, y_train)
|
|
LOGGER.debug("Fine-tuning accuracy: %0.4f" % (score))
|
|
|
|
self.augment_model_trained = True
|
|
|
|
return {
|
|
"message": "Fine-tuning accuracy on data: %0.2f" % (score),
|
|
"success": True
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"message": "Error in 'retrain()': %s" % (e),
|
|
"success": False
|
|
}
|
|
|
|
def add_sample_point(self, row, col, class_idx):
|
|
|
|
if self._last_tile is not None:
|
|
self.augment_x_train.append(self._last_tile[row, col, :].copy())
|
|
self.augment_y_train.append(class_idx)
|
|
return {
|
|
"message": "Training sample for class %d added" % (class_idx),
|
|
"success": True
|
|
}
|
|
else:
|
|
return {
|
|
"message": "Must run model before adding a training sample",
|
|
"success": False
|
|
}
|
|
|
|
def undo(self):
|
|
if len(self.augment_y_train) > 0:
|
|
self.augment_x_train.pop()
|
|
self.augment_y_train.pop()
|
|
return {
|
|
"message": "Undid training sample",
|
|
"success": True
|
|
}
|
|
else:
|
|
return {
|
|
"message": "Nothing to undo",
|
|
"success": False
|
|
}
|
|
|
|
def reset(self):
|
|
self._last_tile = None
|
|
self.augment_x_train = []
|
|
self.augment_y_train = []
|
|
self.augment_model = sklearn.base.clone(ModelSessionRandomForest.AUGMENT_MODEL)
|
|
self.augment_model_trained = False
|
|
|
|
return {
|
|
"message": "Model reset successfully",
|
|
"success": True
|
|
}
|
|
|
|
def save_state_to(self, directory):
|
|
|
|
np.save(os.path.join(directory, "augment_x_train.npy"), np.array(self.augment_x_train))
|
|
np.save(os.path.join(directory, "augment_y_train.npy"), np.array(self.augment_y_train))
|
|
|
|
joblib.dump(self.augment_model, os.path.join(directory, "augment_model.p"))
|
|
|
|
if self.augment_model_trained:
|
|
with open(os.path.join(directory, "trained.txt"), "w") as f:
|
|
f.write("")
|
|
|
|
return {
|
|
"message": "Saved model state",
|
|
"success": True
|
|
}
|
|
|
|
def load_state_from(self, directory):
|
|
|
|
self.augment_x_train = []
|
|
self.augment_y_train = []
|
|
|
|
for sample in np.load(os.path.join(directory, "augment_x_train.npy")):
|
|
self.augment_x_train.append(sample)
|
|
for sample in np.load(os.path.join(directory, "augment_y_train.npy")):
|
|
self.augment_y_train.append(sample)
|
|
|
|
self.augment_model = joblib.load(os.path.join(directory, "augment_model.p"))
|
|
self.augment_model_trained = os.path.exists(os.path.join(directory, "trained.txt"))
|
|
|
|
return {
|
|
"message": "Loaded model state",
|
|
"success": True
|
|
}
|