[RELAY][RUNTIME] Add Relay interpreter and compiler for TVM runtime system. (#1954)
This commit is contained in:
Родитель
07399e0239
Коммит
10ea05e645
|
@ -22,8 +22,15 @@ namespace tvm {
|
|||
* You can find more about Relay by reading the language reference.
|
||||
*/
|
||||
namespace relay {
|
||||
|
||||
#define RELAY_DEBUG(...) \
|
||||
{ auto fdebug = runtime::Registry::Get("relay.debug"); \
|
||||
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
|
||||
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief we always used NodeRef for referencing nodes.
|
||||
* \brief We always used NodeRef for referencing nodes.
|
||||
*
|
||||
* By default, NodeRef is a std::shared_ptr of node
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file tvm/relay/build_module.h
|
||||
* \brief The passes and data structures needed to build a
|
||||
* tvm::Module from a Relay program.
|
||||
*/
|
||||
#ifndef TVM_RELAY_BUILD_MODULE_H_
|
||||
#define TVM_RELAY_BUILD_MODULE_H_
|
||||
|
||||
#include <tvm/lowered_func.h>
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <string>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
/*! \brief A lowered Relay operation.
|
||||
*
|
||||
* A lowered operation is a pair containing the "primitive" function used
|
||||
* to produce the lowered function as well as the lowered function itself.
|
||||
*/
|
||||
class LoweredOp;
|
||||
/*! \brief Call container. */
|
||||
class LoweredOpNode : public Node {
|
||||
public:
|
||||
/*!
|
||||
* \brief The primitive function to be lowered.
|
||||
*
|
||||
* A primitive function consists only of calls to relay::Op which
|
||||
* can be fused.
|
||||
*/
|
||||
Function func;
|
||||
|
||||
/*!
|
||||
* \brief The lowered function.
|
||||
*/
|
||||
LoweredFunc lowered_func;
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("func", &func);
|
||||
v->Visit("lowered_func", &lowered_func);
|
||||
}
|
||||
|
||||
TVM_DLL static LoweredOp make(
|
||||
Function func,
|
||||
LoweredFunc lowered_func);
|
||||
|
||||
static constexpr const char* _type_key = "relay.LoweredOp";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
|
||||
};
|
||||
|
||||
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
|
||||
|
||||
/*!
|
||||
* \brief Lower the operations contained in a Relay expression.
|
||||
*
|
||||
* The lowering pass will only lower functions marked as primitive,
|
||||
* the FuseOps pass will provide this behavior, if run before LowerOps.
|
||||
*
|
||||
* \note This will do a reachability analysis and lower all definitions
|
||||
* reachable from the provided expression.
|
||||
*
|
||||
* \param env The environment.
|
||||
* \param expr The expression with operations to be lowered.
|
||||
* \param target The target to lower the functions to.
|
||||
*
|
||||
* \return The set of lowered operations.
|
||||
*/
|
||||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
|
||||
const std::string& target = "llvm");
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_RELAY_BUILD_MODULE_H_
|
|
@ -213,12 +213,18 @@ class FunctionNode : public ExprNode {
|
|||
*/
|
||||
tvm::Array<TypeVar> type_params;
|
||||
|
||||
/*!
|
||||
* \brief The attributes which store metadata about functions.
|
||||
*/
|
||||
tvm::Attrs attrs;
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("params", ¶ms);
|
||||
v->Visit("body", &body);
|
||||
v->Visit("ret_type", &ret_type);
|
||||
v->Visit("type_params", &type_params);
|
||||
v->Visit("span", &span);
|
||||
v->Visit("attrs", &attrs);
|
||||
v->Visit("_checked_type_", &checked_type_);
|
||||
}
|
||||
|
||||
|
@ -233,7 +239,8 @@ class FunctionNode : public ExprNode {
|
|||
TVM_DLL static Function make(tvm::Array<Var> params,
|
||||
Expr body,
|
||||
Type ret_type,
|
||||
tvm::Array<TypeVar> ty_params);
|
||||
tvm::Array<TypeVar> ty_params,
|
||||
tvm::Attrs attrs = Attrs());
|
||||
|
||||
static constexpr const char* _type_key = "relay.Function";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
|
||||
|
@ -241,6 +248,11 @@ class FunctionNode : public ExprNode {
|
|||
|
||||
RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
|
||||
|
||||
|
||||
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
|
||||
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Call corresponds to operator invocation.
|
||||
* Corresponds to the operator in computational graph terminology.
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file tvm/relay/interpreter.h
|
||||
* \brief An interpreter for Relay.
|
||||
*
|
||||
* This file implements a simple reference interpreter for Relay programs.
|
||||
* Given a Relay environment, and a Relay expression it produces a value.
|
||||
*
|
||||
* The interpreter's values are a naive representation of the values that
|
||||
* can be produced by a Relay program and are exposed via tvm::Node's
|
||||
* system to Python for introspection and debugging.
|
||||
*
|
||||
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
|
||||
* as well as for debugging and testing.
|
||||
*/
|
||||
#ifndef TVM_RELAY_INTERPRETER_H_
|
||||
#define TVM_RELAY_INTERPRETER_H_
|
||||
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
/*!
|
||||
* \brief A Relay value.
|
||||
*/
|
||||
class Value;
|
||||
|
||||
/*! \brief Evaluate an expression using the interpreter producing a value.
|
||||
*
|
||||
* The resulting value can be passed to Python, making it easy to use
|
||||
* for testing and debugging.
|
||||
*
|
||||
* The interpreter interprets the program fragments not supported by the
|
||||
* TVM runtime, although the interpreter is naively implemented it uses
|
||||
* TVM operators for evaluating all operators.
|
||||
*
|
||||
* Our intent is that this will never be the most efficient implementation of
|
||||
* Relay's semantics, but a readable and clear one.
|
||||
*/
|
||||
Value Evaluate(Environment env, Expr e);
|
||||
|
||||
/*! \brief The base container type of Relay values. */
|
||||
class ValueNode : public RelayNode {
|
||||
public:
|
||||
static constexpr const char* _type_key = "relay.Value";
|
||||
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
|
||||
};
|
||||
|
||||
class Value : public NodeRef {
|
||||
public:
|
||||
Value() {}
|
||||
explicit Value(NodePtr<Node> n) : NodeRef(n) {}
|
||||
const ValueNode* operator->() const {
|
||||
return static_cast<const ValueNode*>(node_.get());
|
||||
}
|
||||
|
||||
using ContainerType = ValueNode;
|
||||
};
|
||||
|
||||
/*! \brief A Relay closure, i.e a scope and a function. */
|
||||
class Closure;
|
||||
|
||||
/*! \brief The container type of Closures. */
|
||||
class ClosureNode : public ValueNode {
|
||||
public:
|
||||
/*! \brief The set of free variables in the closure.
|
||||
*
|
||||
* These are the captured variables which are required for
|
||||
* evaluation when we call the closure.
|
||||
*/
|
||||
tvm::Map<Var, Value> env;
|
||||
/*! \brief The function which implements the closure.
|
||||
*
|
||||
* \note May reference the variables contained in the env.
|
||||
*/
|
||||
Function func;
|
||||
|
||||
ClosureNode() {}
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("env", &env);
|
||||
v->Visit("func", &func);
|
||||
}
|
||||
|
||||
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Closure";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
|
||||
};
|
||||
|
||||
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
|
||||
|
||||
/*! \brief A tuple value. */
|
||||
class TupleValue;
|
||||
|
||||
/*! \brief Tuple (x, ... y). */
|
||||
struct TupleValueNode : ValueNode {
|
||||
tvm::Array<Value> fields;
|
||||
|
||||
TupleValueNode() {}
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
|
||||
|
||||
TVM_DLL static TupleValue make(tvm::Array<Value> value);
|
||||
|
||||
static constexpr const char* _type_key = "relay.TupleValue";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
|
||||
};
|
||||
|
||||
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);
|
||||
|
||||
/*! \brief A tensor value. */
|
||||
class TensorValue;
|
||||
|
||||
/*! \brief The tensor value container, wrapping an NDArray. */
|
||||
struct TensorValueNode : ValueNode {
|
||||
runtime::NDArray data;
|
||||
|
||||
TensorValueNode() {}
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }
|
||||
|
||||
/*! \brief Build a value from an NDArray. */
|
||||
TVM_DLL static TensorValue make(runtime::NDArray data);
|
||||
|
||||
/*! \brief Construct an empty tensor value from t. */
|
||||
TVM_DLL static TensorValue FromType(const Type& t);
|
||||
|
||||
static constexpr const char* _type_key = "relay.TensorValue";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
|
||||
};
|
||||
|
||||
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
|
||||
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_INTERPRETER_H_
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <string>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
@ -20,7 +21,8 @@ namespace relay {
|
|||
* populated with the result type.
|
||||
*
|
||||
* \param expr The expression to type check.
|
||||
* \param env The environment used for referencing global functions, can be None.
|
||||
* \param env The environment used for referencing global functions, can be
|
||||
* None.
|
||||
*
|
||||
* \return A type checked expression with its checked_type field populated.
|
||||
*/
|
||||
|
@ -35,7 +37,8 @@ Expr InferType(const Expr& expr, const Environment& env);
|
|||
* \return A type checked Function with its checked_type field populated.
|
||||
* \note this function mutates env and is not thread-safe.
|
||||
*/
|
||||
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);
|
||||
Function InferType(const Function& f, const Environment& env,
|
||||
const GlobalVar& var);
|
||||
|
||||
/*!
|
||||
* \brief Check that types are well kinded by applying "kinding rules".
|
||||
|
@ -94,28 +97,30 @@ bool AlphaEqual(const Type& t1, const Type& t2);
|
|||
*
|
||||
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
|
||||
*
|
||||
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed.
|
||||
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
|
||||
* although x is not shadowed.
|
||||
*
|
||||
* \param e the expression to check.
|
||||
* \param expr the expression to check.
|
||||
*
|
||||
* \return true iff all Var in e is bound at most once.
|
||||
* \return true iff all Var in expr is bound at most once.
|
||||
*/
|
||||
bool WellFormed(const Expr& e);
|
||||
bool WellFormed(const Expr& expr);
|
||||
|
||||
/*! \brief Get free Vars from expr in PostDFS order.
|
||||
/*! \brief Get free type parameters from expression expr.
|
||||
*
|
||||
* Free variables are variables that are not bound by a
|
||||
* let or a function parameter in the context.
|
||||
*
|
||||
* \param expr the expression.
|
||||
*
|
||||
* \return List of free vars, in the PostDFS order visited by expr.
|
||||
* \return List of free vars, in the PostDFS order in the expression.
|
||||
*/
|
||||
tvm::Array<Var> FreeVars(const Expr& expr);
|
||||
|
||||
/*! \brief Get free TypeVars from expression expr.
|
||||
*
|
||||
* Free type parameters are type parameters that are not bound by a function type in the context.
|
||||
* Free type parameters are type parameters that are not bound by a function
|
||||
* type in the context.
|
||||
*
|
||||
* \param expr the expression.
|
||||
*
|
||||
|
@ -125,10 +130,12 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
|
|||
|
||||
/*! \brief Remove expressions which does not effect the program result.
|
||||
*
|
||||
* It will remove let binding that are not referenced, and if branch that are not entered.
|
||||
* It will remove let bindings which are not referenced, and branches that will
|
||||
* not be entered.
|
||||
*
|
||||
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a.
|
||||
* Another example is `if (true) then 1 else 2` will be optimized into 1.
|
||||
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
|
||||
* the expression does not depend on a. Another example is `if (true) then 1
|
||||
* else 2` will be optimized into 1.
|
||||
*
|
||||
* \param e the expression to optimize.
|
||||
*
|
||||
|
@ -136,6 +143,8 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
|
|||
*/
|
||||
Expr DeadCodeElimination(const Expr& e);
|
||||
|
||||
/*! \brief A hashing structure in the style of std::hash. */
|
||||
struct StructuralHash {
|
||||
/*! \brief Hash a Relay type.
|
||||
*
|
||||
* Implements structural hashing of a Relay type.
|
||||
|
@ -144,7 +153,7 @@ Expr DeadCodeElimination(const Expr& e);
|
|||
*
|
||||
* \return the hash value.
|
||||
*/
|
||||
size_t StructuralHash(const Type& type);
|
||||
size_t operator()(const Type& type) const;
|
||||
|
||||
/*! \brief Hash a Relay expression.
|
||||
*
|
||||
|
@ -154,9 +163,10 @@ size_t StructuralHash(const Type& type);
|
|||
*
|
||||
* \return the hash value.
|
||||
*/
|
||||
size_t StructuralHash(const Expr& expr);
|
||||
|
||||
size_t operator()(const Expr& expr) const;
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_RELAY_PASS_H_
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
|
||||
"""The Relay IR namespace containing the IR definition and compiler."""
|
||||
from __future__ import absolute_import
|
||||
from ..api import register_func
|
||||
from . import base
|
||||
from . import ty
|
||||
from . import expr
|
||||
|
@ -15,6 +17,7 @@ from . import nn
|
|||
from . import vision
|
||||
from . import image
|
||||
|
||||
|
||||
from .scope_builder import ScopeBuilder
|
||||
|
||||
# Span
|
||||
|
@ -46,6 +49,21 @@ Let = expr.Let
|
|||
If = expr.If
|
||||
TupleGetItem = expr.TupleGetItem
|
||||
|
||||
|
||||
# helper functions
|
||||
var = expr.var
|
||||
const = expr.const
|
||||
|
||||
@register_func("relay._tensor_value_repr")
|
||||
def _tensor_value_repr(tv):
|
||||
return str(tv.data.asnumpy())
|
||||
|
||||
@register_func("relay._constant_repr")
|
||||
def _tensor_constant_repr(tv):
|
||||
return str(tv.data.asnumpy())
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@register_func("relay.debug")
|
||||
def _debug(*args):
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
"""The interface to the Evaluator exposed from C++."""
|
||||
from tvm._ffi.function import _init_api
|
||||
|
||||
_init_api("relay._interpreter", __name__)
|
|
@ -45,9 +45,12 @@ class Environment(RelayNode):
|
|||
func: Function
|
||||
The function.
|
||||
"""
|
||||
return self._add(var, func)
|
||||
|
||||
def _add(self, var, func, update=False):
|
||||
if isinstance(var, _base.string_types):
|
||||
var = _expr.GlobalVar(var)
|
||||
_env.Environment_Add(self, var, func)
|
||||
return _env.Environment_Add(self, var, func, update)
|
||||
|
||||
def __getitem__(self, var):
|
||||
"""Lookup a global function by name or by variable.
|
||||
|
|
|
@ -0,0 +1,551 @@
|
|||
"""
|
||||
A compiler from a Relay expression to TVM's graph runtime.
|
||||
|
||||
The compiler is built from a few pieces.
|
||||
|
||||
First we define a compiler from a single Relay expression to the
|
||||
graph langauge. We require the expression to be a function.
|
||||
The function's parameters correpond to the placeholder/inputs
|
||||
and model parameters found in the computation graph representation.
|
||||
The body of the function represents the computation graph.
|
||||
|
||||
The compiler's output is a program in the graph language, which is composed of
|
||||
graph langauge is composed of Node, NodeRef, InputNode, OpNode.
|
||||
This "little language" represents programs in TVM's graph format.
|
||||
|
||||
To connect to the graph runtime, we use a printer that converts our graph format
|
||||
into TVM's JSON format. The resulting string can be loaded by
|
||||
contrib.graph_runtime or any other TVM runtime comptatible system.
|
||||
|
||||
We expose this functionality in compile_to_tvm.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
import json
|
||||
import attr
|
||||
from . import ir_pass
|
||||
from .op import Op
|
||||
from .expr import Var, Function, Call, If, GlobalVar, Constant, Let, Tuple
|
||||
from ..build_module import build as tvm_build_module
|
||||
from .. contrib import graph_runtime
|
||||
from .ir_pass import infer_type
|
||||
from .. import cpu
|
||||
|
||||
class AbstractExprVisitor(object):
|
||||
"""A visitor over Expr in Python."""
|
||||
|
||||
def __init__(self):
|
||||
self.memo_map = {}
|
||||
|
||||
# pylint: disable=no-else-return
|
||||
def visit(self, expr):
|
||||
"""Apply the visitor to an expression."""
|
||||
found = self.memo_map.get(expr)
|
||||
if found:
|
||||
return found
|
||||
|
||||
if isinstance(expr, Function):
|
||||
res = self.visit_function(expr)
|
||||
elif isinstance(expr, Call):
|
||||
res = self.visit_call(expr)
|
||||
elif isinstance(expr, Let):
|
||||
res = self.visit_let(expr)
|
||||
elif isinstance(expr, Var):
|
||||
res = self.visit_var(expr)
|
||||
elif isinstance(expr, GlobalVar):
|
||||
res = self.visit_global_var(expr)
|
||||
elif isinstance(expr, If):
|
||||
res = self.visit_if(expr)
|
||||
elif isinstance(expr, Tuple):
|
||||
res = self.visit_tuple(expr)
|
||||
elif isinstance(expr, Constant):
|
||||
res = self.visit_constant(expr)
|
||||
else:
|
||||
raise Exception("warning unhandled case: {0}".format(type(expr)))
|
||||
|
||||
self.memo_map[expr] = res
|
||||
return res
|
||||
|
||||
def visit_function(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_let(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_call(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_var(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_type(self, typ):
|
||||
return typ
|
||||
|
||||
def visit_if(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_tuple(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_constant(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
def visit_global_var(self, _):
|
||||
raise Exception("Abstract method please implement me.")
|
||||
|
||||
|
||||
class ExprMutator(AbstractExprVisitor):
|
||||
"""A functional visitor over Expr in Python."""
|
||||
|
||||
def visit_function(self, fn):
|
||||
new_body = self.visit(fn.body)
|
||||
return Function(
|
||||
list(fn.params),
|
||||
fn.ret_type, new_body,
|
||||
fn.type_params)
|
||||
|
||||
def visit_let(self, let):
|
||||
new_var = self.visit(let.var)
|
||||
new_val = self.visit(let.value)
|
||||
new_body = self.visit(let.body)
|
||||
return Let(new_var, new_val, new_body)
|
||||
|
||||
def visit_call(self, call):
|
||||
new_fn = self.visit(call.op)
|
||||
new_args = [self.visit(arg) for arg in call.args]
|
||||
return Call(new_fn, new_args, call.attrs)
|
||||
|
||||
def visit_var(self, var):
|
||||
return var
|
||||
|
||||
def visit_global_id(self, global_var):
|
||||
return global_var
|
||||
|
||||
def visit_if(self, ite):
|
||||
return If(
|
||||
self.visit(ite.guard),
|
||||
self.visit(ite.true_b),
|
||||
self.visit(ite.false_b))
|
||||
|
||||
def visit_tuple(self, tup):
|
||||
return Tuple([self.visit(field) for field in tup.fields])
|
||||
|
||||
def visit_constant(self, const):
|
||||
return const
|
||||
|
||||
|
||||
@attr.s
|
||||
class NodeRef(object):
|
||||
"""A reference to a node, used for constructing the graph."""
|
||||
ident = attr.ib()
|
||||
index = attr.ib(default=0)
|
||||
version = attr.ib(default=0)
|
||||
|
||||
def to_json(self):
|
||||
return [self.ident, self.index, self.version]
|
||||
|
||||
|
||||
@attr.s
|
||||
class Node(object):
|
||||
"""The base class for nodes in the TVM runtime system graph input."""
|
||||
name = attr.ib()
|
||||
attrs = attr.ib()
|
||||
is_output = attr.ib()
|
||||
|
||||
def to_json(self):
|
||||
raise Exception("Abstract method, please implement me.")
|
||||
|
||||
|
||||
@attr.s
|
||||
class InputNode(Node):
|
||||
"""An input node in the TVM runtime system graph input."""
|
||||
name = attr.ib()
|
||||
attrs = attr.ib()
|
||||
is_output = attr.ib(default=False)
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"op": "null",
|
||||
"name": self.name,
|
||||
"inputs": []
|
||||
}
|
||||
|
||||
|
||||
@attr.s
|
||||
class OpNode(Node):
|
||||
"""An operator node in the TVM runtime system's graph input."""
|
||||
op_name = attr.ib()
|
||||
inputs = attr.ib()
|
||||
op_attrs = attr.ib()
|
||||
is_output = attr.ib(default=False)
|
||||
|
||||
def to_json(self):
|
||||
attrs = dict.copy(self.op_attrs)
|
||||
# Extend ops with extra info.
|
||||
attrs['func_name'] = self.op_name
|
||||
# When do we flatten?
|
||||
attrs['flatten_data'] = "0"
|
||||
# Fix me!
|
||||
attrs['num_inputs'] = str(len(self.inputs))
|
||||
attrs['num_outputs'] = "1"
|
||||
|
||||
return {
|
||||
"op": "tvm_op",
|
||||
"name": self.name,
|
||||
"attrs": attrs,
|
||||
"inputs": self.inputs
|
||||
}
|
||||
|
||||
|
||||
def shape_to_json(shape):
|
||||
return [sh.value for sh in shape]
|
||||
|
||||
|
||||
def from_tensor(typ):
|
||||
return (typ.dtype, shape_to_json(typ.shape))
|
||||
|
||||
|
||||
class GraphRuntimeCodegen(ExprMutator):
|
||||
"""The compiler from Relay to the TVM runtime system."""
|
||||
nodes = attr.ib()
|
||||
id_map = attr.ib()
|
||||
|
||||
def __init__(self, env):
|
||||
ExprMutator.__init__(self)
|
||||
self.nodes = []
|
||||
self.id_map = {}
|
||||
self.env = env
|
||||
|
||||
def add_node(self, node):
|
||||
"""
|
||||
Add a node to the graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node: Node
|
||||
The node to add to the graph.
|
||||
|
||||
Returns
|
||||
-------
|
||||
node_ref: NodeRef
|
||||
A reference to the node.
|
||||
|
||||
"""
|
||||
self.nodes.append(node)
|
||||
ident = len(self.nodes) - 1
|
||||
return NodeRef(ident)
|
||||
|
||||
def add_binding(self, ident, ref):
|
||||
"""
|
||||
Add a identifier to node mapping.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ident: relay.Var
|
||||
The variable to map
|
||||
|
||||
ref: NodeRef
|
||||
The node the identifier points.
|
||||
"""
|
||||
self.id_map[ident] = ref
|
||||
|
||||
def let_bind(self, ident, node):
|
||||
"""
|
||||
Let bind node to ident.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ident: relay.Var
|
||||
The variable to map.
|
||||
|
||||
ref: NodeRef
|
||||
The node the identifier points.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ref: NodeRef
|
||||
Return reference to the node.
|
||||
"""
|
||||
ref = self.add_node(node)
|
||||
self.add_binding(ident, ref)
|
||||
return ref
|
||||
|
||||
def get_node(self, ref):
|
||||
"""
|
||||
Lookup a node by a node reference.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ref: NodeRef
|
||||
The reference to lookup.
|
||||
|
||||
Returns
|
||||
-------
|
||||
node: Node
|
||||
The node.
|
||||
"""
|
||||
return self.nodes[ref.ident]
|
||||
|
||||
def lookup(self, ident):
|
||||
"""
|
||||
Lookup a node by identifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ident: relay.Var
|
||||
The reference to lookup.
|
||||
|
||||
Returns
|
||||
-------
|
||||
node: Node
|
||||
The node.
|
||||
"""
|
||||
return self.id_map[ident]
|
||||
|
||||
def codegen(self, func):
|
||||
"""Compile a single function into a graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: tvm.relay.Expr
|
||||
The function to compile.
|
||||
"""
|
||||
# First we convert all the parameters into input nodes.
|
||||
params = func.params
|
||||
|
||||
for param in params:
|
||||
dtype, shape = from_tensor(param.type_annotation)
|
||||
node = InputNode("{0}".format(param.name_hint), {
|
||||
"shape": shape,
|
||||
"dtype": dtype,
|
||||
})
|
||||
self.let_bind(param, node)
|
||||
|
||||
# Then we compile the body into a graph which can depend
|
||||
# on input variables.
|
||||
output_ref = self.visit(func.body)
|
||||
|
||||
# Finally we retreive return value of program, which will
|
||||
# become our output node.
|
||||
self.get_node(output_ref).is_output = True
|
||||
|
||||
def visit_let(self, let):
|
||||
"""
|
||||
Visit the let binding, by first traversing its value,
|
||||
then setting the metadata on the returned NodeRef.
|
||||
|
||||
Finally visit the body, and return the NodeRef corresponding
|
||||
to it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
let: tvm.relay.Expr
|
||||
The let binding to transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ref: NodeRef
|
||||
The node reference to the body.
|
||||
"""
|
||||
ident = let.var
|
||||
val = let.value
|
||||
body = let.body
|
||||
|
||||
val_ref = self.visit(val)
|
||||
dtype, shape = from_tensor(val.checked_type())
|
||||
val_node = self.get_node(val_ref)
|
||||
val_node.attrs["dtype"] = dtype
|
||||
val_node.attrs["shape"] = shape
|
||||
self.add_binding(ident, val_ref)
|
||||
return self.visit(body)
|
||||
|
||||
def visit_var(self, var):
|
||||
return self.lookup(var)
|
||||
|
||||
def visit_call(self, call):
|
||||
"""Transform a ::tvm.relay.Call into an operator in the TVM graph."""
|
||||
inputs = []
|
||||
for arg in call.args:
|
||||
inputs.append(self.visit(arg).to_json())
|
||||
|
||||
if isinstance(call.op, Op):
|
||||
raise Exception(
|
||||
"Operators should be transformed away; try applying" +
|
||||
"the fuse_ops transformation to the expression.")
|
||||
elif isinstance(call.op, GlobalVar):
|
||||
func = self.env[call.op]
|
||||
elif isinstance(call.op, Function):
|
||||
func = call.op
|
||||
else:
|
||||
raise Exception(
|
||||
"TVM runtime does not support calls to {0}".format(type(call.op)))
|
||||
|
||||
if int(func.attrs.Primitive) != 1:
|
||||
raise Exception(
|
||||
"TVM only support calls to primitive functions " +
|
||||
"(i.e functions composed of fusable operator invocations)")
|
||||
|
||||
op_name = func.attrs.LoweredFunc.name
|
||||
|
||||
attrs = {'shape': shape_to_json(call.checked_type.shape),
|
||||
'dtype': call.checked_type.dtype}
|
||||
call_hash = str(ir_pass.structural_hash(call))
|
||||
op_node = OpNode("call_" + call_hash, attrs, op_name, inputs, {})
|
||||
return self.add_node(op_node)
|
||||
|
||||
def to_json(self):
|
||||
"""
|
||||
Convert the sequence of nodes stored by the compiler into the
|
||||
TVM graph runtime format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
graph_json : str
|
||||
The generated JSON as a string.
|
||||
"""
|
||||
nodes = []
|
||||
# First we compute "nodes" field.
|
||||
for node in self.nodes:
|
||||
nodes.append(node.to_json())
|
||||
|
||||
arg_nodes = []
|
||||
heads = []
|
||||
# Compute "arg_nodes" and "heads" fields.
|
||||
for i, node in enumerate(self.nodes):
|
||||
if isinstance(node, InputNode):
|
||||
arg_nodes.append(i)
|
||||
|
||||
if node.is_output:
|
||||
# Need to fix this.
|
||||
heads.append(NodeRef(i).to_json())
|
||||
|
||||
def compute_node_row_ptr(nodes):
|
||||
"""Calculate the node_row_ptr field by doing a DFS backwards
|
||||
from the output and reversing the path.
|
||||
"""
|
||||
row_ptr = [len(nodes)]
|
||||
discovered = set()
|
||||
stack = []
|
||||
stack.append(len(nodes) - 1)
|
||||
while stack:
|
||||
i = stack.pop()
|
||||
if i not in discovered:
|
||||
discovered.add(i)
|
||||
row_ptr.append(i)
|
||||
node = nodes[i]
|
||||
if isinstance(node, OpNode):
|
||||
for inp in node.inputs:
|
||||
stack.append(inp[0])
|
||||
row_ptr.reverse()
|
||||
return row_ptr
|
||||
|
||||
# Compute "node_row_ptr".
|
||||
node_row_ptr = compute_node_row_ptr(self.nodes)
|
||||
|
||||
# Compute "attrs" field.
|
||||
attrs = {}
|
||||
|
||||
# These fields are mandatory.
|
||||
shapes = []
|
||||
storage_ids = []
|
||||
dtype = []
|
||||
dltype = []
|
||||
|
||||
for i, node in enumerate(self.nodes):
|
||||
storage_ids.append(i)
|
||||
shapes.append(node.attrs['shape'])
|
||||
if node.attrs['dtype'] == 'float32':
|
||||
dtype.append(0)
|
||||
dltype.append('float32')
|
||||
|
||||
attrs["shape"] = ["list_shape", shapes]
|
||||
attrs["storage_id"] = ["list_int", storage_ids]
|
||||
attrs["dtype"] = ["list_int", dtype]
|
||||
attrs["dltype"] = ["list_str", dltype]
|
||||
|
||||
json_dict = {
|
||||
"nodes": nodes,
|
||||
"arg_nodes": arg_nodes,
|
||||
"heads": heads,
|
||||
"attrs": attrs,
|
||||
"node_row_ptr": node_row_ptr
|
||||
}
|
||||
|
||||
return json.dumps(json_dict)
|
||||
|
||||
|
||||
def build(env, func, target=None):
|
||||
"""
|
||||
Compile a single function to the components needed by the
|
||||
TVM RTS.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: relay.Expr
|
||||
The function to build.
|
||||
|
||||
target: optional str
|
||||
The target platform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
|
||||
The outputs of building a Relay function for the TVM runtime.
|
||||
|
||||
"""
|
||||
if target is None:
|
||||
target = 'llvm'
|
||||
|
||||
comp = GraphRuntimeCodegen(env)
|
||||
# NB(@jroesch) This creates lowered functions, and generates names for them
|
||||
#
|
||||
# We need these names to emit the correct graph as these are names of the
|
||||
# functions contained in the module.
|
||||
lowered_ops = ir_pass.lower_ops(env, func)
|
||||
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
|
||||
|
||||
# Therefore the call to compile must come after.
|
||||
comp.codegen(func)
|
||||
graph_json = comp.to_json()
|
||||
return graph_json, mod, None # params currently isn't supported by API
|
||||
|
||||
|
||||
def graph_evaluate(env, func, *args):
|
||||
"""
|
||||
Corresponding function to tvm.relay.eval.evaluate.
|
||||
|
||||
This function evaluates a Relay expression on the
|
||||
TVM graph_runtime.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env: tvm.relay.Environment
|
||||
The global environment used.
|
||||
|
||||
expr: tvm.relay.Expr
|
||||
The expression to evaluate.
|
||||
|
||||
args: list of tvm.relay.Expr
|
||||
The arguments to apply to the expression, only works
|
||||
if the expression has a function type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
value: tvm.NDArray
|
||||
The output Tensor produced by evaluating the expression.
|
||||
"""
|
||||
func = infer_type(func, env)
|
||||
func = ir_pass.fuse_ops(env, func)
|
||||
func = infer_type(func, env)
|
||||
graph_json, mod, params = build(env, func)
|
||||
assert params is None
|
||||
gmodule = graph_runtime.create(graph_json, mod, cpu(0))
|
||||
# Create map of inputs.
|
||||
inputs = {}
|
||||
for i, arg in enumerate(args):
|
||||
inputs[func.params[i].name_hint] = arg
|
||||
# Set the inputs here.
|
||||
gmodule.set_input(**inputs)
|
||||
# Run the module, and fetch the output.
|
||||
gmodule.run()
|
||||
return gmodule.get_output(0)
|
|
@ -0,0 +1,130 @@
|
|||
#pylint: disable=no-else-return
|
||||
"""An interface to the Realy interpreter."""
|
||||
from __future__ import absolute_import
|
||||
import numpy as np
|
||||
from .. import register_func, nd
|
||||
from .base import NodeBase, register_relay_node
|
||||
from . import _make
|
||||
from . import _interpreter
|
||||
from . import ir_pass
|
||||
from .expr import Call, Constant, GlobalVar
|
||||
from . import const
|
||||
from .._ffi.base import integer_types
|
||||
|
||||
class Value(NodeBase):
|
||||
"""Base class of all values.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@register_func("relay.from_scalar")
|
||||
def from_scalar(i, dtype=None):
|
||||
"""Convert a Python scalar to a Relay scalar."""
|
||||
if dtype is None:
|
||||
if isinstance(i, integer_types):
|
||||
dtype = 'int32'
|
||||
elif isinstance(i, float):
|
||||
dtype = 'float32'
|
||||
elif isinstance(i, bool):
|
||||
dtype = 'uint8'
|
||||
else:
|
||||
raise Exception("unable to infer dtype {0}".format(type(i)))
|
||||
|
||||
return TensorValue(nd.array(np.array(i, dtype=dtype)))
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class TupleValue(Value):
|
||||
def __init__(self, *fields):
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.TupleValue, fields)
|
||||
|
||||
def __getitem__(self, field_no):
|
||||
return self.fields[field_no]
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Closure(Value):
|
||||
pass
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class TensorValue(Value):
|
||||
"""A Tensor value produced by the evaluator."""
|
||||
|
||||
def __init__(self, data):
|
||||
"""Allocate a new TensorValue and copy the data from `array` into
|
||||
the new array.
|
||||
"""
|
||||
if isinstance(data, np.ndarray):
|
||||
data = nd.array(data)
|
||||
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.TensorValue, data)
|
||||
|
||||
def as_ndarray(self):
|
||||
"""Convert a Relay TensorValue into a tvm.ndarray."""
|
||||
return self.data
|
||||
|
||||
def asnumpy(self):
|
||||
"""Convert a Relay TensorValue into a numpy.ndarray."""
|
||||
return self.data.asnumpy()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.data == other.data
|
||||
|
||||
|
||||
def _arg_to_ast(arg):
|
||||
if isinstance(arg, TensorValue):
|
||||
return Constant(arg.data)
|
||||
elif isinstance(arg, np.ndarray):
|
||||
return Constant(nd.array(arg))
|
||||
elif isinstance(arg, Constant):
|
||||
return arg
|
||||
else:
|
||||
return const(arg)
|
||||
|
||||
|
||||
def apply_passes(expr, env=None):
|
||||
ck_expr = ir_pass.infer_type(expr, env=env)
|
||||
fused_expr = ir_pass.fuse_ops(env, ck_expr)
|
||||
return fused_expr
|
||||
|
||||
|
||||
def evaluate(env, expr, *args):
|
||||
"""
|
||||
Evaluate a Relay expression on the interpreter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env: tvm.relay.Environment
|
||||
The global environment used.
|
||||
|
||||
expr: tvm.relay.Expr
|
||||
The expression to evaluate.
|
||||
|
||||
args: list of tvm.relay.Expr
|
||||
The arguments to apply to the expression, only works
|
||||
if the expression has a function type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
value: tvm.relay.eval.Value
|
||||
The value produced by evaluating the expression.
|
||||
"""
|
||||
# assert len(args) == 0
|
||||
relay_args = []
|
||||
for arg in args:
|
||||
relay_args.append(_arg_to_ast(arg))
|
||||
|
||||
# TODO: We need to move this optimization code into the optimizer/pass manager
|
||||
if isinstance(expr, GlobalVar):
|
||||
func = env[expr]
|
||||
func = apply_passes(func, env)
|
||||
env._add(expr, func, True)
|
||||
opt_expr = Call(expr, relay_args)
|
||||
# import pdb; pdb.set_trace()
|
||||
return _interpreter.evaluate(env, opt_expr)
|
||||
else:
|
||||
expr = Call(expr, relay_args)
|
||||
opt_expr = apply_passes(expr, env)
|
||||
return _interpreter.evaluate(env, opt_expr)
|
|
@ -240,3 +240,9 @@ def structural_hash(value):
|
|||
msg = ("found value of type {0} expected" +
|
||||
"relay.Expr or relay.Type").format(type(value))
|
||||
raise TypeError(msg)
|
||||
|
||||
def fuse_ops(expr, env):
|
||||
return _ir_pass.FuseOps(env, expr)
|
||||
|
||||
def lower_ops(env, expr, target='llvm'):
|
||||
return _ir_pass.LowerOps(env, expr, target)
|
||||
|
|
|
@ -1,2 +1,49 @@
|
|||
#pylint: disable=invalid-name
|
||||
#pylint: disable=invalid-name, unused-argument
|
||||
"""Backend compiler related feature registration"""
|
||||
import tvm
|
||||
import topi
|
||||
from . import register
|
||||
|
||||
def add_compute(attrs, inputs, output_type, target):
|
||||
assert len(inputs) == 2
|
||||
return [topi.add(inputs[0], inputs[1])]
|
||||
|
||||
def add_schedule(outputs, target):
|
||||
assert len(outputs) == 1
|
||||
return tvm.create_schedule(outputs[0].op)
|
||||
|
||||
register("add", "FTVMCompute", add_compute)
|
||||
register("add", "FTVMSchedule", add_schedule)
|
||||
|
||||
def subtract_compute(attrs, inputs, output_type, target):
|
||||
assert len(inputs) == 2
|
||||
return [topi.subtract(inputs[0], inputs[1])]
|
||||
|
||||
def subtract_schedule(outputs, target):
|
||||
assert len(outputs) == 1
|
||||
return tvm.create_schedule(outputs[0].op)
|
||||
|
||||
register("subtract", "FTVMCompute", subtract_compute)
|
||||
register("subtract", "FTVMSchedule", subtract_schedule)
|
||||
|
||||
def multiply_compute(attrs, inputs, output_type, target):
|
||||
assert len(inputs) == 2
|
||||
return [topi.multiply(inputs[0], inputs[1])]
|
||||
|
||||
def multiply_schedule(outputs, target):
|
||||
assert len(outputs) == 1
|
||||
return tvm.create_schedule(outputs[0].op)
|
||||
|
||||
register("multiply", "FTVMCompute", multiply_compute)
|
||||
register("multiply", "FTVMSchedule", multiply_schedule)
|
||||
|
||||
def equal_compute(attrs, inputs, output_type, target):
|
||||
assert len(inputs) == 2
|
||||
return [topi.equal(inputs[0], inputs[1])]
|
||||
|
||||
def equal_schedule(outputs, target):
|
||||
assert len(outputs) == 1
|
||||
return tvm.create_schedule(outputs[0].op)
|
||||
|
||||
register("equal", "FTVMCompute", equal_compute)
|
||||
register("equal", "FTVMSchedule", equal_schedule)
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#pylint: disable=invalid-name, unused-argument
|
||||
"""Backend compiler related feature registration"""
|
||||
import tvm
|
||||
import topi
|
||||
from .. import register
|
||||
|
||||
def dense_compiler(attrs, inputs, output_type):
|
||||
assert len(inputs) == 2
|
||||
return [topi.nn.dense(inputs[0], inputs[1])]
|
||||
|
||||
def dense_schedule(outputs, target):
|
||||
assert len(outputs) == 1
|
||||
return tvm.create_schedule(outputs[0].op)
|
||||
|
||||
register("nn.dense", "FTVMCompute", dense_compiler)
|
||||
register("nn.dense", "FTVMSchedule", dense_schedule)
|
|
@ -3,7 +3,8 @@ from ..._ffi.function import _init_api
|
|||
|
||||
from ..base import register_relay_node
|
||||
from ..expr import Expr
|
||||
|
||||
from ...api import register_func
|
||||
from ...build_module import lower, build
|
||||
|
||||
@register_relay_node
|
||||
class Op(Expr):
|
||||
|
@ -75,3 +76,11 @@ def register(op_name, attr_key, value=None, level=10):
|
|||
|
||||
|
||||
_init_api("relay.op", __name__)
|
||||
|
||||
@register_func("relay.op.compiler._lower")
|
||||
def _lower(name, schedule, inputs, outputs):
|
||||
return lower(schedule, list(inputs) + list(outputs), name=name)
|
||||
|
||||
@register_func("relay.op.compiler._build")
|
||||
def _build(lowered_funcs):
|
||||
return build(lowered_funcs, target="llvm")
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
"""
|
||||
a simple multilayer perceptron
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from tvm import relay
|
||||
from .init import create_workload
|
||||
|
||||
|
|
|
@ -0,0 +1,432 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file src/tvm/relay/interpreter.cc
|
||||
* \brief An interpreter for the Relay IR.
|
||||
*/
|
||||
|
||||
#include <tvm/codegen.h>
|
||||
#include <tvm/packed_func_ext.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/interpreter.h>
|
||||
#include <tvm/relay/logging.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/build_module.h>
|
||||
#include "./ir/type_functor.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
using namespace runtime;
|
||||
|
||||
inline const PackedFunc& GetPackedFunc(const std::string& name) {
|
||||
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
|
||||
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
|
||||
return *pf;
|
||||
}
|
||||
|
||||
/* Value Implementation */
|
||||
Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
|
||||
NodePtr<ClosureNode> n = make_node<ClosureNode>();
|
||||
n->env = std::move(env);
|
||||
n->func = std::move(func);
|
||||
return Closure(n);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._make.Closure")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = ClosureNode::make(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
|
||||
p->stream << "ClosureNode(" << node->func << ")";
|
||||
});
|
||||
|
||||
TupleValue TupleValueNode::make(tvm::Array<Value> value) {
|
||||
NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
|
||||
n->fields = value;
|
||||
return TupleValue(n);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._make.TupleValue")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = TupleValueNode::make(args[0]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<TupleValueNode>([](const TupleValueNode* node,
|
||||
tvm::IRPrinter* p) {
|
||||
p->stream << "TupleValueNode(" << node->fields << ")";
|
||||
});
|
||||
|
||||
TensorValue TensorValueNode::make(runtime::NDArray data) {
|
||||
NodePtr<TensorValueNode> n = make_node<TensorValueNode>();
|
||||
n->data = std::move(data);
|
||||
return TensorValue(n);
|
||||
}
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<TensorValueNode>([](const TensorValueNode* node,
|
||||
tvm::IRPrinter* p) {
|
||||
auto to_str = GetPackedFunc("relay._tensor_value_repr");
|
||||
std::string data_str = to_str(GetRef<TensorValue>(node));
|
||||
p->stream << "TensorValueNode(" << data_str << ")";
|
||||
});
|
||||
|
||||
TensorValue TensorValueNode::FromType(const Type& t) {
|
||||
if (auto tt_node = t.as<TensorTypeNode>()) {
|
||||
std::vector<int64_t> dims;
|
||||
|
||||
for (auto dim : tt_node->shape) {
|
||||
auto int_node = dim.as<tvm::ir::IntImm>();
|
||||
CHECK(int_node) << "expected concrete dimensions";
|
||||
dims.push_back(int_node->value);
|
||||
}
|
||||
|
||||
DLDataType dtype;
|
||||
DLContext context;
|
||||
|
||||
switch (tt_node->dtype.code()) {
|
||||
case halideir_type_int:
|
||||
dtype.code = kDLInt;
|
||||
break;
|
||||
case halideir_type_uint:
|
||||
dtype.code = kDLUInt;
|
||||
break;
|
||||
case halideir_type_float:
|
||||
dtype.code = kDLFloat;
|
||||
break;
|
||||
default:
|
||||
throw dmlc::Error("can not convert HalideIR type into DLTensor dtype");
|
||||
}
|
||||
|
||||
dtype.bits = tt_node->dtype.bits();
|
||||
dtype.lanes = tt_node->dtype.lanes();
|
||||
|
||||
// TODO(@jroesch): Is this the right place to place the tensor?
|
||||
context.device_type = DLDeviceType::kDLCPU;
|
||||
context.device_id = 0;
|
||||
runtime::NDArray data = NDArray::Empty(dims, dtype, context);
|
||||
return TensorValueNode::make(data);
|
||||
} else {
|
||||
LOG(FATAL) << "expected a tensor type";
|
||||
return TensorValue();
|
||||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._make.TensorValue")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
runtime::NDArray data = args[0];
|
||||
*ret = TensorValueNode::make(data);
|
||||
});
|
||||
|
||||
/* Evaluator Implementation. */
|
||||
struct EvalError : dmlc::Error {
|
||||
explicit EvalError(const std::string& msg) : Error(msg) {}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A stack frame in the Relay interpreter.
|
||||
*
|
||||
* Contains a mapping from relay::Var to relay::Value.
|
||||
*/
|
||||
struct Frame {
|
||||
/*! \brief The set of local variables and arguments for the frame. */
|
||||
tvm::Map<Var, Value> locals;
|
||||
|
||||
explicit Frame(tvm::Map<Var, Value> locals) : locals(locals) {}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief The call stack in the Relay interpreter.
|
||||
*
|
||||
* Contains a stack of frames; each corresponding to
|
||||
* a function call.
|
||||
*/
|
||||
struct Stack {
|
||||
/*! \brief The stack frames. */
|
||||
std::vector<Frame> frames;
|
||||
Stack() : frames() { frames.push_back(Frame({})); }
|
||||
|
||||
Frame& current_frame() { return frames.back(); }
|
||||
|
||||
Value Lookup(const Var& local) {
|
||||
for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
|
||||
auto elem = frame->locals.find(local);
|
||||
if (elem != frame->locals.end()) {
|
||||
return (*elem).second;
|
||||
}
|
||||
}
|
||||
|
||||
LOG(FATAL) << "could not find variable binding for " << local
|
||||
<< "address= " << local.operator->();
|
||||
return Value();
|
||||
}
|
||||
/*!
|
||||
* A wrapper around Frame to add RAII semantics to pushing and popping
|
||||
* stack frames.
|
||||
*/
|
||||
struct LocalFrame {
|
||||
Stack& st;
|
||||
explicit LocalFrame(Stack& st, const Frame& fr) : st(st) {
|
||||
st.frames.push_back(fr);
|
||||
}
|
||||
~LocalFrame() { st.frames.pop_back(); }
|
||||
};
|
||||
};
|
||||
|
||||
/*! \brief The equal comparator for expressions. */
|
||||
struct ExprEqual {
|
||||
bool operator()(const Expr& a, const Expr& b) const {
|
||||
return AlphaEqual(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||
Environment env;
|
||||
Stack stack;
|
||||
using JitKey = Function;
|
||||
|
||||
using OpMap = std::unordered_map<JitKey, PackedFunc, StructuralHash, ExprEqual>;
|
||||
|
||||
OpMap operator_map_;
|
||||
|
||||
template <typename T>
|
||||
T with_frame(const Frame& fr, const std::function<T()>& f) {
|
||||
Stack::LocalFrame lf(stack, fr);
|
||||
return f();
|
||||
}
|
||||
|
||||
Interpreter(Environment env) : env(env), operator_map_() {}
|
||||
Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {}
|
||||
|
||||
void extend(const Var& id, Value v) {
|
||||
this->stack.current_frame().locals.Set(id, v);
|
||||
}
|
||||
|
||||
inline Value Lookup(const Var& local) {
|
||||
return this->stack.Lookup(local);
|
||||
}
|
||||
|
||||
Value Eval(const Expr& expr) {
|
||||
return (*this)(expr);
|
||||
}
|
||||
|
||||
Value VisitExpr(const Expr& expr) override {
|
||||
RELAY_LOG(INFO) << "VisitExpr: " << expr << std::endl;
|
||||
auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
|
||||
return ret;
|
||||
}
|
||||
|
||||
Value VisitExpr_(const VarNode* var_node) override {
|
||||
return Lookup(GetRef<Var>(var_node));
|
||||
}
|
||||
|
||||
Value VisitExpr_(const GlobalVarNode* op) override {
|
||||
return Eval(this->env->Lookup(GetRef<GlobalVar>(op)));
|
||||
}
|
||||
|
||||
Value VisitExpr_(const OpNode* id) override {
|
||||
// TODO(@jroesch): Eta-expand and return in this case.
|
||||
throw EvalError(
|
||||
"internal error, need to wrap intrinsic into call synthetic call node "
|
||||
"in "
|
||||
"this case, eta expand");
|
||||
}
|
||||
|
||||
Value VisitExpr_(const ConstantNode* op) override {
|
||||
return TensorValueNode::make(op->data);
|
||||
}
|
||||
|
||||
Value VisitExpr_(const TupleNode* op) override {
|
||||
std::vector<Value> values;
|
||||
|
||||
for (const auto& field : op->fields) {
|
||||
Value field_value = Eval(field);
|
||||
values.push_back(field_value);
|
||||
}
|
||||
|
||||
return TupleValueNode::make(values);
|
||||
}
|
||||
|
||||
Value VisitExpr_(const FunctionNode* func_node) override {
|
||||
auto func = GetRef<Function>(func_node);
|
||||
tvm::Map<Var, Value> captured_env;
|
||||
Array<Var> free_vars = FreeVars(func);
|
||||
|
||||
for (const auto& var : free_vars) {
|
||||
captured_env.Set(var, Eval(var));
|
||||
}
|
||||
|
||||
return ClosureNode::make(captured_env, func);
|
||||
}
|
||||
|
||||
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
|
||||
Type ret_type) {
|
||||
// Marshal the arguments.
|
||||
auto arg_len = args.size() + 1;
|
||||
std::vector<TVMValue> values(arg_len);
|
||||
std::vector<int> codes(arg_len);
|
||||
TVMArgsSetter setter(values.data(), codes.data());
|
||||
TVMRetValue ret;
|
||||
|
||||
// We need real type information to properly allocate the structure.
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (const TensorValueNode* tv = args[i].as<TensorValueNode>()) {
|
||||
setter(i, tv->data);
|
||||
}
|
||||
}
|
||||
|
||||
// TVM's calling convention is that the final argument is the output
|
||||
// buffer. To preserve the illusion of being a functional language
|
||||
// we need to allocate space for the output buffer based on the
|
||||
// return type.
|
||||
CHECK(ret_type.as<TensorTypeNode>());
|
||||
|
||||
auto out_tensor = TensorValueNode::FromType(ret_type);
|
||||
|
||||
setter(arg_len - 1, out_tensor->data);
|
||||
func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &ret);
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
|
||||
// Get a reference to the function inside the closure.
|
||||
auto func = closure->func;
|
||||
auto compiled = operator_map_.find(func);
|
||||
tvm::Array<Function> funcs;
|
||||
for (auto op : operator_map_) {
|
||||
funcs.push_back(op.first);
|
||||
}
|
||||
|
||||
// This case we know we have precompiled the operator.
|
||||
if (compiled != operator_map_.end()) {
|
||||
auto func_ty = func->func_type_annotation();
|
||||
return InvokeCompiledOp(compiled->second, args, func_ty->ret_type);
|
||||
}
|
||||
|
||||
// Allocate a frame with the parameters and free variables.
|
||||
tvm::Map<Var, Value> locals;
|
||||
|
||||
CHECK_EQ(func->params.size(), args.size());
|
||||
|
||||
for (size_t i = 0; i < func->params.size(); i++) {
|
||||
CHECK_EQ(locals.count(func->params[i]), 0);
|
||||
locals.Set(func->params[i], args[i]);
|
||||
}
|
||||
|
||||
// Add the var to value mappings from the Closure's environment.
|
||||
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
|
||||
CHECK_EQ(locals.count((*it).first), 0);
|
||||
locals.Set((*it).first, (*it).second);
|
||||
}
|
||||
|
||||
return with_frame<Value>(Frame(locals), [&]() { return Eval(func->body); });
|
||||
}
|
||||
|
||||
Value VisitExpr_(const CallNode* call) override {
|
||||
tvm::Array<Value> args;
|
||||
for (auto arg : call->args) {
|
||||
args.push_back(Eval(arg));
|
||||
}
|
||||
|
||||
// We should not find operators after running fusion,
|
||||
// and operator lowering.
|
||||
//
|
||||
// We have some functions cotaining chunks of operators
|
||||
// which will be loaded into operator map.
|
||||
if (auto op_node = call->op.as<OpNode>()) {
|
||||
LOG(FATAL) << "found " << op_node->name
|
||||
<< "; operators should be removed by future passes; try "
|
||||
"fusing and lowering";
|
||||
}
|
||||
|
||||
// Now we just evaluate and expect to find a closure.
|
||||
Value fn_val = Eval(call->op);
|
||||
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
|
||||
auto closure = GetRef<Closure>(closure_node);
|
||||
return this->Invoke(closure, args);
|
||||
} else {
|
||||
throw EvalError(
|
||||
"internal error: type error, expected function value in the call "
|
||||
"position");
|
||||
}
|
||||
}
|
||||
|
||||
Value VisitExpr_(const LetNode* op) override {
|
||||
auto value = Eval(op->value);
|
||||
this->extend(op->var, value);
|
||||
return Eval(op->body);
|
||||
}
|
||||
|
||||
Value VisitExpr_(const TupleGetItemNode* op) override {
|
||||
Value val = Eval(op->tuple);
|
||||
auto product_node = val.as<TupleValueNode>();
|
||||
CHECK(product_node)
|
||||
<< "interal error: when evaluating TupleGetItem expected a tuple value";
|
||||
CHECK_LT(static_cast<size_t>(op->index), product_node->fields.size())
|
||||
<< "internal error: index out of bounds";
|
||||
return product_node->fields[op->index];
|
||||
}
|
||||
|
||||
Value VisitExpr_(const IfNode* op) override {
|
||||
Value v = Eval(op->cond);
|
||||
if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
|
||||
// TODO(@jroesch, @MK): Refactor code into helper from DCE.
|
||||
if (reinterpret_cast<uint8_t*>(bv->data->data)[0]) {
|
||||
return Eval(op->true_branch);
|
||||
} else {
|
||||
return Eval(op->false_branch);
|
||||
}
|
||||
} else {
|
||||
throw EvalError("type error, type system should have caught this");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
||||
Interpreter::OpMap op_map;
|
||||
auto lowered_ops = LowerOps(env, e);
|
||||
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
|
||||
if (lowered_ops.size()) {
|
||||
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
|
||||
CHECK(fbuild_ptr) << "Could not find registered function: relay.op.compiler._build";
|
||||
auto fbuild = *fbuild_ptr;
|
||||
|
||||
// Collect the set of lowered functions to build a module.
|
||||
Array<LoweredFunc> lowered_funcs;
|
||||
for (auto lop : lowered_ops) {
|
||||
lowered_funcs.push_back(lop->lowered_func);
|
||||
}
|
||||
|
||||
Module module = fbuild(lowered_funcs);
|
||||
|
||||
// Loop over the lowered operations to map them into the operator map.
|
||||
for (auto lop : lowered_ops) {
|
||||
Function func = lop->func;
|
||||
LoweredFunc lf = lop->lowered_func;
|
||||
|
||||
RELAY_LOG(INFO) << "LoweredFunc: " << lf->name << std::endl;
|
||||
auto op_impl = module.GetFunction(lf->name);
|
||||
op_map.insert({func, op_impl});
|
||||
}
|
||||
}
|
||||
|
||||
return op_map;
|
||||
}
|
||||
|
||||
Value Evaluate(Environment env, Expr e) {
|
||||
auto op_map = CompileOperators(env, e);
|
||||
Interpreter interp(env, op_map);
|
||||
return interp.Eval(e);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._interpreter.evaluate")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
Environment env = args[0];
|
||||
Expr expr = args[1];
|
||||
*ret = Evaluate(env, expr);
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -66,3 +66,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
|
||||
|
|
|
@ -49,9 +49,16 @@ void EnvironmentNode::Add(const GlobalVar& var,
|
|||
<< "Environment#update changes type, not possible in this mode.";
|
||||
}
|
||||
this->functions.Set(var, checked_func);
|
||||
// set gloval var map
|
||||
|
||||
auto it = global_var_map_.find(var->name_hint);
|
||||
if (it != global_var_map_.end()) {
|
||||
CHECK_EQ((*it).second, var);
|
||||
} else {
|
||||
// set global var map
|
||||
CHECK(!global_var_map_.count(var->name_hint))
|
||||
<< "Duplicate global function name " << var->name_hint;
|
||||
}
|
||||
|
||||
global_var_map_.Set(var->name_hint, var);
|
||||
}
|
||||
|
||||
|
@ -94,7 +101,7 @@ TVM_REGISTER_API("relay._make.Environment")
|
|||
TVM_REGISTER_API("relay._env.Environment_Add")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Environment env = args[0];
|
||||
env->Add(args[1], args[2], false);
|
||||
env->Add(args[1], args[2], args[3]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
|
||||
|
|
|
@ -26,7 +26,10 @@ TVM_REGISTER_API("relay._make.Constant")
|
|||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
|
||||
p->stream << "Constant(TODO)";
|
||||
const PackedFunc* fprint = Registry::Get("relay._constant_repr");
|
||||
CHECK(fprint) << "unable to find printing function for constants";
|
||||
std::string data = (*fprint)(GetRef<Constant>(node));
|
||||
p->stream << "Constant(" << data << ")";
|
||||
});
|
||||
|
||||
TensorType ConstantNode::tensor_type() const {
|
||||
|
@ -104,12 +107,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
Function FunctionNode::make(tvm::Array<Var> params,
|
||||
Expr body,
|
||||
Type ret_type,
|
||||
tvm::Array<TypeVar> type_params) {
|
||||
tvm::Array<TypeVar> type_params,
|
||||
tvm::Attrs attrs) {
|
||||
NodePtr<FunctionNode> n = make_node<FunctionNode>();
|
||||
n->params = std::move(params);
|
||||
n->body = std::move(body);
|
||||
n->ret_type = std::move(ret_type);
|
||||
n->type_params = std::move(type_params);
|
||||
n->attrs = std::move(attrs);
|
||||
return Function(n);
|
||||
}
|
||||
|
||||
|
@ -121,6 +126,39 @@ FuncType FunctionNode::func_type_annotation() const {
|
|||
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
|
||||
}
|
||||
|
||||
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
|
||||
if (!func->attrs.defined()) { return NodeRef(); }
|
||||
|
||||
const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
|
||||
CHECK(dict_attrs);
|
||||
auto it = dict_attrs->dict.find(key);
|
||||
if (it != dict_attrs->dict.end()) {
|
||||
return (*it).second;
|
||||
} else {
|
||||
return NodeRef();
|
||||
}
|
||||
}
|
||||
|
||||
Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) {
|
||||
const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
|
||||
Attrs func_attrs;
|
||||
if (dattrs) {
|
||||
Map<std::string, NodeRef> dict = dattrs->dict;
|
||||
dict.Set(key, data);
|
||||
func_attrs = DictAttrsNode::make(dict);
|
||||
} else {
|
||||
Map<std::string, NodeRef> dict = {{key, data}};
|
||||
func_attrs = DictAttrsNode::make(dict);
|
||||
}
|
||||
|
||||
return FunctionNode::make(
|
||||
func->params,
|
||||
func->body,
|
||||
func->ret_type,
|
||||
func->type_params,
|
||||
func_attrs);
|
||||
}
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(FunctionNode);
|
||||
|
||||
TVM_REGISTER_API("relay._make.Function")
|
||||
|
@ -132,7 +170,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
.set_dispatch<FunctionNode>([](const FunctionNode* node,
|
||||
tvm::IRPrinter* p) {
|
||||
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
|
||||
<< ", " << node->body << ", " << node->type_params << ")";
|
||||
<< ", " << node->body << ", " << node->type_params << ", "
|
||||
<< node->attrs << ")";
|
||||
});
|
||||
|
||||
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
|
||||
|
|
|
@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
|
|||
body.same_as(op->body)) {
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return FunctionNode::make(params, body, ret_type, ty_params);
|
||||
return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,6 +198,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
|
|||
|
||||
void ExprVisitor::VisitExpr_(const CallNode* op) {
|
||||
this->VisitExpr(op->op);
|
||||
|
||||
for (auto ty_arg : op->type_args) {
|
||||
this->VisitType(ty_arg);
|
||||
}
|
||||
|
|
|
@ -285,11 +285,11 @@ class RelayHashHandler:
|
|||
int var_counter = 0;
|
||||
};
|
||||
|
||||
size_t StructuralHash(const Type& type) {
|
||||
size_t StructuralHash::operator()(const Type& type) const {
|
||||
return RelayHashHandler().TypeHash(type);
|
||||
}
|
||||
|
||||
size_t StructuralHash(const Expr& expr) {
|
||||
size_t StructuralHash::operator()(const Expr& expr) const {
|
||||
return RelayHashHandler().ExprHash(expr);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
*
|
||||
* \file src/tvm/relay/pass/fuse_ops.cc
|
||||
*
|
||||
* \brief Fuse Relay eligble sequences of Relay operators into a single one.
|
||||
*
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/runtime/module.h>
|
||||
#include <tvm/lowered_func.h>
|
||||
#include <tvm/operation.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/logging.h>
|
||||
#include "../ir/type_functor.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
using namespace runtime;
|
||||
|
||||
struct AbstractFusableOps : ExprMutator {
|
||||
Environment env;
|
||||
Array<GlobalVar> fusable_funcs;
|
||||
int counter = 0;
|
||||
size_t expr_hash;
|
||||
|
||||
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {}
|
||||
|
||||
Expr VisitExpr_(const CallNode* call) {
|
||||
if (auto op_node = call->op.as<OpNode>()) {
|
||||
// Placeholder fusion algorithm which abstracts
|
||||
// single definitions into functions only.
|
||||
Array<Var> params;
|
||||
Array<Expr> inner_args;
|
||||
Array<Expr> args;
|
||||
|
||||
int param_number = 0;
|
||||
for (auto arg : call->args) {
|
||||
auto name = std::string("p") + std::to_string(param_number++);
|
||||
auto type = arg->checked_type();
|
||||
auto var = VarNode::make(name, type);
|
||||
params.push_back(var);
|
||||
inner_args.push_back(var);
|
||||
args.push_back(VisitExpr(arg));
|
||||
}
|
||||
|
||||
auto body = CallNode::make(call->op, inner_args, call->attrs);
|
||||
auto func = FunctionNode::make(params, body, call->checked_type(), {});
|
||||
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
|
||||
std::string func_name = "fused_";
|
||||
func_name += op_node->name;
|
||||
func_name += "_";
|
||||
func_name += std::to_string(counter++);
|
||||
func_name += "_";
|
||||
func_name += std::to_string(expr_hash);
|
||||
auto gv = GlobalVarNode::make(func_name);
|
||||
env->Add(gv, func);
|
||||
fusable_funcs.push_back(gv);
|
||||
return CallNode::make(gv, args, Attrs());
|
||||
} else {
|
||||
return ExprMutator::VisitExpr_(call);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Expr FuseOps(const Environment& env, const Expr& e) {
|
||||
// First we convert all chains of fusable ops into
|
||||
// abstracted functions which we mark as primtive
|
||||
// then we convert these primtive functions into
|
||||
// new operators.
|
||||
auto abstract = AbstractFusableOps(env, StructuralHash()(e));
|
||||
auto abstracted_e = abstract.VisitExpr(e);
|
||||
RELAY_LOG(INFO) << "FuseOps: before=" << e
|
||||
<< "Fuse: after=" << abstracted_e;
|
||||
return abstracted_e;
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.FuseOps")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = FuseOps(args[1], args[0]);
|
||||
});
|
||||
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -0,0 +1,222 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
*
|
||||
* \file src/tvm/relay/pass/lower_ops.cc
|
||||
*
|
||||
* \brief Lower a Relay program to set of TVM operators.
|
||||
*
|
||||
*/
|
||||
#include <tvm/lowered_func.h>
|
||||
#include <tvm/operation.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/logging.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/runtime/module.h>
|
||||
#include <tvm/relay/build_module.h>
|
||||
#include "../ir/type_functor.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
using namespace runtime;
|
||||
|
||||
LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
|
||||
auto node = make_node<LoweredOpNode>();
|
||||
node->func = func;
|
||||
node->lowered_func = lowered_func;
|
||||
return LoweredOp(node);
|
||||
}
|
||||
|
||||
struct AbstractLocalFunctions : ExprMutator {
|
||||
Environment env;
|
||||
size_t expr_hash;
|
||||
int counter = 0;
|
||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
||||
explicit AbstractLocalFunctions(Environment env)
|
||||
: env(env), expr_hash(0), counter(0), visited_funcs() {}
|
||||
|
||||
Expr Abstract(const Expr& e) {
|
||||
expr_hash = StructuralHash()(e);
|
||||
return VisitExpr(e);
|
||||
}
|
||||
|
||||
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
|
||||
auto gvar = GetRef<GlobalVar>(gvar_node);
|
||||
auto it = visited_funcs.find(gvar);
|
||||
if (it == visited_funcs.end()) {
|
||||
auto func = env->Lookup(gvar);
|
||||
visited_funcs.insert(gvar);
|
||||
auto new_func = FunctionNode::make(
|
||||
func->params,
|
||||
VisitExpr(func->body),
|
||||
func->ret_type,
|
||||
func->type_params,
|
||||
func->attrs);
|
||||
env->Update(gvar, new_func);
|
||||
}
|
||||
return gvar;
|
||||
}
|
||||
|
||||
Expr VisitExpr_(const FunctionNode* func_node) final {
|
||||
Function func = GetRef<Function>(func_node);
|
||||
auto free_vars = FreeVars(func);
|
||||
Array<Var> params;
|
||||
for (auto free_var : free_vars) {
|
||||
auto var = VarNode::make("free_var", free_var->checked_type());
|
||||
params.push_back(var);
|
||||
}
|
||||
std::string abs_func = "abstracted_func_";
|
||||
abs_func += std::to_string(counter++);
|
||||
abs_func += std::to_string(expr_hash);
|
||||
auto gv = GlobalVarNode::make(abs_func);
|
||||
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
|
||||
env->Add(gv, lifted_func);
|
||||
Array<Expr> args;
|
||||
for (auto free_var : free_vars) {
|
||||
args.push_back(free_var);
|
||||
}
|
||||
return CallNode::make(gv, args, {});
|
||||
}
|
||||
};
|
||||
|
||||
struct LiveFunctions : ExprVisitor {
|
||||
Environment env;
|
||||
explicit LiveFunctions(Environment env) : env(env), global_funcs() {}
|
||||
|
||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
|
||||
|
||||
void Live(const Expr& e) {
|
||||
CHECK(!e.as<FunctionNode>())
|
||||
<< "functions should of been transformed away by previous pass";
|
||||
VisitExpr(e);
|
||||
}
|
||||
|
||||
void VisitExpr_(const FunctionNode* func_node) {
|
||||
LOG(FATAL) << "functions should of been transformed away by previous pass";
|
||||
}
|
||||
|
||||
void VisitExpr_(const GlobalVarNode* var_node) final {
|
||||
GlobalVar var = GetRef<GlobalVar>(var_node);
|
||||
auto it = visited_funcs.find(var);
|
||||
if (it == visited_funcs.end()) {
|
||||
auto func = env->Lookup(var);
|
||||
visited_funcs.insert(var);
|
||||
// The last pass has trasnformed functions of the form:
|
||||
//
|
||||
// let x = fn (p_1, ..., p_n) { ... };
|
||||
// ...
|
||||
//
|
||||
// into, a top-level declaration:
|
||||
//
|
||||
// def abs_f(fv_1, ..., fv_n) {
|
||||
// return (fn (p_1...,p_N) { ... };)
|
||||
// }
|
||||
//
|
||||
// and:
|
||||
//
|
||||
// let x = abs_f(fv_1, ... fv_n);
|
||||
//
|
||||
// The only other case we can handle is
|
||||
//
|
||||
// fn foo(...) { body }
|
||||
//
|
||||
// We just search through the body in this case.
|
||||
if (auto inner_func = func->body.as<FunctionNode>()) {
|
||||
return VisitExpr(inner_func->body);
|
||||
} else {
|
||||
return VisitExpr(func->body);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void VisitExpr_(const CallNode* call) final {
|
||||
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
|
||||
if (auto gv_node = call->op.as<GlobalVarNode>()) {
|
||||
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
|
||||
Function func = env->Lookup(gvar);
|
||||
|
||||
auto attr = FunctionGetAttr(func, "Primitive");
|
||||
|
||||
if (attr.defined() && Downcast<Integer>(attr)->value == 1) {
|
||||
global_funcs.insert(gvar);
|
||||
} else {
|
||||
VisitExpr(gvar);
|
||||
}
|
||||
|
||||
// Finally we need to ensure to visit all the args no matter what.
|
||||
for (auto arg : call->args) {
|
||||
VisitExpr(arg);
|
||||
}
|
||||
} else {
|
||||
return ExprVisitor::VisitExpr_(call);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using FCompute = TypedPackedFunc<Array<Tensor>(
|
||||
const Attrs&, const Array<Tensor>&, Type, std::string)>;
|
||||
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
|
||||
|
||||
/*! \brief Return the set of operators in their TVM format. */
|
||||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
||||
const std::string& target) {
|
||||
RELAY_LOG(INFO) << "LowerOps: e=" << e;
|
||||
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
|
||||
CHECK(flower_ptr);
|
||||
PackedFunc flower = *flower_ptr;
|
||||
|
||||
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e);
|
||||
auto live_funcs = LiveFunctions(env);
|
||||
live_funcs.VisitExpr(abstracted_e);
|
||||
|
||||
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
|
||||
auto compute_reg = Op::GetAttr<FCompute>("FTVMCompute");
|
||||
|
||||
Array<LoweredOp> lowered_funcs;
|
||||
|
||||
for (auto func_name : live_funcs.global_funcs) {
|
||||
auto func = env->Lookup(func_name);
|
||||
auto call = Downcast<Call>(func->body);
|
||||
auto op_node = call->op.as<OpNode>();
|
||||
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
|
||||
auto op = GetRef<Op>(op_node);
|
||||
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
|
||||
|
||||
CHECK(IsPrimitiveOp(op)) << "failed to lower "
|
||||
<< op->name << "can only lower primitve operations";
|
||||
|
||||
Array<Tensor> inputs;
|
||||
std::string input_name = "in";
|
||||
int i = 0;
|
||||
for (auto type_arg : call->type_args) {
|
||||
auto tt = Downcast<TensorType>(type_arg);
|
||||
inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i),
|
||||
tt->shape, tt->dtype)
|
||||
.output(0));
|
||||
i++;
|
||||
}
|
||||
|
||||
auto output_tt = op->op_type->ret_type;
|
||||
Array<Tensor> outputs =
|
||||
compute_reg[op](call->attrs, inputs, output_tt, target);
|
||||
auto schedule = schedule_reg[op](outputs, target);
|
||||
size_t hash = StructuralHash()(func);
|
||||
LoweredFunc lf =
|
||||
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
|
||||
func = FunctionSetAttr(func, "LoweredFunc", lf);
|
||||
env->Add(func_name, func, true);
|
||||
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
|
||||
}
|
||||
|
||||
return lowered_funcs;
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.LowerOps")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = LowerOps(args[0], args[1], args[2]);
|
||||
});
|
||||
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -298,8 +298,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
auto* fn_ty_node = ftype.as<FuncTypeNode>();
|
||||
|
||||
CHECK(fn_ty_node != nullptr)
|
||||
<< "only expressions with function types can be called, at "
|
||||
<< call->span;
|
||||
<< "only expressions with function types can be called, found "
|
||||
<< ftype << " at " << call->span;
|
||||
|
||||
Array<Type> type_args;
|
||||
FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
|
||||
|
@ -505,12 +505,16 @@ Expr TypeInferencer::Infer(Expr expr) {
|
|||
// Step 1: Solve the constraints.
|
||||
solver_.Solve();
|
||||
// Step 2: Attach resolved types to checked_type field.
|
||||
return Resolver(type_map_, &solver_).VisitExpr(expr);
|
||||
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
|
||||
CHECK(WellFormed(resolved_expr));
|
||||
return resolved_expr;
|
||||
}
|
||||
|
||||
|
||||
Expr InferType(const Expr& expr, const Environment& env) {
|
||||
return TypeInferencer(env).Infer(expr);
|
||||
auto e = TypeInferencer(env).Infer(expr);
|
||||
CHECK(WellFormed(e));
|
||||
return e;
|
||||
}
|
||||
|
||||
Function InferType(const Function& func,
|
||||
|
@ -522,6 +526,7 @@ Function InferType(const Function& func,
|
|||
Expr func_ret = TypeInferencer(env).Infer(func_copy);
|
||||
auto map_node = env->functions.CopyOnWrite();
|
||||
map_node->data.erase(var.node_);
|
||||
CHECK(WellFormed(func_ret));
|
||||
return Downcast<Function>(func_ret);
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
*
|
||||
* \file util.cc
|
||||
*
|
||||
* \brief simple util for relay.
|
||||
* \brief Utility functions for Relay.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import numpy as np
|
||||
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm.relay.interpreter import evaluate
|
||||
from tvm.relay.graph_runtime_codegen import graph_evaluate
|
||||
from tvm.relay.scope_builder import ScopeBuilder
|
||||
from tvm.relay.op import add
|
||||
from tvm.relay.env import Environment
|
||||
|
||||
# @tq, @jr should we put this in testing ns?
|
||||
def check_rts(env, expr, args, expected_result):
|
||||
"""
|
||||
Check that evaluating `expr` applied to the arguments produces
|
||||
`result` on both the evaluator and TVM runtime.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr:
|
||||
The expression to evaluate
|
||||
|
||||
args: list of Expr
|
||||
The arguments to supply the expr.
|
||||
|
||||
expected_result:
|
||||
The expected result of running the expression.
|
||||
"""
|
||||
eval_result = evaluate(env, expr, *args)
|
||||
rts_result = graph_evaluate(env, expr, *args)
|
||||
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
|
||||
|
||||
def test_add_op_scalar():
|
||||
"""
|
||||
Program:
|
||||
fn (x, y) {
|
||||
return x + y;
|
||||
}
|
||||
"""
|
||||
env = Environment()
|
||||
x = relay.var('x', shape=())
|
||||
y = relay.var('y', shape=())
|
||||
func = relay.Function([x, y], add(x, y))
|
||||
x_data = np.array(10.0, dtype='float32')
|
||||
y_data = np.array(1.0, dtype='float32')
|
||||
check_rts(env, func, [x_data, y_data], x_data + y_data)
|
||||
|
||||
def test_add_op_tensor():
|
||||
"""
|
||||
Program:
|
||||
fn (x, y) {
|
||||
return x + y;
|
||||
}
|
||||
"""
|
||||
env = Environment()
|
||||
x = relay.var('x', shape=(10, 5))
|
||||
y = relay.var('y', shape=(10, 5))
|
||||
func = relay.Function([x, y], add(x, y))
|
||||
x_data = np.random.rand(10, 5).astype('float32')
|
||||
y_data = np.random.rand(10, 5).astype('float32')
|
||||
check_rts(env, func, [x_data, y_data], x_data + y_data)
|
||||
|
||||
def test_add_op_broadcast():
|
||||
"""
|
||||
Program:
|
||||
fn (x, y) {
|
||||
return x + y;
|
||||
}
|
||||
"""
|
||||
env = Environment()
|
||||
x = relay.var('x', shape=(10, 5))
|
||||
y = relay.var('y', shape=(1, 5))
|
||||
func = relay.Function([x, y], add(x, y))
|
||||
x_data = np.random.rand(10, 5).astype('float32')
|
||||
y_data = np.random.rand(1, 5).astype('float32')
|
||||
check_rts(env, func, [x_data, y_data], x_data + y_data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_add_op_scalar()
|
||||
test_add_op_tensor()
|
||||
test_add_op_broadcast()
|
|
@ -0,0 +1,142 @@
|
|||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.interpreter import Value, TupleValue, evaluate
|
||||
from tvm.relay import op
|
||||
from tvm.relay.scope_builder import ScopeBuilder
|
||||
from tvm.relay import testing
|
||||
|
||||
|
||||
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
|
||||
if env is None:
|
||||
env = relay.env.Environment({})
|
||||
|
||||
result = evaluate(env, expr, *args)
|
||||
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
||||
|
||||
|
||||
def test_from_scalar():
|
||||
np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1)
|
||||
np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0)
|
||||
np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True)
|
||||
|
||||
|
||||
def test_tuple_value():
|
||||
tv = TupleValue(Value.from_scalar(
|
||||
1), Value.from_scalar(2), Value.from_scalar(3))
|
||||
np.testing.assert_allclose(tv[0].asnumpy(), 1)
|
||||
np.testing.assert_allclose(tv[1].asnumpy(), 2)
|
||||
np.testing.assert_allclose(tv[2].asnumpy(), 3)
|
||||
|
||||
|
||||
def test_id():
|
||||
x = relay.var('x', 'float32')
|
||||
ident = relay.Function([x], x)
|
||||
env = relay.env.Environment({})
|
||||
res = evaluate(env, ident, 1.0)
|
||||
check_eval(ident, [1.0], 1.0)
|
||||
|
||||
|
||||
def test_add_const():
|
||||
two = op.add(relay.const(1), relay.const(1))
|
||||
func = relay.Function([], two)
|
||||
check_eval(func, [], 2)
|
||||
|
||||
|
||||
def test_mul_param():
|
||||
x = relay.var('x', shape=(10, 10))
|
||||
y = relay.var('y', shape=(1, 10))
|
||||
func = relay.Function([x, y], op.multiply(x, y))
|
||||
x_data = np.random.rand(10, 10).astype('float32')
|
||||
y_data = np.random.rand(1, 10).astype('float32')
|
||||
check_eval(func, [x_data, y_data], x_data * y_data)
|
||||
|
||||
|
||||
# failing due to numeric issues
|
||||
|
||||
# def test_dense():
|
||||
# x = relay.var('x', shape=(10, 10))
|
||||
# w = relay.var('w', shape=(10, 10))
|
||||
# y = op.nn.dense(x, w)
|
||||
# func = relay.Function([x, w], y)
|
||||
# x_data = np.random.rand(10, 10).astype('float32')
|
||||
# w_data = np.random.rand(10, 10).astype('float32')
|
||||
# check_eval(func, [x_data, w_data], x_data @ w_data, rtol=0.1)
|
||||
|
||||
# def test_linear():
|
||||
# x = relay.var('x', shape=(10, 10))
|
||||
# w = relay.var('w', shape=(10, 10))
|
||||
# b = relay.var('b', shape=(10,))
|
||||
# y = op.add(op.nn.dense(x, w), b)
|
||||
# func = relay.Function([x, w, b], y)
|
||||
# x_data = np.random.rand(10, 10).astype('float32')
|
||||
# w_data = np.random.rand(10, 10).astype('float32')
|
||||
# b_data = np.random.rand(10).astype('float32')
|
||||
# check_eval(func, [x_data, w_data, b_data], x_data @ w_data + b_data)
|
||||
|
||||
def test_equal():
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
j = relay.var('i', shape=[], dtype='int32')
|
||||
z = op.equal(i, j)
|
||||
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
|
||||
i_data = relay.const(0)
|
||||
j_data = relay.const(0)
|
||||
check_eval(func, [i_data, j_data], True)
|
||||
|
||||
def test_subtract():
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
sub = op.subtract(i, relay.const(1, dtype='int32'))
|
||||
func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
|
||||
i_data = np.array(1, dtype='int32')
|
||||
check_eval(func, [i_data], 0)
|
||||
|
||||
def test_simple_loop():
|
||||
env = relay.env.Environment({})
|
||||
sum_up = relay.GlobalVar('sum_up')
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
sb = ScopeBuilder()
|
||||
with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))):
|
||||
sb.ret(i)
|
||||
with sb.else_scope():
|
||||
one_less = op.subtract(i, relay.const(1, dtype='int32'))
|
||||
rec_call = relay.Call(sum_up, [one_less])
|
||||
sb.ret(op.add(rec_call, i))
|
||||
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
|
||||
env[sum_up] = func
|
||||
i_data = np.array(10, dtype='int32')
|
||||
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env)
|
||||
|
||||
def test_loop():
|
||||
env = relay.env.Environment({})
|
||||
sum_up = relay.GlobalVar('sum_up')
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
accum = relay.var('accum', shape=[], dtype='int32')
|
||||
sb = ScopeBuilder()
|
||||
with sb.if_scope(op.equal(i, relay.const(0))):
|
||||
sb.ret(accum)
|
||||
with sb.else_scope():
|
||||
one_less = op.subtract(i, relay.const(1))
|
||||
new_accum = op.add(accum, i)
|
||||
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
|
||||
func = relay.Function([i, accum], sb.get())
|
||||
env[sum_up] = func
|
||||
i_data = np.array(10, dtype='int32')
|
||||
accum_data = np.array(0, dtype='int32')
|
||||
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env)
|
||||
|
||||
def test_mlp():
|
||||
pass
|
||||
# net = testing.mlp.get_workload(1)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_id()
|
||||
test_add_const()
|
||||
# test_dense()
|
||||
# test_linear()
|
||||
test_equal()
|
||||
test_subtract()
|
||||
test_simple_loop()
|
||||
test_loop()
|
||||
test_mlp()
|
||||
|
|
@ -5,6 +5,16 @@ import tvm
|
|||
import numpy as np
|
||||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm import relay
|
||||
from tvm.relay import op
|
||||
from tvm.relay.scope_builder import ScopeBuilder
|
||||
|
||||
|
||||
def assert_has_type(expr, typ, env=relay.env.Environment({})):
|
||||
checked_expr = infer_type(expr, env)
|
||||
checked_type = checked_expr.checked_type
|
||||
if checked_type != typ:
|
||||
raise RuntimeError("Type mismatch %s vs %s" % (
|
||||
checked_type, typ))
|
||||
|
||||
|
||||
def test_monomorphic_let():
|
||||
|
@ -16,6 +26,31 @@ def test_monomorphic_let():
|
|||
assert xchecked.checked_type == relay.scalar_type("float64")
|
||||
|
||||
|
||||
def test_single_op():
|
||||
"Program: fn (x : float32) { let t1 = f(x); t1 }"
|
||||
x = relay.var('x', shape=[])
|
||||
func = relay.Function([x], op.log(x))
|
||||
ttype = relay.TensorType([], dtype='float32')
|
||||
assert_has_type(func, relay.FuncType([ttype], ttype))
|
||||
|
||||
|
||||
def test_add_broadcast_op():
|
||||
"""
|
||||
Program:
|
||||
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
|
||||
return x + y;
|
||||
}
|
||||
"""
|
||||
pass
|
||||
# x = relay.var('x', shape=(10, 4))
|
||||
# y = relay.var('y', shape=(5, 10, 1))
|
||||
# z = x + y
|
||||
# func = relay.Function([x, y], z)
|
||||
# ttype = relay.TensorType((5, 5, 5), 'float32')
|
||||
# expected_ty = relay.FuncType([ttype, ttype], ttype)
|
||||
# assert_has_type(func.to_func(), expected_ty)
|
||||
|
||||
|
||||
def test_dual_op():
|
||||
"""Program:
|
||||
fn (x : Tensor[f32, (10, 10)]) {
|
||||
|
@ -41,7 +76,6 @@ def test_decl():
|
|||
return log(x);
|
||||
}
|
||||
"""
|
||||
sb = relay.ScopeBuilder()
|
||||
tp = relay.TensorType((10, 10))
|
||||
x = relay.var("x", tp)
|
||||
f = relay.Function([x], relay.log(x))
|
||||
|
@ -76,6 +110,24 @@ def test_recursion():
|
|||
assert "%3 = @f(%1, %2)" in env.astext()
|
||||
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
|
||||
|
||||
# This currently fails and should pass under the type system.
|
||||
#
|
||||
# This test is to illustrate problem with our weak form of
|
||||
# unification.
|
||||
#
|
||||
|
||||
|
||||
def test_incomplete_call():
|
||||
sb = ScopeBuilder()
|
||||
x = relay.var('x', dtype='int32')
|
||||
f = relay.var('f')
|
||||
func = relay.Function([x, f], relay.Call(f, [x]))
|
||||
|
||||
try:
|
||||
relay.ir_pass.infer_type(func)
|
||||
assert False
|
||||
except tvm.TVMError as e:
|
||||
assert True
|
||||
|
||||
def test_tuple():
|
||||
tp = relay.TensorType((10,))
|
||||
|
@ -84,13 +136,13 @@ def test_tuple():
|
|||
assert (relay.ir_pass.infer_type(res).checked_type ==
|
||||
relay.TupleType([tp, tp]))
|
||||
|
||||
|
||||
def test_free_expr():
|
||||
x = relay.var("x", "float32")
|
||||
y = relay.add(x, x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
assert yy.checked_type == relay.scalar_type("float32")
|
||||
|
||||
|
||||
def test_type_args():
|
||||
x = relay.var("x", shape=(10, 10))
|
||||
y = relay.var("y", shape=(1, 10))
|
||||
|
@ -107,6 +159,7 @@ def test_type_args():
|
|||
assert sh2[0].value == 1
|
||||
assert sh2[1].value == 10
|
||||
|
||||
|
||||
def test_self_reference():
|
||||
"""
|
||||
Program:
|
||||
|
@ -117,30 +170,40 @@ def test_self_reference():
|
|||
a = relay.TypeVar("a")
|
||||
x = relay.var("x", a)
|
||||
sb = relay.ScopeBuilder()
|
||||
|
||||
f = relay.Function([x], x)
|
||||
fx = relay.Call(f, [x])
|
||||
assert relay.ir_pass.infer_type(x).checked_type == a
|
||||
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
|
||||
assert relay.ir_pass.infer_type(fx).checked_type == a
|
||||
|
||||
|
||||
def test_global_var_cow_issue():
|
||||
env = relay.env.Environment({})
|
||||
gv = relay.GlobalVar("foo")
|
||||
x = relay.var('x', shape=[])
|
||||
func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32'))
|
||||
func = relay.Function([x], relay.Call(gv, [x]),
|
||||
relay.TensorType([], 'float32'))
|
||||
env[gv] = func
|
||||
# They should both point to the same global variable if global variables are
|
||||
# stable across type checking.
|
||||
assert gv == func.body.op
|
||||
|
||||
|
||||
def test_equal():
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
eq = op.equal(i, relay.const(0, dtype='int32'))
|
||||
# This should fail ....
|
||||
func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_free_expr()
|
||||
test_dual_op()
|
||||
test_single_op()
|
||||
test_recursion()
|
||||
test_monomorphic_let()
|
||||
test_decl()
|
||||
test_recursion()
|
||||
test_tuple()
|
||||
test_incomplete_call()
|
||||
test_free_expr()
|
||||
test_type_args()
|
||||
test_self_reference()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
export PYTHONPATH=python:apps/extension/python
|
||||
export PYTHONPATH=python:topi/python:apps/extension/python
|
||||
export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH}
|
||||
|
||||
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
|
||||
|
|
Загрузка…
Ссылка в новой задаче