* Add save\load for pytorch\tvm models
Add related tests

* add code and test for saving torch.jit and onnx models

* add dill

* Addressing PR comments
This commit is contained in:
Matteo Interlandi 2020-12-23 15:50:35 -08:00 коммит произвёл GitHub
Родитель bec5a9868a
Коммит b8ec670b3c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 339 добавлений и 3 удалений

Просмотреть файл

@ -24,3 +24,8 @@ from .convert import convert, convert_batch # noqa: F401, E402
# Add the supported backends in scope.
from .supported import backends # noqa: F401, E402
# Add load capabilities
from ._container import PyTorchSklearnContainer # noqa: F401, E402
from ._container import TVMSklearnContainer # noqa: F401, E402
from ._container import ONNXSklearnContainer # noqa: F401, E402

Просмотреть файл

@ -12,6 +12,7 @@ In Hummingbird we use two types of containers:
"""
from abc import ABC, abstractmethod
import dill
import os
import numpy as np
from onnxconverter_common.container import CommonSklearnModelContainer
@ -69,6 +70,16 @@ class SklearnContainer(ABC):
def model(self):
return self._model
@abstractmethod
def save(self, location):
"""
Method used to save the container for future use.
Args:
location: The location on the file system where to save the model.
"""
return
def _run(self, function, *inputs):
"""
This function scores the full dataset at once. See BatchContainer below for batched scoring.
@ -319,12 +330,70 @@ class SklearnContainerAnomalyDetection(SklearnContainerRegression):
# PyTorch containers.
class PyTorchSklearnContainer(ABC):
class PyTorchSklearnContainer(SklearnContainer):
"""
Base container for PyTorch models.
We used this container to surface PyTorch-specific functionalities in the containers.
"""
def save(self, location):
assert self.model is not None, "Saving a None model is undefined."
if constants.TEST_INPUT in self._extra_config:
self._extra_config[constants.TEST_INPUT] = None
if "torch.jit" in str(type(self.model)):
# This is a torchscript model.
assert not os.path.exists(location), "Directory {} already exists.".format(location)
os.makedirs(location)
self.model.save(os.path.join(location, constants.SAVE_LOAD_TORCH_JIT_PATH))
model = self.model
self._model = None
with open(os.path.join(location, "container.pkl"), "wb") as file:
dill.dump(self, file)
self._model = model
elif "PyTorchBackendModel" in str(type(self.model)):
# This is a pytorch model.
if not location.endswith("pkl"):
location += "pkl"
assert not os.path.exists(location), "File {} already exists.".format(location)
with open(location, "wb") as file:
dill.dump(self, file)
else:
raise RuntimeError("Model type {} not recognized.".format(type(self.model)))
@staticmethod
def load(location):
"""
Method used to load a container from the file system.
Args:
location: The location on the file system where to load the model.
Returns:
The loaded model.
"""
assert os.path.exists(location), "Model location {} does not exist.".format(location)
container = None
if os.path.isdir(location):
# This is a torch.jit model
model = torch.jit.load(os.path.join(location, constants.SAVE_LOAD_TORCH_JIT_PATH))
with open(os.path.join(location, "container.pkl"), "rb") as file:
container = dill.load(file)
container._model = model
else:
# This is a pytorch model
with open(location, "rb") as file:
container = dill.load(file)
# Need to set the number of threads to use as set in the original container.
if container._n_threads is not None:
if torch.get_num_interop_threads() != 1:
torch.set_num_interop_threads(1)
torch.set_num_threads(container._n_threads)
return container
def to(self, device):
self.model.to(device)
return self
@ -353,7 +422,7 @@ class PyTorchSklearnContainerRegression(SklearnContainerRegression, PyTorchSklea
return self.model.forward(*inputs)[0].cpu().numpy().ravel()
class PyTorchSklearnContainerClassification(PyTorchSklearnContainerRegression, SklearnContainerClassification):
class PyTorchSklearnContainerClassification(SklearnContainerClassification, PyTorchSklearnContainerRegression):
"""
Container for PyTorch models mirroring Sklearn classifiers API.
"""
@ -511,6 +580,58 @@ class ONNXSklearnContainer(SklearnContainer):
else:
raise RuntimeError("ONNX Container requires ONNX runtime installed.")
def save(self, location):
assert self.model is not None, "Saving a None model is undefined."
import onnx
if constants.TEST_INPUT in self._extra_config:
self._extra_config[constants.TEST_INPUT] = None
assert not os.path.exists(location), "Directory {} already exists.".format(location)
os.makedirs(location)
onnx.save(self.model, os.path.join(location, constants.SAVE_LOAD_ONNX_PATH))
model = self.model
session = self._session
self._model = None
self._session = None
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "wb") as file:
dill.dump(self, file)
self._model = model
self._session = session
@staticmethod
def load(location):
"""
Method used to load a container from the file system.
Args:
location: The location on the file system where to load the model.
Returns:
The loaded model.
"""
assert os.path.exists(location), "Model location {} does not exist.".format(location)
assert onnx_runtime_installed
import onnx
import onnxruntime as ort
container = None
model = onnx.load(os.path.join(location, constants.SAVE_LOAD_ONNX_PATH))
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "rb") as file:
container = dill.load(file)
container._model = model
sess_options = ort.SessionOptions()
if container._n_threads is not None:
# Need to set the number of threads to use as set in the original container.
sess_options.intra_op_num_threads = container._n_threads
sess_options.inter_op_num_threads = 1
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
container._session = ort.InferenceSession(container._model.SerializeToString(), sess_options=sess_options)
return container
def _get_named_inputs(self, inputs):
"""
Retrieve the inputs names from the session object.
@ -609,6 +730,90 @@ class TVMSklearnContainer(SklearnContainer):
os.environ["TVM_NUM_THREADS"] = str(self._n_threads)
def save(self, location):
assert self.model is not None, "Saving a None model is undefined."
from tvm.contrib import util
from tvm import relay
assert not os.path.exists(location), "Directory {} already exists.".format(location)
os.makedirs(location)
path_lib = os.path.join(location, constants.SAVE_LOAD_TVM_LIB_PATH)
self._extra_config[constants.TVM_LIB].export_library(path_lib)
with open(os.path.join(location, constants.SAVE_LOAD_TVM_GRAPH_PATH), "w") as fo:
fo.write(self._extra_config[constants.TVM_GRAPH])
with open(os.path.join(location, constants.SAVE_LOAD_TVM_PARAMS_PATH), "wb") as fo:
fo.write(relay.save_param_dict(self._extra_config[constants.TVM_PARAMS]))
# Remove all information that cannot be pickled
if constants.TEST_INPUT in self._extra_config:
self._extra_config[constants.TEST_INPUT] = None
lib = self._extra_config[constants.TVM_LIB]
graph = self._extra_config[constants.TVM_GRAPH]
params = self._extra_config[constants.TVM_PARAMS]
ctx = self._extra_config[constants.TVM_CONTEXT]
model = self._model
self._extra_config[constants.TVM_LIB] = None
self._extra_config[constants.TVM_GRAPH] = None
self._extra_config[constants.TVM_PARAMS] = None
self._extra_config[constants.TVM_CONTEXT] = None
self._ctx = "cpu" if self._ctx.device_type == 1 else "cuda"
self._model = None
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "wb") as file:
dill.dump(self, file)
# Restore the information
self._extra_config[constants.TVM_LIB] = lib
self._extra_config[constants.TVM_GRAPH] = graph
self._extra_config[constants.TVM_PARAMS] = params
self._extra_config[constants.TVM_CONTEXT] = ctx
self._ctx = ctx
self._model = model
@staticmethod
def load(location):
"""
Method used to load a container from the file system.
Args:
location: The location on the file system where to load the model.
Returns:
The loaded model.
"""
assert tvm_installed()
import tvm
from tvm.contrib import util, graph_runtime
from tvm import relay
container = None
assert os.path.exists(location), "Directory {} not found.".format(location)
path_lib = os.path.join(location, constants.SAVE_LOAD_TVM_LIB_PATH)
graph = open(os.path.join(location, constants.SAVE_LOAD_TVM_GRAPH_PATH)).read()
lib = tvm.runtime.module.load_module(path_lib)
params = relay.load_param_dict(open(os.path.join(location, constants.SAVE_LOAD_TVM_PARAMS_PATH), "rb").read())
# params = bytearray(open(os.path.join(location, "deploy_param.params"), "rb").read())
with open(os.path.join(location, constants.SAVE_LOAD_CONTAINER_PATH), "rb") as file:
container = dill.load(file)
assert container is not None, "Failed to load the model container."
ctx = tvm.cpu() if container._ctx == "cpu" else tvm.gpu
container._model = graph_runtime.create(graph, lib, ctx)
container._model.set_input(**params)
container._extra_config[constants.TVM_GRAPH] = graph
container._extra_config[constants.TVM_LIB] = lib
container._extra_config[constants.TVM_PARAMS] = params
container._extra_config[constants.TVM_CONTEXT] = ctx
container._ctx = ctx
# Need to set the number of threads to use as set in the original container.
os.environ["TVM_NUM_THREADS"] = str(container._n_threads)
return container
def _to_tvm_tensor(self, *inputs):
tvm_tensors = {}
msg = "The number of input rows {} is different from the batch size {} the TVM model is compiled for."

Просмотреть файл

@ -115,6 +115,10 @@ def _compile_tvm_model(topology, torch_model, trace_input, target, ctx, config,
tvm_model = graph_runtime.create(graph, lib, ctx)
tvm_model.set_input(**params)
extra_config[constants.TVM_GRAPH] = graph
extra_config[constants.TVM_LIB] = lib
extra_config[constants.TVM_PARAMS] = params
return tvm_model

Просмотреть файл

@ -38,9 +38,36 @@ ONNX_INITIALIZERS = "onnx_initializers"
TVM_CONTEXT = "tvm_context"
"""The context for TVM containing information on the target."""
TVM_GRAPH = "tvm_graph"
"""The graph defining the TVM model. This parameter is used for saving and loading a TVM model."""
TVM_LIB = "tvm_lib"
"""The lib for the TVM model. This parameter is used for saving and loading a TVM model."""
TVM_PARAMS = "tvm_params"
"""The params for the TVM model. This parameter is used for saving and loading a TVM model."""
TVM_INPUT_NAMES = "tvm_input_names"
"""TVM expects named inputs. This is used to set the names for the inputs."""
SAVE_LOAD_CONTAINER_PATH = "container.pkl"
"""Path where to find the container when saving or loading."""
SAVE_LOAD_TVM_LIB_PATH = "deploy_lib.tar"
"""Path where to find the TVM lib when saving or loading."""
SAVE_LOAD_TVM_GRAPH_PATH = "deploy_graph.json"
"""Path where to find the TVM graph when saving or loading."""
SAVE_LOAD_TVM_PARAMS_PATH = "deploy_param.params"
"""Path where to find the TVM params when saving or loading."""
SAVE_LOAD_TORCH_JIT_PATH = "deploy_model.zip"
"""Path where to find the torchscript model when saving or loading."""
SAVE_LOAD_ONNX_PATH = "deploy_model.onnx"
"""Path where to find the onnx model when saving or loading."""
TEST_INPUT = "test_input"
"""The test input data for models that need to be traced."""

Просмотреть файл

@ -27,6 +27,7 @@ install_requires = [
"scikit-learn>=0.21.3,<=0.23.2",
"torch>=1.4.*,<=1.7.0",
"psutil",
"dill",
]
onnx_requires = [
"onnxruntime>=1.0.0",

Просмотреть файл

@ -3,6 +3,8 @@ Tests Hummingbird's backends.
"""
import unittest
import warnings
import os
import shutil
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
@ -64,6 +66,50 @@ class TestBackends(unittest.TestCase):
self.assertIsNotNone(hb_model)
np.testing.assert_allclose(model.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
# Test pytorch save and load
def test_pytorch_save_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch")
self.assertIsNotNone(hb_model)
hb_model.save("pt-tmp.pkl")
hb_model_loaded = hummingbird.ml.PyTorchSklearnContainer.load("pt-tmp.pkl")
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
os.remove("pt-tmp.pkl")
# Test torchscript save and load
def test_torchscript_save_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "torch.jit", X)
self.assertIsNotNone(hb_model)
hb_model.save("ts-tmp")
hb_model_loaded = hummingbird.ml.PyTorchSklearnContainer.load("ts-tmp")
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
shutil.rmtree("ts-tmp")
# Test not supported backends
def test_unsupported_backend(self):
warnings.filterwarnings("ignore")
@ -113,6 +159,29 @@ class TestBackends(unittest.TestCase):
# Test tvm requires test_input
self.assertRaises(RuntimeError, hummingbird.ml.convert, model, "tvm")
# Test pytorch save and load
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_tvm_save_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "tvm", X)
self.assertIsNotNone(hb_model)
hb_model.save("tvm-tmp")
hb_model_loaded = hummingbird.ml.TVMSklearnContainer.load("tvm-tmp")
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
shutil.rmtree("tvm-tmp")
# Test onnx requires test_data or initial_types
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
@ -243,6 +312,31 @@ class TestBackends(unittest.TestCase):
# Test backends are not case sensitive
self.assertRaises(RuntimeError, hummingbird.ml.convert, onnx_ml_model, "onnx")
# Test ONNX save and load
@unittest.skipIf(
not (onnx_ml_tools_installed() and onnx_runtime_installed()), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS"
)
def test_onnx_save_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
num_classes = 2
model = GradientBoostingClassifier(n_estimators=10, max_depth=max_depth)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=100)
model.fit(X, y)
hb_model = hummingbird.ml.convert(model, "onnx", X)
self.assertIsNotNone(hb_model)
hb_model.save("onnx-tmp")
hb_model_loaded = hummingbird.ml.ONNXSklearnContainer.load("onnx-tmp")
np.testing.assert_allclose(hb_model_loaded.predict_proba(X), hb_model.predict_proba(X), rtol=1e-06, atol=1e-06)
shutil.rmtree("onnx-tmp")
# Test for when the user forgets to add a target (ex: convert(model, output) rather than convert(model, 'torch')) due to API change
def test_forgotten_backend_string(self):
from sklearn.preprocessing import LabelEncoder