onnxruntime-tvm/python/tvm/relay/build_module.py

299 строки
10 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
import warnings
import numpy as np
from tvm import expr as tvm_expr
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ty as _ty
from . import expr as _expr
from .module import Module as _Module
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
def _update_target(target):
target = target if target else _target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
tgts = {}
if isinstance(target, (str, _target.Target)):
dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type)
tgts[dev_type] = _target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)
tgts[dev_type] = _target.create(tgt)
else:
raise TypeError("target is expected to be str or " +
"tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts
class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
def __init__(self):
self.mod = _build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
def build(self, func, target=None, target_host=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
graph_json : str
The json string that can be accepted by graph runtime.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
"""
target = _update_target(target)
# Setup the params.
if params:
self._set_params(params)
# Build the function
self._build(func, target, target_host)
# Get artifacts
graph_json = self.get_json()
mod = self.get_module()
params = self.get_params()
return graph_json, mod, params
def _set_params(self, params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
def get_json(self):
"""Return the json file of the built program."""
return self._get_graph_json()
def get_module(self):
"""Return the built module."""
return self._get_module()
def get_params(self):
"""Return the updated weights."""
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret
def build(mod, target=None, target_host=None, params=None):
"""Helper function that builds a Relay function to run on TVM graph
runtime.
Parameters
----------
mod : relay.Module
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context to
target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
graph_json : str
The json string that can be accepted by graph runtime.
mod : tvm.Module
The module containing necessary libraries.
params : dict
The parameters of the final graph.
"""
if isinstance(mod, _Module):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
warnings.warn(
"Please use input parameter mod (tvm.relay.module.Module) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")
target = _update_target(target)
if isinstance(target_host, (str, _target.Target)):
target_host = _target.create(target_host)
elif target_host:
raise ValueError("target host must be the type of str, " +
"tvm.target.Target, or None")
# If current dispatch context is fallback context (the default root context),
# then load pre-tuned parameters from TopHub
if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
tophub_context = autotvm.tophub.context(list(target.values()))
else:
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(func, target, target_host, params)
return graph_json, mod, params
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
This executor is used for debug and testing purpoes.
Parameters
----------
mod : :py:class:`~tvm.relay.module.Module`
The module to support the execution.
ctx : :py:class:`TVMContext`
The runtime context to run the code on.
target : :py:class:`Target`
The target option to build the function.
"""
def __init__(self, mod, ctx, target):
assert mod is not None
self.mod = mod
self.ctx = ctx
self.target = target
def _make_executor(self, expr=None):
if expr:
self.mod["main"] = expr
ret_type = self.mod["main"].checked_type.ret_type
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(self.mod, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs):
args = self._convert_args(self.mod["main"], args, kwargs)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
# make a copy so multiple invocation won't hurt perf.
if num_outputs == 1:
return gmodule.get_output(0).copyto(_nd.cpu(0))
outputs = []
for i in range(num_outputs):
outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
return outputs
return _graph_wrapper
def create_executor(kind="debug",
mod=None,
ctx=None,
target="llvm"):
"""Factory function to create an executor.
Parameters
----------
kind : str
The type of executor
mod : :py:class:`~tvm.relay.module.Module`
The Relay module containing collection of functions
ctx : :py:class:`tvm.TVMContext`
The context to execute the code.
target : :py:class:`tvm.Target`
The corresponding context
"""
if mod is None:
mod = _Module()
if ctx is not None:
assert ctx.device_type == _nd.context(str(target), 0).device_type
else:
ctx = _nd.context(str(target), 0)
if isinstance(target, str):
target = _target.create(target)
if kind == "debug":
return _interpreter.Interpreter(mod, ctx, target)
if kind == "graph":
return GraphExecutor(mod, ctx, target)
elif kind == "vm":
return VMExecutor(mod, ctx, target)
else:
raise RuntimeError("unknown execution strategy: {0}".format(kind))