Add save\load to containers (#399)
* Add save\load for pytorch\tvm models Add related tests * add code and test for saving torch.jit and onnx models * add dill * Addressing PR comments
This commit is contained in:
Родитель
bec5a9868a
Коммит
b8ec670b3c
|
@ -24,3 +24,8 @@ from .convert import convert, convert_batch # noqa: F401, E402
|
|||
|
||||
# Add the supported backends in scope.
|
||||
from .supported import backends # noqa: F401, E402
|
||||
|
||||
# Add load capabilities
|
||||
from ._container import PyTorchSklearnContainer # noqa: F401, E402
|
||||
from ._container import TVMSklearnContainer # noqa: F401, E402
|
||||
from ._container import ONNXSklearnContainer # noqa: F401, E402
|
||||
|
|
|
@ -12,6 +12,7 @@ In Hummingbird we use two types of containers:
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import dill
|
||||
import os
|
||||
import numpy as np
|
||||
from onnxconverter_common.container import CommonSklearnModelContainer
|
||||
|
@ -69,6 +70,16 @@ class SklearnContainer(ABC):
|
|||
def model(self):
|
||||
return self._model
|
||||
|
||||
@abstractmethod
|
||||
def save(self, location):
|
||||
"""
|
||||
Method used to save the container for future use.
|
||||
|
||||
Args:
|
||||
location: The location on the file system where to save the model.
|
||||
"""
|
||||
return
|
||||
|
||||
def _run(self, function, *inputs):
|
||||
"""
|
||||
This function scores the full dataset at once. See BatchContainer below for batched scoring.
|
||||
|
@ -319,12 +330,70 @@ class SklearnContainerAnomalyDetection(SklearnContainerRegression):
|
|||
|
||||
|
||||
# PyTorch containers.
|
||||
class PyTorchSklearnContainer(ABC):
|
||||
class PyTorchSklearnContainer(SklearnContainer):
|
||||
"""
|
||||
Base container for PyTorch models.
|
||||
We used this container to surface PyTorch-specific functionalities in the containers.
|
||||
"""
|
||||
|
||||
def save(self, location):
|
||||
assert self.model is not None, "Saving a None model is undefined."
|
||||
|
||||
if constants.TEST_INPUT in self._extra_config:
|
||||
self._extra_config[constants.TEST_INPUT] = None
|
||||
|
||||
if "torch.jit" in str(type(self.model)):
|
||||
# This is a torchscript model.
|
||||
assert not os.path.exists(location), "Directory {} already exists.".format(location)
|
||||
os.makedirs(location)
|
||||
self.model.save(os.path.join(location, constants.SAVE_LOAD_TORCH_JIT_PATH))
|
||||
model = self.model
|
||||
self._model = None
|
||||
with open(os.path.join(location, "container.pkl"), "wb") as file:
|
||||
dill.dump(self, file)
|
||||
self._model = model
|
||||
elif "PyTorchBackendModel" in str(type(self.model)):
|
||||
# This is a pytorch model.
|
||||
if not location.endswith("pkl"):
|
||||
location += "pkl"
|
||||
assert not os.path.exists(location), "File {} already exists.".format(location)
|
||||
with open(location, "wb") as file:
|
||||
dill.dump(self, file)
|
||||
else:
|
||||
raise RuntimeError("Model type {} not recognized.".format(type(self.model)))
|
||||
|
||||
@staticmethod
|
||||
def load(location):
|
||||
"""
|
||||
Method used to load a container from the file system.
|
||||
|
||||
Args:
|
||||
location: The location on the file system where to load the model.
|
||||
|
||||
Returns:
|
||||
The loaded model.
|
||||
"""
|
||||
assert os.path.exists(location), "Model location {} does not exist.".format(location)
|
||||
|
||||
container = None
|
||||
if os.path.isdir(location):
|
||||
# This is a torch.jit model
|
||||
model = torch.jit.load(os.path.join(location, constants.SAVE_LOAD_TORCH_JIT_PATH))
|
||||
with open(os.path.join(location, "container.pkl"), "rb") as file:
|
||||
container = dill.load(file)
|
||||
container._model = model
|
||||
else:
|
||||
# This is a pytorch model
|
||||
with open(location, "rb") as file:
|
||||
container = dill.load(file)
|
||||
|
||||
# Need to set the number of threads to use as set in the original container.
|
||||
if container._n_threads is not None:
|
||||
if torch.get_num_interop_threads() != 1:
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.set_num_threads(container._n_threads)
|
||||
|
||||
return container
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
@ -353,7 +422,7 @@ class PyTorchSklearnContainerRegression(SklearnContainerRegression, PyTorchSklea
|
|||
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
|
||||
|
||||
|
||||
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression, SklearnContainerClassification):
|
||||
class PyTorchSklearnContainerClassification(SklearnContainerClassification, PyTorchSklearnContainerRegression):
|
||||
"""
|
||||
Container for PyTorch models mirroring Sklearn classifiers API.
|
||||
"""
|
||||
|
@ -511,6 +580,58 @@ class ONNXSklearnContainer(SklearnContainer):
|
|||
else:
|
||||
raise RuntimeError("ONNX Container requires ONNX runtime installed.")
|
||||
|
||||
def save(self, location):
|
||||
assert self.model is not None, "Saving a None model is undefined."
|
||||
import onnx
|
||||
|
||||
if constants.TEST_INPUT in self._extra_config:
|
||||
self._extra_config[constants.TEST_INPUT] = None
|
||||
|
||||
assert not os.path.exists(location), "Directory {} already exists.".format(location)
|
||||
os.makedirs(location)
|
||||
onnx.save(self.model, os.path.join(location, constants.SAVE_LOAD_ONNX_PATH))
|
||||
model = self.model
|
||||
session = self._session
|
||||
self._model = None
|
||||
self._session = None
|
||||
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "wb") as file:
|
||||
dill.dump(self, file)
|
||||
self._model = model
|
||||
self._session = session
|
||||
|
||||
@staticmethod
|
||||
def load(location):
|
||||
"""
|
||||
Method used to load a container from the file system.
|
||||
|
||||
Args:
|
||||
location: The location on the file system where to load the model.
|
||||
|
||||
Returns:
|
||||
The loaded model.
|
||||
"""
|
||||
|
||||
assert os.path.exists(location), "Model location {} does not exist.".format(location)
|
||||
assert onnx_runtime_installed
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
container = None
|
||||
model = onnx.load(os.path.join(location, constants.SAVE_LOAD_ONNX_PATH))
|
||||
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "rb") as file:
|
||||
container = dill.load(file)
|
||||
container._model = model
|
||||
|
||||
sess_options = ort.SessionOptions()
|
||||
if container._n_threads is not None:
|
||||
# Need to set the number of threads to use as set in the original container.
|
||||
sess_options.intra_op_num_threads = container._n_threads
|
||||
sess_options.inter_op_num_threads = 1
|
||||
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
container._session = ort.InferenceSession(container._model.SerializeToString(), sess_options=sess_options)
|
||||
|
||||
return container
|
||||
|
||||
def _get_named_inputs(self, inputs):
|
||||
"""
|
||||
Retrieve the inputs names from the session object.
|
||||
|
@ -609,6 +730,90 @@ class TVMSklearnContainer(SklearnContainer):
|
|||
|
||||
os.environ["TVM_NUM_THREADS"] = str(self._n_threads)
|
||||
|
||||
def save(self, location):
|
||||
assert self.model is not None, "Saving a None model is undefined."
|
||||
from tvm.contrib import util
|
||||
from tvm import relay
|
||||
|
||||
assert not os.path.exists(location), "Directory {} already exists.".format(location)
|
||||
os.makedirs(location)
|
||||
path_lib = os.path.join(location, constants.SAVE_LOAD_TVM_LIB_PATH)
|
||||
self._extra_config[constants.TVM_LIB].export_library(path_lib)
|
||||
with open(os.path.join(location, constants.SAVE_LOAD_TVM_GRAPH_PATH), "w") as fo:
|
||||
fo.write(self._extra_config[constants.TVM_GRAPH])
|
||||
with open(os.path.join(location, constants.SAVE_LOAD_TVM_PARAMS_PATH), "wb") as fo:
|
||||
fo.write(relay.save_param_dict(self._extra_config[constants.TVM_PARAMS]))
|
||||
|
||||
# Remove all information that cannot be pickled
|
||||
if constants.TEST_INPUT in self._extra_config:
|
||||
self._extra_config[constants.TEST_INPUT] = None
|
||||
lib = self._extra_config[constants.TVM_LIB]
|
||||
graph = self._extra_config[constants.TVM_GRAPH]
|
||||
params = self._extra_config[constants.TVM_PARAMS]
|
||||
ctx = self._extra_config[constants.TVM_CONTEXT]
|
||||
model = self._model
|
||||
self._extra_config[constants.TVM_LIB] = None
|
||||
self._extra_config[constants.TVM_GRAPH] = None
|
||||
self._extra_config[constants.TVM_PARAMS] = None
|
||||
self._extra_config[constants.TVM_CONTEXT] = None
|
||||
self._ctx = "cpu" if self._ctx.device_type == 1 else "cuda"
|
||||
self._model = None
|
||||
|
||||
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "wb") as file:
|
||||
dill.dump(self, file)
|
||||
|
||||
# Restore the information
|
||||
self._extra_config[constants.TVM_LIB] = lib
|
||||
self._extra_config[constants.TVM_GRAPH] = graph
|
||||
self._extra_config[constants.TVM_PARAMS] = params
|
||||
self._extra_config[constants.TVM_CONTEXT] = ctx
|
||||
self._ctx = ctx
|
||||
self._model = model
|
||||
|
||||
@staticmethod
|
||||
def load(location):
|
||||
"""
|
||||
Method used to load a container from the file system.
|
||||
|
||||
Args:
|
||||
location: The location on the file system where to load the model.
|
||||
|
||||
Returns:
|
||||
The loaded model.
|
||||
"""
|
||||
|
||||
assert tvm_installed()
|
||||
import tvm
|
||||
from tvm.contrib import util, graph_runtime
|
||||
from tvm import relay
|
||||
|
||||
container = None
|
||||
assert os.path.exists(location), "Directory {} not found.".format(location)
|
||||
path_lib = os.path.join(location, constants.SAVE_LOAD_TVM_LIB_PATH)
|
||||
graph = open(os.path.join(location, constants.SAVE_LOAD_TVM_GRAPH_PATH)).read()
|
||||
lib = tvm.runtime.module.load_module(path_lib)
|
||||
params = relay.load_param_dict(open(os.path.join(location, constants.SAVE_LOAD_TVM_PARAMS_PATH), "rb").read())
|
||||
# params = bytearray(open(os.path.join(location, "deploy_param.params"), "rb").read())
|
||||
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "rb") as file:
|
||||
container = dill.load(file)
|
||||
|
||||
assert container is not None, "Failed to load the model container."
|
||||
|
||||
ctx = tvm.cpu() if container._ctx == "cpu" else tvm.gpu
|
||||
container._model = graph_runtime.create(graph, lib, ctx)
|
||||
container._model.set_input(**params)
|
||||
|
||||
container._extra_config[constants.TVM_GRAPH] = graph
|
||||
container._extra_config[constants.TVM_LIB] = lib
|
||||
container._extra_config[constants.TVM_PARAMS] = params
|
||||
container._extra_config[constants.TVM_CONTEXT] = ctx
|
||||
container._ctx = ctx
|
||||
|
||||
# Need to set the number of threads to use as set in the original container.
|
||||
os.environ["TVM_NUM_THREADS"] = str(container._n_threads)
|
||||
|
||||
return container
|
||||
|
||||
def _to_tvm_tensor(self, *inputs):
|
||||
tvm_tensors = {}
|
||||
msg = "The number of input rows {} is different from the batch size {} the TVM model is compiled for."
|
||||
|
|
|
@ -115,6 +115,10 @@ def _compile_tvm_model(topology, torch_model, trace_input, target, ctx, config,
|
|||
tvm_model = graph_runtime.create(graph, lib, ctx)
|
||||
tvm_model.set_input(**params)
|
||||
|
||||
extra_config[constants.TVM_GRAPH] = graph
|
||||
extra_config[constants.TVM_LIB] = lib
|
||||
extra_config[constants.TVM_PARAMS] = params
|
||||
|
||||
return tvm_model
|
||||
|
||||
|
||||
|
|
|
@ -38,9 +38,36 @@ ONNX_INITIALIZERS = "onnx_initializers"
|
|||
TVM_CONTEXT = "tvm_context"
|
||||
"""The context for TVM containing information on the target."""
|
||||
|
||||
TVM_GRAPH = "tvm_graph"
|
||||
"""The graph defining the TVM model. This parameter is used for saving and loading a TVM model."""
|
||||
|
||||
TVM_LIB = "tvm_lib"
|
||||
"""The lib for the TVM model. This parameter is used for saving and loading a TVM model."""
|
||||
|
||||
TVM_PARAMS = "tvm_params"
|
||||
"""The params for the TVM model. This parameter is used for saving and loading a TVM model."""
|
||||
|
||||
TVM_INPUT_NAMES = "tvm_input_names"
|
||||
"""TVM expects named inputs. This is used to set the names for the inputs."""
|
||||
|
||||
SAVE_LOAD_CONTAINER_PATH = "container.pkl"
|
||||
"""Path where to find the container when saving or loading."""
|
||||
|
||||
SAVE_LOAD_TVM_LIB_PATH = "deploy_lib.tar"
|
||||
"""Path where to find the TVM lib when saving or loading."""
|
||||
|
||||
SAVE_LOAD_TVM_GRAPH_PATH = "deploy_graph.json"
|
||||
"""Path where to find the TVM graph when saving or loading."""
|
||||
|
||||
SAVE_LOAD_TVM_PARAMS_PATH = "deploy_param.params"
|
||||
"""Path where to find the TVM params when saving or loading."""
|
||||
|
||||
SAVE_LOAD_TORCH_JIT_PATH = "deploy_model.zip"
|
||||
"""Path where to find the torchscript model when saving or loading."""
|
||||
|
||||
SAVE_LOAD_ONNX_PATH = "deploy_model.onnx"
|
||||
"""Path where to find the onnx model when saving or loading."""
|
||||
|
||||
TEST_INPUT = "test_input"
|
||||
"""The test input data for models that need to be traced."""
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -27,6 +27,7 @@ install_requires = [
|
|||
"scikit-learn>=0.21.3,<=0.23.2",
|
||||
"torch>=1.4.*,<=1.7.0",
|
||||
"psutil",
|
||||
"dill",
|
||||
]
|
||||
onnx_requires = [
|
||||
"onnxruntime>=1.0.0",
|
||||
|
|
|
@ -3,6 +3,8 @@ Tests Hummingbird's backends.
|
|||
"""
|
||||
import unittest
|
||||
import warnings
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
|
@ -64,6 +66,50 @@ class TestBackends(unittest.TestCase):
|
|||
self.assertIsNotNone(hb_model)
|
||||
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
# Test pytorch save and load
|
||||
def test_pytorch_save_load(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "torch")
|
||||
self.assertIsNotNone(hb_model)
|
||||
hb_model.save("pt-tmp.pkl")
|
||||
|
||||
hb_model_loaded = hummingbird.ml.PyTorchSklearnContainer.load("pt-tmp.pkl")
|
||||
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
os.remove("pt-tmp.pkl")
|
||||
|
||||
# Test torchscript save and load
|
||||
def test_torchscript_save_load(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "torch.jit", X)
|
||||
self.assertIsNotNone(hb_model)
|
||||
hb_model.save("ts-tmp")
|
||||
|
||||
hb_model_loaded = hummingbird.ml.PyTorchSklearnContainer.load("ts-tmp")
|
||||
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
shutil.rmtree("ts-tmp")
|
||||
|
||||
# Test not supported backends
|
||||
def test_unsupported_backend(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
|
@ -113,6 +159,29 @@ class TestBackends(unittest.TestCase):
|
|||
# Test tvm requires test_input
|
||||
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "tvm")
|
||||
|
||||
# Test pytorch save and load
|
||||
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
|
||||
def test_tvm_save_load(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "tvm", X)
|
||||
self.assertIsNotNone(hb_model)
|
||||
hb_model.save("tvm-tmp")
|
||||
|
||||
hb_model_loaded = hummingbird.ml.TVMSklearnContainer.load("tvm-tmp")
|
||||
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
shutil.rmtree("tvm-tmp")
|
||||
|
||||
# Test onnx requires test_data or initial_types
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
|
||||
|
@ -243,6 +312,31 @@ class TestBackends(unittest.TestCase):
|
|||
# Test backends are not case sensitive
|
||||
self.assertRaises(RuntimeError, hummingbird.ml.convert, onnx_ml_model, "onnx")
|
||||
|
||||
# Test ONNX save and load
|
||||
@unittest.skipIf(
|
||||
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
|
||||
)
|
||||
def test_onnx_save_load(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
max_depth = 10
|
||||
num_classes = 2
|
||||
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
|
||||
np.random.seed(0)
|
||||
X = np.random.rand(100, 200)
|
||||
X = np.array(X, dtype=np.float32)
|
||||
y = np.random.randint(num_classes, size=100)
|
||||
|
||||
model.fit(X, y)
|
||||
|
||||
hb_model = hummingbird.ml.convert(model, "onnx", X)
|
||||
self.assertIsNotNone(hb_model)
|
||||
hb_model.save("onnx-tmp")
|
||||
|
||||
hb_model_loaded = hummingbird.ml.ONNXSklearnContainer.load("onnx-tmp")
|
||||
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
|
||||
|
||||
shutil.rmtree("onnx-tmp")
|
||||
|
||||
# Test for when the user forgets to add a target (ex: convert(model, output) rather than convert(model, 'torch')) due to API change
|
||||
def test_forgotten_backend_string(self):
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
|
Загрузка…
Ссылка в новой задаче