2019-10-22 10:35:29 +03:00
import sys
import os
import time
import datetime
import collections
2019-10-24 01:08:15 +03:00
import subprocess
2020-01-26 01:39:06 +03:00
import shutil
2019-10-22 10:35:29 +03:00
import base64
import json
import uuid
2019-11-09 02:57:07 +03:00
import pickle
2019-10-22 10:35:29 +03:00
import numpy as np
import joblib
2020-07-05 08:33:49 +03:00
from . Utils import get_random_string , AtomicCounter
2020-07-28 22:19:06 +03:00
from . Checkpoints import Checkpoints
2020-09-18 09:02:30 +03:00
from . DataLoader import InMemoryRaster
2019-10-22 10:35:29 +03:00
2020-07-05 19:28:55 +03:00
import logging
LOGGER = logging . getLogger ( " server " )
2019-10-22 10:35:29 +03:00
2020-07-06 23:31:53 +03:00
SESSION_BASE_PATH = ' ./tmp/session '
2020-01-26 01:39:06 +03:00
SESSION_FOLDER = SESSION_BASE_PATH + " / " + datetime . datetime . now ( ) . strftime ( ' % Y- % m- %d ' )
def manage_session_folders ( ) :
if not os . path . exists ( SESSION_BASE_PATH ) :
os . makedirs ( SESSION_BASE_PATH )
if not os . path . exists ( SESSION_FOLDER ) :
shutil . rmtree ( SESSION_BASE_PATH )
os . makedirs ( SESSION_FOLDER )
2019-10-22 10:35:29 +03:00
class Session ( ) :
2020-01-27 00:55:22 +03:00
def __init__ ( self , session_id , model ) :
2019-10-23 06:27:11 +03:00
LOGGER . info ( " Instantiating a new session object with id: %s " % ( session_id ) )
2019-10-23 22:56:57 +03:00
2020-01-14 03:54:19 +03:00
self . model = model
2020-09-18 09:02:30 +03:00
self . data_loader = None
self . latest_input_raster = None # InMemoryRaster object from the most recent prediction
self . tile_map = None # A map recording the most recent prediction per pixel
2019-10-22 10:35:29 +03:00
self . current_snapshot_string = get_random_string ( 8 )
self . current_snapshot_idx = 0
self . current_request_counter = AtomicCounter ( )
self . request_list = [ ]
2019-10-23 22:56:57 +03:00
self . session_id = session_id
self . creation_time = time . time ( )
self . last_interaction_time = self . creation_time
2020-09-18 23:47:18 +03:00
def reset ( self ) :
2019-10-22 10:35:29 +03:00
self . current_snapshot_string = get_random_string ( 8 )
self . current_snapshot_idx = 0
self . current_request_counter = AtomicCounter ( )
self . request_list = [ ]
2020-09-18 23:47:18 +03:00
return self . model . reset ( )
2019-10-22 10:35:29 +03:00
def load ( self , encoded_model_fn ) :
model_fn = base64 . b64decode ( encoded_model_fn ) . decode ( ' utf-8 ' )
del self . model
self . model = joblib . load ( model_fn )
def save ( self , model_name ) :
if self . storage_type is not None :
assert self . storage_path is not None # we check for this when starting the program
snapshot_id = " %s _ %d " % ( model_name , self . current_snapshot_idx )
2020-07-05 19:28:55 +03:00
LOGGER . info ( " Saving state for %s " % ( snapshot_id ) )
2019-10-22 10:35:29 +03:00
base_dir = os . path . join ( self . storage_path , self . current_snapshot_string )
if not os . path . exists ( base_dir ) :
os . makedirs ( base_dir , exist_ok = False )
model_fn = os . path . join ( base_dir , " %s _model.p " % ( snapshot_id ) )
2020-01-21 02:49:28 +03:00
#joblib.dump(self.model, model_fn, protocol=pickle.HIGHEST_PROTOCOL)
2019-10-22 10:35:29 +03:00
if self . storage_type == " file " :
request_list_fn = os . path . join ( base_dir , " %s _request_list.p " % ( snapshot_id ) )
joblib . dump ( self . request_list , request_list_fn , protocol = pickle . HIGHEST_PROTOCOL )
elif self . storage_type == " table " :
# We don't serialize the request list when saving to table storage
pass
self . current_snapshot_idx + = 1
return base64 . b64encode ( model_fn . encode ( ' utf-8 ' ) ) . decode ( ' utf-8 ' ) # this is super dumb
else :
return None
2020-07-28 22:19:06 +03:00
def create_checkpoint ( self , dataset_name , model_name , checkpoint_name , classes ) :
if " - " in checkpoint_name :
return {
" message " : " Checkpoint name cannot contain ' - ' " ,
" success " : False
}
elif checkpoint_name == " new " :
return {
" message " : " Checkpoint name cannot be ' new ' " ,
" success " : False
}
try :
directory = Checkpoints . create_new_checkpoint_directory ( dataset_name , model_name , checkpoint_name )
except ValueError as e :
return {
" message " : e . args [ 0 ] ,
" success " : False
}
with open ( os . path . join ( directory , " classes.json " ) , " w " ) as f :
f . write ( json . dumps ( classes ) )
return self . model . save_state_to ( directory )
2019-10-23 06:27:11 +03:00
def add_entry ( self , data ) :
2020-07-28 22:19:06 +03:00
# data = data.copy()
# data["time"] = datetime.datetime.now()
# data["current_snapshot_index"] = self.current_snapshot_idx
# current_request_counter = self.current_request_counter.increment()
# data["current_request_index"] = current_request_counter
2019-10-22 10:35:29 +03:00
2020-07-28 22:19:06 +03:00
# assert "experiment" in data
2019-10-22 10:35:29 +03:00
2020-07-28 22:19:06 +03:00
# if self.storage_type == "file":
# self.request_list.append(data)
2019-10-22 10:35:29 +03:00
2020-07-28 22:19:06 +03:00
# elif self.storage_type == "table":
2019-10-22 10:35:29 +03:00
2020-07-28 22:19:06 +03:00
# data["PartitionKey"] = self.current_snapshot_string
# data["RowKey"] = "%s_%d" % (data["experiment"], current_request_counter)
2019-10-22 10:35:29 +03:00
2020-07-28 22:19:06 +03:00
# for k in data.keys():
# if isinstance(data[k], dict) or isinstance(data[k], list):
# data[k] = json.dumps(data[k])
2019-10-22 10:35:29 +03:00
2020-07-28 22:19:06 +03:00
# try:
# self.table_service.insert_entity("webtoolinteractions", data)
# except Exception as e:
# LOGGER.error(e)
# else:
# # The storage_type / --storage_path command line args were not set
# pass
2020-09-18 09:02:30 +03:00
pass
def pred_patch ( self , input_raster ) :
output = self . model . run ( input_raster . data , False )
assert input_raster . shape [ 0 ] == output . shape [ 0 ] and input_raster . shape [ 1 ] == output . shape [ 1 ] , " ModelSession must return an np.ndarray with the same height and width as the input "
return InMemoryRaster ( output , input_raster . crs , input_raster . transform , input_raster . bounds )
def pred_tile ( self , input_raster ) :
output = self . model . run ( input_raster . data , True )
assert input_raster . shape [ 0 ] == output . shape [ 0 ] and input_raster . shape [ 1 ] == output . shape [ 1 ] , " ModelSession must return an np.ndarray with the same height and width as the input "
return InMemoryRaster ( output , input_raster . crs , input_raster . transform , input_raster . bounds )
def download_all ( self ) :
pass