[RELAY][RUNTIME] Add Relay interpreter and compiler for TVM runtime system. (#1954)

This commit is contained in:
Jared Roesch 2018-10-30 15:29:36 -07:00 коммит произвёл Tianqi Chen
Родитель 07399e0239
Коммит 10ea05e645
29 изменённых файлов: 2168 добавлений и 59 удалений

Просмотреть файл

@ -22,8 +22,15 @@ namespace tvm {
* You can find more about Relay by reading the language reference.
*/
namespace relay {
#define RELAY_DEBUG(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}
/*!
* \brief we always used NodeRef for referencing nodes.
* \brief We always used NodeRef for referencing nodes.
*
* By default, NodeRef is a std::shared_ptr of node
*/

Просмотреть файл

@ -0,0 +1,76 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/build_module.h
* \brief The passes and data structures needed to build a
* tvm::Module from a Relay program.
*/
#ifndef TVM_RELAY_BUILD_MODULE_H_
#define TVM_RELAY_BUILD_MODULE_H_
#include <tvm/lowered_func.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief A lowered Relay operation.
*
* A lowered operation is a pair containing the "primitive" function used
* to produce the lowered function as well as the lowered function itself.
*/
class LoweredOp;
/*! \brief Call container. */
class LoweredOpNode : public Node {
public:
/*!
* \brief The primitive function to be lowered.
*
* A primitive function consists only of calls to relay::Op which
* can be fused.
*/
Function func;
/*!
* \brief The lowered function.
*/
LoweredFunc lowered_func;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("lowered_func", &lowered_func);
}
TVM_DLL static LoweredOp make(
Function func,
LoweredFunc lowered_func);
static constexpr const char* _type_key = "relay.LoweredOp";
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
};
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
/*!
* \brief Lower the operations contained in a Relay expression.
*
* The lowering pass will only lower functions marked as primitive,
* the FuseOps pass will provide this behavior, if run before LowerOps.
*
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param env The environment.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
const std::string& target = "llvm");
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BUILD_MODULE_H_

Просмотреть файл

@ -213,12 +213,18 @@ class FunctionNode : public ExprNode {
*/
tvm::Array<TypeVar> type_params;
/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("span", &span);
v->Visit("attrs", &attrs);
v->Visit("_checked_type_", &checked_type_);
}
@ -233,7 +239,8 @@ class FunctionNode : public ExprNode {
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params);
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
@ -241,6 +248,11 @@ class FunctionNode : public ExprNode {
RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.

Просмотреть файл

@ -0,0 +1,140 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/interpreter.h
* \brief An interpreter for Relay.
*
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay environment, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
*/
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*!
* \brief A Relay value.
*/
class Value;
/*! \brief Evaluate an expression using the interpreter producing a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
*
* The interpreter interprets the program fragments not supported by the
* TVM runtime, although the interpreter is naively implemented it uses
* TVM operators for evaluating all operators.
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*/
Value Evaluate(Environment env, Expr e);
/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};
class Value : public NodeRef {
public:
Value() {}
explicit Value(NodePtr<Node> n) : NodeRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(node_.get());
}
using ContainerType = ValueNode;
};
/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;
/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
*/
Function func;
ClosureNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("env", &env);
v->Visit("func", &func);
}
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);
static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
/*! \brief A tuple value. */
class TupleValue;
/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;
TupleValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleValue make(tvm::Array<Value> value);
static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);
/*! \brief A tensor value. */
class TensorValue;
/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;
TensorValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }
/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);
/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);
static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_

Просмотреть файл

@ -8,6 +8,7 @@
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
@ -20,7 +21,8 @@ namespace relay {
* populated with the result type.
*
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None.
* \param env The environment used for referencing global functions, can be
* None.
*
* \return A type checked expression with its checked_type field populated.
*/
@ -35,7 +37,8 @@ Expr InferType(const Expr& expr, const Environment& env);
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);
Function InferType(const Function& f, const Environment& env,
const GlobalVar& var);
/*!
* \brief Check that types are well kinded by applying "kinding rules".
@ -94,28 +97,30 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed.
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
* although x is not shadowed.
*
* \param e the expression to check.
* \param expr the expression to check.
*
* \return true iff all Var in e is bound at most once.
* \return true iff all Var in expr is bound at most once.
*/
bool WellFormed(const Expr& e);
bool WellFormed(const Expr& expr);
/*! \brief Get free Vars from expr in PostDFS order.
/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
*
* \param expr the expression.
*
* \return List of free vars, in the PostDFS order visited by expr.
* \return List of free vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> FreeVars(const Expr& expr);
/*! \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param expr the expression.
*
@ -125,10 +130,12 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let binding that are not referenced, and if branch that are not entered.
* It will remove let bindings which are not referenced, and branches that will
* not be entered.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a.
* Another example is `if (true) then 1 else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
*
* \param e the expression to optimize.
*
@ -136,27 +143,30 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/
Expr DeadCodeElimination(const Expr& e);
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t StructuralHash(const Type& type);
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t StructuralHash(const Expr& expr);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_H_

Просмотреть файл

@ -1,5 +1,7 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from ..api import register_func
from . import base
from . import ty
from . import expr
@ -15,6 +17,7 @@ from . import nn
from . import vision
from . import image
from .scope_builder import ScopeBuilder
# Span
@ -46,6 +49,21 @@ Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())
@register_func("relay._constant_repr")
def _tensor_constant_repr(tv):
return str(tv.data.asnumpy())
# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
import pdb
pdb.set_trace()

Просмотреть файл

@ -0,0 +1,4 @@
"""The interface to the Evaluator exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._interpreter", __name__)

Просмотреть файл

@ -45,9 +45,12 @@ class Environment(RelayNode):
func: Function
The function.
"""
return self._add(var, func)
def _add(self, var, func, update=False):
if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)
_env.Environment_Add(self, var, func)
return _env.Environment_Add(self, var, func, update)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.

Просмотреть файл

@ -0,0 +1,551 @@
"""
A compiler from a Relay expression to TVM's graph runtime.
The compiler is built from a few pieces.
First we define a compiler from a single Relay expression to the
graph langauge. We require the expression to be a function.
The function's parameters correpond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
The compiler's output is a program in the graph language, which is composed of
graph langauge is composed of Node, NodeRef, InputNode, OpNode.
This "little language" represents programs in TVM's graph format.
To connect to the graph runtime, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
contrib.graph_runtime or any other TVM runtime comptatible system.
We expose this functionality in compile_to_tvm.
"""
from __future__ import absolute_import
import json
import attr
from . import ir_pass
from .op import Op
from .expr import Var, Function, Call, If, GlobalVar, Constant, Let, Tuple
from ..build_module import build as tvm_build_module
from .. contrib import graph_runtime
from .ir_pass import infer_type
from .. import cpu
class AbstractExprVisitor(object):
"""A visitor over Expr in Python."""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
class ExprMutator(AbstractExprVisitor):
"""A functional visitor over Expr in Python."""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, var):
return var
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_constant(self, const):
return const
@attr.s
class NodeRef(object):
"""A reference to a node, used for constructing the graph."""
ident = attr.ib()
index = attr.ib(default=0)
version = attr.ib(default=0)
def to_json(self):
return [self.ident, self.index, self.version]
@attr.s
class Node(object):
"""The base class for nodes in the TVM runtime system graph input."""
name = attr.ib()
attrs = attr.ib()
is_output = attr.ib()
def to_json(self):
raise Exception("Abstract method, please implement me.")
@attr.s
class InputNode(Node):
"""An input node in the TVM runtime system graph input."""
name = attr.ib()
attrs = attr.ib()
is_output = attr.ib(default=False)
def to_json(self):
return {
"op": "null",
"name": self.name,
"inputs": []
}
@attr.s
class OpNode(Node):
"""An operator node in the TVM runtime system's graph input."""
op_name = attr.ib()
inputs = attr.ib()
op_attrs = attr.ib()
is_output = attr.ib(default=False)
def to_json(self):
attrs = dict.copy(self.op_attrs)
# Extend ops with extra info.
attrs['func_name'] = self.op_name
# When do we flatten?
attrs['flatten_data'] = "0"
# Fix me!
attrs['num_inputs'] = str(len(self.inputs))
attrs['num_outputs'] = "1"
return {
"op": "tvm_op",
"name": self.name,
"attrs": attrs,
"inputs": self.inputs
}
def shape_to_json(shape):
return [sh.value for sh in shape]
def from_tensor(typ):
return (typ.dtype, shape_to_json(typ.shape))
class GraphRuntimeCodegen(ExprMutator):
"""The compiler from Relay to the TVM runtime system."""
nodes = attr.ib()
id_map = attr.ib()
def __init__(self, env):
ExprMutator.__init__(self)
self.nodes = []
self.id_map = {}
self.env = env
def add_node(self, node):
"""
Add a node to the graph.
Parameters
----------
node: Node
The node to add to the graph.
Returns
-------
node_ref: NodeRef
A reference to the node.
"""
self.nodes.append(node)
ident = len(self.nodes) - 1
return NodeRef(ident)
def add_binding(self, ident, ref):
"""
Add a identifier to node mapping.
Parameters
----------
ident: relay.Var
The variable to map
ref: NodeRef
The node the identifier points.
"""
self.id_map[ident] = ref
def let_bind(self, ident, node):
"""
Let bind node to ident.
Parameters
----------
ident: relay.Var
The variable to map.
ref: NodeRef
The node the identifier points.
Returns
-------
ref: NodeRef
Return reference to the node.
"""
ref = self.add_node(node)
self.add_binding(ident, ref)
return ref
def get_node(self, ref):
"""
Lookup a node by a node reference.
Parameters
----------
ref: NodeRef
The reference to lookup.
Returns
-------
node: Node
The node.
"""
return self.nodes[ref.ident]
def lookup(self, ident):
"""
Lookup a node by identifier.
Parameters
----------
ident: relay.Var
The reference to lookup.
Returns
-------
node: Node
The node.
"""
return self.id_map[ident]
def codegen(self, func):
"""Compile a single function into a graph.
Parameters
----------
func: tvm.relay.Expr
The function to compile.
"""
# First we convert all the parameters into input nodes.
params = func.params
for param in params:
dtype, shape = from_tensor(param.type_annotation)
node = InputNode("{0}".format(param.name_hint), {
"shape": shape,
"dtype": dtype,
})
self.let_bind(param, node)
# Then we compile the body into a graph which can depend
# on input variables.
output_ref = self.visit(func.body)
# Finally we retreive return value of program, which will
# become our output node.
self.get_node(output_ref).is_output = True
def visit_let(self, let):
"""
Visit the let binding, by first traversing its value,
then setting the metadata on the returned NodeRef.
Finally visit the body, and return the NodeRef corresponding
to it.
Parameters
----------
let: tvm.relay.Expr
The let binding to transform.
Returns
-------
ref: NodeRef
The node reference to the body.
"""
ident = let.var
val = let.value
body = let.body
val_ref = self.visit(val)
dtype, shape = from_tensor(val.checked_type())
val_node = self.get_node(val_ref)
val_node.attrs["dtype"] = dtype
val_node.attrs["shape"] = shape
self.add_binding(ident, val_ref)
return self.visit(body)
def visit_var(self, var):
return self.lookup(var)
def visit_call(self, call):
"""Transform a ::tvm.relay.Call into an operator in the TVM graph."""
inputs = []
for arg in call.args:
inputs.append(self.visit(arg).to_json())
if isinstance(call.op, Op):
raise Exception(
"Operators should be transformed away; try applying" +
"the fuse_ops transformation to the expression.")
elif isinstance(call.op, GlobalVar):
func = self.env[call.op]
elif isinstance(call.op, Function):
func = call.op
else:
raise Exception(
"TVM runtime does not support calls to {0}".format(type(call.op)))
if int(func.attrs.Primitive) != 1:
raise Exception(
"TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)")
op_name = func.attrs.LoweredFunc.name
attrs = {'shape': shape_to_json(call.checked_type.shape),
'dtype': call.checked_type.dtype}
call_hash = str(ir_pass.structural_hash(call))
op_node = OpNode("call_" + call_hash, attrs, op_name, inputs, {})
return self.add_node(op_node)
def to_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
TVM graph runtime format.
Returns
-------
graph_json : str
The generated JSON as a string.
"""
nodes = []
# First we compute "nodes" field.
for node in self.nodes:
nodes.append(node.to_json())
arg_nodes = []
heads = []
# Compute "arg_nodes" and "heads" fields.
for i, node in enumerate(self.nodes):
if isinstance(node, InputNode):
arg_nodes.append(i)
if node.is_output:
# Need to fix this.
heads.append(NodeRef(i).to_json())
def compute_node_row_ptr(nodes):
"""Calculate the node_row_ptr field by doing a DFS backwards
from the output and reversing the path.
"""
row_ptr = [len(nodes)]
discovered = set()
stack = []
stack.append(len(nodes) - 1)
while stack:
i = stack.pop()
if i not in discovered:
discovered.add(i)
row_ptr.append(i)
node = nodes[i]
if isinstance(node, OpNode):
for inp in node.inputs:
stack.append(inp[0])
row_ptr.reverse()
return row_ptr
# Compute "node_row_ptr".
node_row_ptr = compute_node_row_ptr(self.nodes)
# Compute "attrs" field.
attrs = {}
# These fields are mandatory.
shapes = []
storage_ids = []
dtype = []
dltype = []
for i, node in enumerate(self.nodes):
storage_ids.append(i)
shapes.append(node.attrs['shape'])
if node.attrs['dtype'] == 'float32':
dtype.append(0)
dltype.append('float32')
attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids]
attrs["dtype"] = ["list_int", dtype]
attrs["dltype"] = ["list_str", dltype]
json_dict = {
"nodes": nodes,
"arg_nodes": arg_nodes,
"heads": heads,
"attrs": attrs,
"node_row_ptr": node_row_ptr
}
return json.dumps(json_dict)
def build(env, func, target=None):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if target is None:
target = 'llvm'
comp = GraphRuntimeCodegen(env)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
return graph_json, mod, None # params currently isn't supported by API
def graph_evaluate(env, func, *args):
"""
Corresponding function to tvm.relay.eval.evaluate.
This function evaluates a Relay expression on the
TVM graph_runtime.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: tvm.relay.Expr
The expression to evaluate.
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
Returns
-------
value: tvm.NDArray
The output Tensor produced by evaluating the expression.
"""
func = infer_type(func, env)
func = ir_pass.fuse_ops(env, func)
func = infer_type(func, env)
graph_json, mod, params = build(env, func)
assert params is None
gmodule = graph_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)

Просмотреть файл

@ -0,0 +1,130 @@
#pylint: disable=no-else-return
"""An interface to the Realy interpreter."""
from __future__ import absolute_import
import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node
from . import _make
from . import _interpreter
from . import ir_pass
from .expr import Call, Constant, GlobalVar
from . import const
from .._ffi.base import integer_types
class Value(NodeBase):
"""Base class of all values.
"""
@staticmethod
@register_func("relay.from_scalar")
def from_scalar(i, dtype=None):
"""Convert a Python scalar to a Relay scalar."""
if dtype is None:
if isinstance(i, integer_types):
dtype = 'int32'
elif isinstance(i, float):
dtype = 'float32'
elif isinstance(i, bool):
dtype = 'uint8'
else:
raise Exception("unable to infer dtype {0}".format(type(i)))
return TensorValue(nd.array(np.array(i, dtype=dtype)))
@register_relay_node
class TupleValue(Value):
def __init__(self, *fields):
self.__init_handle_by_constructor__(
_make.TupleValue, fields)
def __getitem__(self, field_no):
return self.fields[field_no]
@register_relay_node
class Closure(Value):
pass
@register_relay_node
class TensorValue(Value):
"""A Tensor value produced by the evaluator."""
def __init__(self, data):
"""Allocate a new TensorValue and copy the data from `array` into
the new array.
"""
if isinstance(data, np.ndarray):
data = nd.array(data)
self.__init_handle_by_constructor__(
_make.TensorValue, data)
def as_ndarray(self):
"""Convert a Relay TensorValue into a tvm.ndarray."""
return self.data
def asnumpy(self):
"""Convert a Relay TensorValue into a numpy.ndarray."""
return self.data.asnumpy()
def __eq__(self, other):
return self.data == other.data
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data)
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
return arg
else:
return const(arg)
def apply_passes(expr, env=None):
ck_expr = ir_pass.infer_type(expr, env=env)
fused_expr = ir_pass.fuse_ops(env, ck_expr)
return fused_expr
def evaluate(env, expr, *args):
"""
Evaluate a Relay expression on the interpreter.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: tvm.relay.Expr
The expression to evaluate.
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
Returns
-------
value: tvm.relay.eval.Value
The value produced by evaluating the expression.
"""
# assert len(args) == 0
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))
# TODO: We need to move this optimization code into the optimizer/pass manager
if isinstance(expr, GlobalVar):
func = env[expr]
func = apply_passes(func, env)
env._add(expr, func, True)
opt_expr = Call(expr, relay_args)
# import pdb; pdb.set_trace()
return _interpreter.evaluate(env, opt_expr)
else:
expr = Call(expr, relay_args)
opt_expr = apply_passes(expr, env)
return _interpreter.evaluate(env, opt_expr)

Просмотреть файл

@ -240,3 +240,9 @@ def structural_hash(value):
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def fuse_ops(expr, env):
return _ir_pass.FuseOps(env, expr)
def lower_ops(env, expr, target='llvm'):
return _ir_pass.LowerOps(env, expr, target)

Просмотреть файл

@ -1,2 +1,49 @@
#pylint: disable=invalid-name
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
import tvm
import topi
from . import register
def add_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.add(inputs[0], inputs[1])]
def add_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("add", "FTVMCompute", add_compute)
register("add", "FTVMSchedule", add_schedule)
def subtract_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.subtract(inputs[0], inputs[1])]
def subtract_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("subtract", "FTVMCompute", subtract_compute)
register("subtract", "FTVMSchedule", subtract_schedule)
def multiply_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.multiply(inputs[0], inputs[1])]
def multiply_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("multiply", "FTVMCompute", multiply_compute)
register("multiply", "FTVMSchedule", multiply_schedule)
def equal_compute(attrs, inputs, output_type, target):
assert len(inputs) == 2
return [topi.equal(inputs[0], inputs[1])]
def equal_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("equal", "FTVMCompute", equal_compute)
register("equal", "FTVMSchedule", equal_schedule)

Просмотреть файл

@ -0,0 +1,16 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
import tvm
import topi
from .. import register
def dense_compiler(attrs, inputs, output_type):
assert len(inputs) == 2
return [topi.nn.dense(inputs[0], inputs[1])]
def dense_schedule(outputs, target):
assert len(outputs) == 1
return tvm.create_schedule(outputs[0].op)
register("nn.dense", "FTVMCompute", dense_compiler)
register("nn.dense", "FTVMSchedule", dense_schedule)

Просмотреть файл

@ -3,7 +3,8 @@ from ..._ffi.function import _init_api
from ..base import register_relay_node
from ..expr import Expr
from ...api import register_func
from ...build_module import lower, build
@register_relay_node
class Op(Expr):
@ -75,3 +76,11 @@ def register(op_name, attr_key, value=None, level=10):
_init_api("relay.op", __name__)
@register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
@register_func("relay.op.compiler._build")
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")

Просмотреть файл

@ -17,6 +17,7 @@
"""
a simple multilayer perceptron
"""
from __future__ import absolute_import
from tvm import relay
from .init import create_workload

432
src/relay/interpreter.cc Normal file
Просмотреть файл

@ -0,0 +1,432 @@
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/build_module.h>
#include "./ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
/* Value Implementation */
Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
NodePtr<ClosureNode> n = make_node<ClosureNode>();
n->env = std::move(env);
n->func = std::move(func);
return Closure(n);
}
TVM_REGISTER_API("relay._make.Closure")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ClosureNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
p->stream << "ClosureNode(" << node->func << ")";
});
TupleValue TupleValueNode::make(tvm::Array<Value> value) {
NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
n->fields = value;
return TupleValue(n);
}
TVM_REGISTER_API("relay._make.TupleValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleValueNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const TupleValueNode* node,
tvm::IRPrinter* p) {
p->stream << "TupleValueNode(" << node->fields << ")";
});
TensorValue TensorValueNode::make(runtime::NDArray data) {
NodePtr<TensorValueNode> n = make_node<TensorValueNode>();
n->data = std::move(data);
return TensorValue(n);
}
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorValueNode>([](const TensorValueNode* node,
tvm::IRPrinter* p) {
auto to_str = GetPackedFunc("relay._tensor_value_repr");
std::string data_str = to_str(GetRef<TensorValue>(node));
p->stream << "TensorValueNode(" << data_str << ")";
});
TensorValue TensorValueNode::FromType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) {
std::vector<int64_t> dims;
for (auto dim : tt_node->shape) {
auto int_node = dim.as<tvm::ir::IntImm>();
CHECK(int_node) << "expected concrete dimensions";
dims.push_back(int_node->value);
}
DLDataType dtype;
DLContext context;
switch (tt_node->dtype.code()) {
case halideir_type_int:
dtype.code = kDLInt;
break;
case halideir_type_uint:
dtype.code = kDLUInt;
break;
case halideir_type_float:
dtype.code = kDLFloat;
break;
default:
throw dmlc::Error("can not convert HalideIR type into DLTensor dtype");
}
dtype.bits = tt_node->dtype.bits();
dtype.lanes = tt_node->dtype.lanes();
// TODO(@jroesch): Is this the right place to place the tensor?
context.device_type = DLDeviceType::kDLCPU;
context.device_id = 0;
runtime::NDArray data = NDArray::Empty(dims, dtype, context);
return TensorValueNode::make(data);
} else {
LOG(FATAL) << "expected a tensor type";
return TensorValue();
}
}
TVM_REGISTER_API("relay._make.TensorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
runtime::NDArray data = args[0];
*ret = TensorValueNode::make(data);
});
/* Evaluator Implementation. */
struct EvalError : dmlc::Error {
explicit EvalError(const std::string& msg) : Error(msg) {}
};
/*!
* \brief A stack frame in the Relay interpreter.
*
* Contains a mapping from relay::Var to relay::Value.
*/
struct Frame {
/*! \brief The set of local variables and arguments for the frame. */
tvm::Map<Var, Value> locals;
explicit Frame(tvm::Map<Var, Value> locals) : locals(locals) {}
};
/*!
* \brief The call stack in the Relay interpreter.
*
* Contains a stack of frames; each corresponding to
* a function call.
*/
struct Stack {
/*! \brief The stack frames. */
std::vector<Frame> frames;
Stack() : frames() { frames.push_back(Frame({})); }
Frame& current_frame() { return frames.back(); }
Value Lookup(const Var& local) {
for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
auto elem = frame->locals.find(local);
if (elem != frame->locals.end()) {
return (*elem).second;
}
}
LOG(FATAL) << "could not find variable binding for " << local
<< "address= " << local.operator->();
return Value();
}
/*!
* A wrapper around Frame to add RAII semantics to pushing and popping
* stack frames.
*/
struct LocalFrame {
Stack& st;
explicit LocalFrame(Stack& st, const Frame& fr) : st(st) {
st.frames.push_back(fr);
}
~LocalFrame() { st.frames.pop_back(); }
};
};
/*! \brief The equal comparator for expressions. */
struct ExprEqual {
bool operator()(const Expr& a, const Expr& b) const {
return AlphaEqual(a, b);
}
};
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
Environment env;
Stack stack;
using JitKey = Function;
using OpMap = std::unordered_map<JitKey, PackedFunc, StructuralHash, ExprEqual>;
OpMap operator_map_;
template <typename T>
T with_frame(const Frame& fr, const std::function<T()>& f) {
Stack::LocalFrame lf(stack, fr);
return f();
}
Interpreter(Environment env) : env(env), operator_map_() {}
Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {}
void extend(const Var& id, Value v) {
this->stack.current_frame().locals.Set(id, v);
}
inline Value Lookup(const Var& local) {
return this->stack.Lookup(local);
}
Value Eval(const Expr& expr) {
return (*this)(expr);
}
Value VisitExpr(const Expr& expr) override {
RELAY_LOG(INFO) << "VisitExpr: " << expr << std::endl;
auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
return ret;
}
Value VisitExpr_(const VarNode* var_node) override {
return Lookup(GetRef<Var>(var_node));
}
Value VisitExpr_(const GlobalVarNode* op) override {
return Eval(this->env->Lookup(GetRef<GlobalVar>(op)));
}
Value VisitExpr_(const OpNode* id) override {
// TODO(@jroesch): Eta-expand and return in this case.
throw EvalError(
"internal error, need to wrap intrinsic into call synthetic call node "
"in "
"this case, eta expand");
}
Value VisitExpr_(const ConstantNode* op) override {
return TensorValueNode::make(op->data);
}
Value VisitExpr_(const TupleNode* op) override {
std::vector<Value> values;
for (const auto& field : op->fields) {
Value field_value = Eval(field);
values.push_back(field_value);
}
return TupleValueNode::make(values);
}
Value VisitExpr_(const FunctionNode* func_node) override {
auto func = GetRef<Function>(func_node);
tvm::Map<Var, Value> captured_env;
Array<Var> free_vars = FreeVars(func);
for (const auto& var : free_vars) {
captured_env.Set(var, Eval(var));
}
return ClosureNode::make(captured_env, func);
}
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
Type ret_type) {
// Marshal the arguments.
auto arg_len = args.size() + 1;
std::vector<TVMValue> values(arg_len);
std::vector<int> codes(arg_len);
TVMArgsSetter setter(values.data(), codes.data());
TVMRetValue ret;
// We need real type information to properly allocate the structure.
for (size_t i = 0; i < args.size(); i++) {
if (const TensorValueNode* tv = args[i].as<TensorValueNode>()) {
setter(i, tv->data);
}
}
// TVM's calling convention is that the final argument is the output
// buffer. To preserve the illusion of being a functional language
// we need to allocate space for the output buffer based on the
// return type.
CHECK(ret_type.as<TensorTypeNode>());
auto out_tensor = TensorValueNode::FromType(ret_type);
setter(arg_len - 1, out_tensor->data);
func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &ret);
return out_tensor;
}
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
// Get a reference to the function inside the closure.
auto func = closure->func;
auto compiled = operator_map_.find(func);
tvm::Array<Function> funcs;
for (auto op : operator_map_) {
funcs.push_back(op.first);
}
// This case we know we have precompiled the operator.
if (compiled != operator_map_.end()) {
auto func_ty = func->func_type_annotation();
return InvokeCompiledOp(compiled->second, args, func_ty->ret_type);
}
// Allocate a frame with the parameters and free variables.
tvm::Map<Var, Value> locals;
CHECK_EQ(func->params.size(), args.size());
for (size_t i = 0; i < func->params.size(); i++) {
CHECK_EQ(locals.count(func->params[i]), 0);
locals.Set(func->params[i], args[i]);
}
// Add the var to value mappings from the Closure's environment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second);
}
return with_frame<Value>(Frame(locals), [&]() { return Eval(func->body); });
}
Value VisitExpr_(const CallNode* call) override {
tvm::Array<Value> args;
for (auto arg : call->args) {
args.push_back(Eval(arg));
}
// We should not find operators after running fusion,
// and operator lowering.
//
// We have some functions cotaining chunks of operators
// which will be loaded into operator map.
if (auto op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try "
"fusing and lowering";
}
// Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op);
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
} else {
throw EvalError(
"internal error: type error, expected function value in the call "
"position");
}
}
Value VisitExpr_(const LetNode* op) override {
auto value = Eval(op->value);
this->extend(op->var, value);
return Eval(op->body);
}
Value VisitExpr_(const TupleGetItemNode* op) override {
Value val = Eval(op->tuple);
auto product_node = val.as<TupleValueNode>();
CHECK(product_node)
<< "interal error: when evaluating TupleGetItem expected a tuple value";
CHECK_LT(static_cast<size_t>(op->index), product_node->fields.size())
<< "internal error: index out of bounds";
return product_node->fields[op->index];
}
Value VisitExpr_(const IfNode* op) override {
Value v = Eval(op->cond);
if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
// TODO(@jroesch, @MK): Refactor code into helper from DCE.
if (reinterpret_cast<uint8_t*>(bv->data->data)[0]) {
return Eval(op->true_branch);
} else {
return Eval(op->false_branch);
}
} else {
throw EvalError("type error, type system should have caught this");
}
}
};
Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
Interpreter::OpMap op_map;
auto lowered_ops = LowerOps(env, e);
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
if (lowered_ops.size()) {
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
CHECK(fbuild_ptr) << "Could not find registered function: relay.op.compiler._build";
auto fbuild = *fbuild_ptr;
// Collect the set of lowered functions to build a module.
Array<LoweredFunc> lowered_funcs;
for (auto lop : lowered_ops) {
lowered_funcs.push_back(lop->lowered_func);
}
Module module = fbuild(lowered_funcs);
// Loop over the lowered operations to map them into the operator map.
for (auto lop : lowered_ops) {
Function func = lop->func;
LoweredFunc lf = lop->lowered_func;
RELAY_LOG(INFO) << "LoweredFunc: " << lf->name << std::endl;
auto op_impl = module.GetFunction(lf->name);
op_map.insert({func, op_impl});
}
}
return op_map;
}
Value Evaluate(Environment env, Expr e) {
auto op_map = CompileOperators(env, e);
Interpreter interp(env, op_map);
return interp.Eval(e);
}
TVM_REGISTER_API("relay._interpreter.evaluate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0];
Expr expr = args[1];
*ret = Evaluate(env, expr);
});
} // namespace relay
} // namespace tvm

Просмотреть файл

@ -66,3 +66,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
} // namespace relay
} // namespace tvm

Просмотреть файл

@ -49,9 +49,16 @@ void EnvironmentNode::Add(const GlobalVar& var,
<< "Environment#update changes type, not possible in this mode.";
}
this->functions.Set(var, checked_func);
// set gloval var map
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
// set global var map
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
}
@ -94,7 +101,7 @@ TVM_REGISTER_API("relay._make.Environment")
TVM_REGISTER_API("relay._env.Environment_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Add(args[1], args[2], false);
env->Add(args[1], args[2], args[3]);
});
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")

Просмотреть файл

@ -26,7 +26,10 @@ TVM_REGISTER_API("relay._make.Constant")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
p->stream << "Constant(TODO)";
const PackedFunc* fprint = Registry::Get("relay._constant_repr");
CHECK(fprint) << "unable to find printing function for constants";
std::string data = (*fprint)(GetRef<Constant>(node));
p->stream << "Constant(" << data << ")";
});
TensorType ConstantNode::tensor_type() const {
@ -104,12 +107,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params) {
tvm::Array<TypeVar> type_params,
tvm::Attrs attrs) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
return Function(n);
}
@ -121,6 +126,39 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
}
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }
const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
CHECK(dict_attrs);
auto it = dict_attrs->dict.find(key);
if (it != dict_attrs->dict.end()) {
return (*it).second;
} else {
return NodeRef();
}
}
Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) {
const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
Attrs func_attrs;
if (dattrs) {
Map<std::string, NodeRef> dict = dattrs->dict;
dict.Set(key, data);
func_attrs = DictAttrsNode::make(dict);
} else {
Map<std::string, NodeRef> dict = {{key, data}};
func_attrs = DictAttrsNode::make(dict);
}
return FunctionNode::make(
func->params,
func->body,
func->ret_type,
func->type_params,
func_attrs);
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function")
@ -132,7 +170,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode* node,
tvm::IRPrinter* p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")";
<< ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")";
});
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,

Просмотреть файл

@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return FunctionNode::make(params, body, ret_type, ty_params);
return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
}
}
@ -198,6 +198,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
void ExprVisitor::VisitExpr_(const CallNode* op) {
this->VisitExpr(op->op);
for (auto ty_arg : op->type_args) {
this->VisitType(ty_arg);
}

Просмотреть файл

@ -285,11 +285,11 @@ class RelayHashHandler:
int var_counter = 0;
};
size_t StructuralHash(const Type& type) {
size_t StructuralHash::operator()(const Type& type) const {
return RelayHashHandler().TypeHash(type);
}
size_t StructuralHash(const Expr& expr) {
size_t StructuralHash::operator()(const Expr& expr) const {
return RelayHashHandler().ExprHash(expr);
}

Просмотреть файл

@ -0,0 +1,86 @@
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/fuse_ops.cc
*
* \brief Fuse Relay eligble sequences of Relay operators into a single one.
*
*/
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
struct AbstractFusableOps : ExprMutator {
Environment env;
Array<GlobalVar> fusable_funcs;
int counter = 0;
size_t expr_hash;
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {}
Expr VisitExpr_(const CallNode* call) {
if (auto op_node = call->op.as<OpNode>()) {
// Placeholder fusion algorithm which abstracts
// single definitions into functions only.
Array<Var> params;
Array<Expr> inner_args;
Array<Expr> args;
int param_number = 0;
for (auto arg : call->args) {
auto name = std::string("p") + std::to_string(param_number++);
auto type = arg->checked_type();
auto var = VarNode::make(name, type);
params.push_back(var);
inner_args.push_back(var);
args.push_back(VisitExpr(arg));
}
auto body = CallNode::make(call->op, inner_args, call->attrs);
auto func = FunctionNode::make(params, body, call->checked_type(), {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
std::string func_name = "fused_";
func_name += op_node->name;
func_name += "_";
func_name += std::to_string(counter++);
func_name += "_";
func_name += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(func_name);
env->Add(gv, func);
fusable_funcs.push_back(gv);
return CallNode::make(gv, args, Attrs());
} else {
return ExprMutator::VisitExpr_(call);
}
}
};
Expr FuseOps(const Environment& env, const Expr& e) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
auto abstract = AbstractFusableOps(env, StructuralHash()(e));
auto abstracted_e = abstract.VisitExpr(e);
RELAY_LOG(INFO) << "FuseOps: before=" << e
<< "Fuse: after=" << abstracted_e;
return abstracted_e;
}
TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[1], args[0]);
});
} // namespace relay
} // namespace tvm

222
src/relay/pass/lower_ops.cc Normal file
Просмотреть файл

@ -0,0 +1,222 @@
/*!
* Copyright (c) 2018 by Contributors
*
* \file src/tvm/relay/pass/lower_ops.cc
*
* \brief Lower a Relay program to set of TVM operators.
*
*/
#include <tvm/lowered_func.h>
#include <tvm/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/pass.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/build_module.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace runtime;
LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
auto node = make_node<LoweredOpNode>();
node->func = func;
node->lowered_func = lowered_func;
return LoweredOp(node);
}
struct AbstractLocalFunctions : ExprMutator {
Environment env;
size_t expr_hash;
int counter = 0;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
explicit AbstractLocalFunctions(Environment env)
: env(env), expr_hash(0), counter(0), visited_funcs() {}
Expr Abstract(const Expr& e) {
expr_hash = StructuralHash()(e);
return VisitExpr(e);
}
Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
auto gvar = GetRef<GlobalVar>(gvar_node);
auto it = visited_funcs.find(gvar);
if (it == visited_funcs.end()) {
auto func = env->Lookup(gvar);
visited_funcs.insert(gvar);
auto new_func = FunctionNode::make(
func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
env->Update(gvar, new_func);
}
return gvar;
}
Expr VisitExpr_(const FunctionNode* func_node) final {
Function func = GetRef<Function>(func_node);
auto free_vars = FreeVars(func);
Array<Var> params;
for (auto free_var : free_vars) {
auto var = VarNode::make("free_var", free_var->checked_type());
params.push_back(var);
}
std::string abs_func = "abstracted_func_";
abs_func += std::to_string(counter++);
abs_func += std::to_string(expr_hash);
auto gv = GlobalVarNode::make(abs_func);
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
env->Add(gv, lifted_func);
Array<Expr> args;
for (auto free_var : free_vars) {
args.push_back(free_var);
}
return CallNode::make(gv, args, {});
}
};
struct LiveFunctions : ExprVisitor {
Environment env;
explicit LiveFunctions(Environment env) : env(env), global_funcs() {}
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
void Live(const Expr& e) {
CHECK(!e.as<FunctionNode>())
<< "functions should of been transformed away by previous pass";
VisitExpr(e);
}
void VisitExpr_(const FunctionNode* func_node) {
LOG(FATAL) << "functions should of been transformed away by previous pass";
}
void VisitExpr_(const GlobalVarNode* var_node) final {
GlobalVar var = GetRef<GlobalVar>(var_node);
auto it = visited_funcs.find(var);
if (it == visited_funcs.end()) {
auto func = env->Lookup(var);
visited_funcs.insert(var);
// The last pass has trasnformed functions of the form:
//
// let x = fn (p_1, ..., p_n) { ... };
// ...
//
// into, a top-level declaration:
//
// def abs_f(fv_1, ..., fv_n) {
// return (fn (p_1...,p_N) { ... };)
// }
//
// and:
//
// let x = abs_f(fv_1, ... fv_n);
//
// The only other case we can handle is
//
// fn foo(...) { body }
//
// We just search through the body in this case.
if (auto inner_func = func->body.as<FunctionNode>()) {
return VisitExpr(inner_func->body);
} else {
return VisitExpr(func->body);
}
}
}
void VisitExpr_(const CallNode* call) final {
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
if (auto gv_node = call->op.as<GlobalVarNode>()) {
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
Function func = env->Lookup(gvar);
auto attr = FunctionGetAttr(func, "Primitive");
if (attr.defined() && Downcast<Integer>(attr)->value == 1) {
global_funcs.insert(gvar);
} else {
VisitExpr(gvar);
}
// Finally we need to ensure to visit all the args no matter what.
for (auto arg : call->args) {
VisitExpr(arg);
}
} else {
return ExprVisitor::VisitExpr_(call);
}
}
};
using FCompute = TypedPackedFunc<Array<Tensor>(
const Attrs&, const Array<Tensor>&, Type, std::string)>;
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
/*! \brief Return the set of operators in their TVM format. */
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
const std::string& target) {
RELAY_LOG(INFO) << "LowerOps: e=" << e;
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
CHECK(flower_ptr);
PackedFunc flower = *flower_ptr;
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e);
auto live_funcs = LiveFunctions(env);
live_funcs.VisitExpr(abstracted_e);
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
auto compute_reg = Op::GetAttr<FCompute>("FTVMCompute");
Array<LoweredOp> lowered_funcs;
for (auto func_name : live_funcs.global_funcs) {
auto func = env->Lookup(func_name);
auto call = Downcast<Call>(func->body);
auto op_node = call->op.as<OpNode>();
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
auto op = GetRef<Op>(op_node);
RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name;
CHECK(IsPrimitiveOp(op)) << "failed to lower "
<< op->name << "can only lower primitve operations";
Array<Tensor> inputs;
std::string input_name = "in";
int i = 0;
for (auto type_arg : call->type_args) {
auto tt = Downcast<TensorType>(type_arg);
inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i),
tt->shape, tt->dtype)
.output(0));
i++;
}
auto output_tt = op->op_type->ret_type;
Array<Tensor> outputs =
compute_reg[op](call->attrs, inputs, output_tt, target);
auto schedule = schedule_reg[op](outputs, target);
size_t hash = StructuralHash()(func);
LoweredFunc lf =
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
func = FunctionSetAttr(func, "LoweredFunc", lf);
env->Add(func_name, func, true);
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
}
return lowered_funcs;
}
TVM_REGISTER_API("relay._ir_pass.LowerOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LowerOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm

Просмотреть файл

@ -298,8 +298,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
auto* fn_ty_node = ftype.as<FuncTypeNode>();
CHECK(fn_ty_node != nullptr)
<< "only expressions with function types can be called, at "
<< call->span;
<< "only expressions with function types can be called, found "
<< ftype << " at " << call->span;
Array<Type> type_args;
FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
@ -505,12 +505,16 @@ Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Solve the constraints.
solver_.Solve();
// Step 2: Attach resolved types to checked_type field.
return Resolver(type_map_, &solver_).VisitExpr(expr);
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr));
return resolved_expr;
}
Expr InferType(const Expr& expr, const Environment& env) {
return TypeInferencer(env).Infer(expr);
auto e = TypeInferencer(env).Infer(expr);
CHECK(WellFormed(e));
return e;
}
Function InferType(const Function& func,
@ -522,6 +526,7 @@ Function InferType(const Function& func,
Expr func_ret = TypeInferencer(env).Infer(func_copy);
auto map_node = env->functions.CopyOnWrite();
map_node->data.erase(var.node_);
CHECK(WellFormed(func_ret));
return Downcast<Function>(func_ret);
}

Просмотреть файл

@ -3,7 +3,7 @@
*
* \file util.cc
*
* \brief simple util for relay.
* \brief Utility functions for Relay.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>

Просмотреть файл

@ -0,0 +1,80 @@
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import evaluate
from tvm.relay.graph_runtime_codegen import graph_evaluate
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.env import Environment
# @tq, @jr should we put this in testing ns?
def check_rts(env, expr, args, expected_result):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
Parameters
----------
expr:
The expression to evaluate
args: list of Expr
The arguments to supply the expr.
expected_result:
The expected result of running the expression.
"""
eval_result = evaluate(env, expr, *args)
rts_result = graph_evaluate(env, expr, *args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
def test_add_op_scalar():
"""
Program:
fn (x, y) {
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=())
y = relay.var('y', shape=())
func = relay.Function([x, y], add(x, y))
x_data = np.array(10.0, dtype='float32')
y_data = np.array(1.0, dtype='float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
def test_add_op_tensor():
"""
Program:
fn (x, y) {
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(10, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
def test_add_op_broadcast():
"""
Program:
fn (x, y) {
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
if __name__ == "__main__":
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()

Просмотреть файл

@ -0,0 +1,142 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.interpreter import Value, TupleValue, evaluate
from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
if env is None:
env = relay.env.Environment({})
result = evaluate(env, expr, *args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def test_from_scalar():
np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1)
np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0)
np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True)
def test_tuple_value():
tv = TupleValue(Value.from_scalar(
1), Value.from_scalar(2), Value.from_scalar(3))
np.testing.assert_allclose(tv[0].asnumpy(), 1)
np.testing.assert_allclose(tv[1].asnumpy(), 2)
np.testing.assert_allclose(tv[2].asnumpy(), 3)
def test_id():
x = relay.var('x', 'float32')
ident = relay.Function([x], x)
env = relay.env.Environment({})
res = evaluate(env, ident, 1.0)
check_eval(ident, [1.0], 1.0)
def test_add_const():
two = op.add(relay.const(1), relay.const(1))
func = relay.Function([], two)
check_eval(func, [], 2)
def test_mul_param():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(1, 10))
func = relay.Function([x, y], op.multiply(x, y))
x_data = np.random.rand(10, 10).astype('float32')
y_data = np.random.rand(1, 10).astype('float32')
check_eval(func, [x_data, y_data], x_data * y_data)
# failing due to numeric issues
# def test_dense():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# y = op.nn.dense(x, w)
# func = relay.Function([x, w], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
# check_eval(func, [x_data, w_data], x_data @ w_data, rtol=0.1)
# def test_linear():
# x = relay.var('x', shape=(10, 10))
# w = relay.var('w', shape=(10, 10))
# b = relay.var('b', shape=(10,))
# y = op.add(op.nn.dense(x, w), b)
# func = relay.Function([x, w, b], y)
# x_data = np.random.rand(10, 10).astype('float32')
# w_data = np.random.rand(10, 10).astype('float32')
# b_data = np.random.rand(10).astype('float32')
# check_eval(func, [x_data, w_data, b_data], x_data @ w_data + b_data)
def test_equal():
i = relay.var('i', shape=[], dtype='int32')
j = relay.var('i', shape=[], dtype='int32')
z = op.equal(i, j)
func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool'))
i_data = relay.const(0)
j_data = relay.const(0)
check_eval(func, [i_data, j_data], True)
def test_subtract():
i = relay.var('i', shape=[], dtype='int32')
sub = op.subtract(i, relay.const(1, dtype='int32'))
func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32'))
i_data = np.array(1, dtype='int32')
check_eval(func, [i_data], 0)
def test_simple_loop():
env = relay.env.Environment({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1, dtype='int32'))
rec_call = relay.Call(sum_up, [one_less])
sb.ret(op.add(rec_call, i))
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
env[sum_up] = func
i_data = np.array(10, dtype='int32')
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env)
def test_loop():
env = relay.env.Environment({})
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
accum = relay.var('accum', shape=[], dtype='int32')
sb = ScopeBuilder()
with sb.if_scope(op.equal(i, relay.const(0))):
sb.ret(accum)
with sb.else_scope():
one_less = op.subtract(i, relay.const(1))
new_accum = op.add(accum, i)
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
func = relay.Function([i, accum], sb.get())
env[sum_up] = func
i_data = np.array(10, dtype='int32')
accum_data = np.array(0, dtype='int32')
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env)
def test_mlp():
pass
# net = testing.mlp.get_workload(1)
# import pdb; pdb.set_trace()
if __name__ == "__main__":
test_id()
test_add_const()
# test_dense()
# test_linear()
test_equal()
test_subtract()
test_simple_loop()
test_loop()
test_mlp()

Просмотреть файл

@ -5,6 +5,16 @@ import tvm
import numpy as np
from tvm.relay.ir_pass import infer_type
from tvm import relay
from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder
def assert_has_type(expr, typ, env=relay.env.Environment({})):
checked_expr = infer_type(expr, env)
checked_type = checked_expr.checked_type
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def test_monomorphic_let():
@ -16,6 +26,31 @@ def test_monomorphic_let():
assert xchecked.checked_type == relay.scalar_type("float64")
def test_single_op():
"Program: fn (x : float32) { let t1 = f(x); t1 }"
x = relay.var('x', shape=[])
func = relay.Function([x], op.log(x))
ttype = relay.TensorType([], dtype='float32')
assert_has_type(func, relay.FuncType([ttype], ttype))
def test_add_broadcast_op():
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x + y;
}
"""
pass
# x = relay.var('x', shape=(10, 4))
# y = relay.var('y', shape=(5, 10, 1))
# z = x + y
# func = relay.Function([x, y], z)
# ttype = relay.TensorType((5, 5, 5), 'float32')
# expected_ty = relay.FuncType([ttype, ttype], ttype)
# assert_has_type(func.to_func(), expected_ty)
def test_dual_op():
"""Program:
fn (x : Tensor[f32, (10, 10)]) {
@ -41,7 +76,6 @@ def test_decl():
return log(x);
}
"""
sb = relay.ScopeBuilder()
tp = relay.TensorType((10, 10))
x = relay.var("x", tp)
f = relay.Function([x], relay.log(x))
@ -76,6 +110,24 @@ def test_recursion():
assert "%3 = @f(%1, %2)" in env.astext()
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
# This currently fails and should pass under the type system.
#
# This test is to illustrate problem with our weak form of
# unification.
#
def test_incomplete_call():
sb = ScopeBuilder()
x = relay.var('x', dtype='int32')
f = relay.var('f')
func = relay.Function([x, f], relay.Call(f, [x]))
try:
relay.ir_pass.infer_type(func)
assert False
except tvm.TVMError as e:
assert True
def test_tuple():
tp = relay.TensorType((10,))
@ -84,13 +136,13 @@ def test_tuple():
assert (relay.ir_pass.infer_type(res).checked_type ==
relay.TupleType([tp, tp]))
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32")
def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
@ -107,6 +159,7 @@ def test_type_args():
assert sh2[0].value == 1
assert sh2[1].value == 10
def test_self_reference():
"""
Program:
@ -117,31 +170,41 @@ def test_self_reference():
a = relay.TypeVar("a")
x = relay.var("x", a)
sb = relay.ScopeBuilder()
f = relay.Function([x], x)
fx = relay.Call(f, [x])
assert relay.ir_pass.infer_type(x).checked_type == a
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a
def test_global_var_cow_issue():
env = relay.env.Environment({})
gv = relay.GlobalVar("foo")
x = relay.var('x', shape=[])
func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32'))
func = relay.Function([x], relay.Call(gv, [x]),
relay.TensorType([], 'float32'))
env[gv] = func
# They should both point to the same global variable if global variables are
# stable across type checking.
assert gv == func.body.op
def test_equal():
i = relay.var('i', shape=[], dtype='int32')
eq = op.equal(i, relay.const(0, dtype='int32'))
# This should fail ....
func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32'))
if __name__ == "__main__":
test_free_expr()
test_dual_op()
test_single_op()
test_recursion()
test_monomorphic_let()
test_decl()
test_recursion()
test_tuple()
test_incomplete_call()
test_free_expr()
test_type_args()
test_self_reference()
test_global_var_cow_issue()
test_global_var_cow_issue()

Просмотреть файл

@ -1,5 +1,5 @@
#!/bin/bash
export PYTHONPATH=python:apps/extension/python
export PYTHONPATH=python:topi/python:apps/extension/python
export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH}
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc