Rename relay::Environment to relay::Module (#2054)
This commit is contained in:
Родитель
420ec786e9
Коммит
ead3ac6c23
|
@ -165,7 +165,7 @@ class RelayNode : public Node {
|
||||||
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
|
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Environment;
|
struct Module;
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
#define TVM_RELAY_BUILD_MODULE_H_
|
#define TVM_RELAY_BUILD_MODULE_H_
|
||||||
|
|
||||||
#include <tvm/lowered_func.h>
|
#include <tvm/lowered_func.h>
|
||||||
#include <tvm/relay/environment.h>
|
#include <tvm/relay/module.h>
|
||||||
#include <tvm/relay/expr.h>
|
#include <tvm/relay/expr.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
|
||||||
* \note This will do a reachability analysis and lower all definitions
|
* \note This will do a reachability analysis and lower all definitions
|
||||||
* reachable from the provided expression.
|
* reachable from the provided expression.
|
||||||
*
|
*
|
||||||
* \param env The environment.
|
* \param mod The module.
|
||||||
* \param expr The expression with operations to be lowered.
|
* \param expr The expression with operations to be lowered.
|
||||||
* \param target The target to lower the functions to.
|
* \param target The target to lower the functions to.
|
||||||
*
|
*
|
||||||
* \return The set of lowered operations.
|
* \return The set of lowered operations.
|
||||||
*/
|
*/
|
||||||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
|
Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
|
||||||
const std::string& target = "llvm");
|
const std::string& target = "llvm");
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
|
|
|
@ -160,7 +160,7 @@ class VarNode : public ExprNode {
|
||||||
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
|
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Global variable that leaves in the top-level environment.
|
* \brief Global variable that leaves in the top-level module.
|
||||||
* This is used to enable recursive calls between function.
|
* This is used to enable recursive calls between function.
|
||||||
*
|
*
|
||||||
* \note A GlobalVar may only point to functions.
|
* \note A GlobalVar may only point to functions.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
* \brief An interpreter for Relay.
|
* \brief An interpreter for Relay.
|
||||||
*
|
*
|
||||||
* This file implements a simple reference interpreter for Relay programs.
|
* This file implements a simple reference interpreter for Relay programs.
|
||||||
* Given a Relay environment, and a Relay expression it produces a value.
|
* Given a Relay module, and a Relay expression it produces a value.
|
||||||
*
|
*
|
||||||
* The interpreter's values are a naive representation of the values that
|
* The interpreter's values are a naive representation of the values that
|
||||||
* can be produced by a Relay program and are exposed via tvm::Node's
|
* can be produced by a Relay program and are exposed via tvm::Node's
|
||||||
|
@ -16,7 +16,7 @@
|
||||||
#ifndef TVM_RELAY_INTERPRETER_H_
|
#ifndef TVM_RELAY_INTERPRETER_H_
|
||||||
#define TVM_RELAY_INTERPRETER_H_
|
#define TVM_RELAY_INTERPRETER_H_
|
||||||
|
|
||||||
#include <tvm/relay/environment.h>
|
#include <tvm/relay/module.h>
|
||||||
#include <tvm/relay/expr.h>
|
#include <tvm/relay/expr.h>
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
@ -39,7 +39,7 @@ class Value;
|
||||||
* Our intent is that this will never be the most efficient implementation of
|
* Our intent is that this will never be the most efficient implementation of
|
||||||
* Relay's semantics, but a readable and clear one.
|
* Relay's semantics, but a readable and clear one.
|
||||||
*/
|
*/
|
||||||
Value Evaluate(Environment env, Expr e);
|
Value Evaluate(Module mod, Expr e);
|
||||||
|
|
||||||
/*! \brief The base container type of Relay values. */
|
/*! \brief The base container type of Relay values. */
|
||||||
class ValueNode : public RelayNode {
|
class ValueNode : public RelayNode {
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2018 by Contributors
|
* Copyright (c) 2018 by Contributors
|
||||||
* \file tvm/relay/environment.h
|
* \file tvm/relay/module.h
|
||||||
* \brief The global environment: contains information needed to
|
* \brief The global environment: contains information needed to
|
||||||
* compile & optimize Relay programs.
|
* compile & optimize Relay programs.
|
||||||
*/
|
*/
|
||||||
#ifndef TVM_RELAY_ENVIRONMENT_H_
|
#ifndef TVM_RELAY_MODULE_H_
|
||||||
#define TVM_RELAY_ENVIRONMENT_H_
|
#define TVM_RELAY_MODULE_H_
|
||||||
|
|
||||||
#include <tvm/relay/error.h>
|
#include <tvm/relay/error.h>
|
||||||
#include <tvm/relay/expr.h>
|
#include <tvm/relay/expr.h>
|
||||||
|
@ -17,7 +17,7 @@
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
|
||||||
struct Environment;
|
struct Module;
|
||||||
|
|
||||||
/*! \brief The global environment of Relay programs.
|
/*! \brief The global environment of Relay programs.
|
||||||
*
|
*
|
||||||
|
@ -28,29 +28,29 @@ struct Environment;
|
||||||
* options.
|
* options.
|
||||||
*
|
*
|
||||||
* Many operations require access to the global
|
* Many operations require access to the global
|
||||||
* Environment. We pass the Environment by value
|
* Module. We pass the Module by value
|
||||||
* in a functional style as an explicit argument,
|
* in a functional style as an explicit argument,
|
||||||
* but we mutate the Environment while optimizing
|
* but we mutate the Module while optimizing
|
||||||
* Relay programs.
|
* Relay programs.
|
||||||
*
|
*
|
||||||
* The functional style allows users to construct custom
|
* The functional style allows users to construct custom
|
||||||
* environments easily, for example each thread can store
|
* environments easily, for example each thread can store
|
||||||
* an Environment while auto-tuning.
|
* an Module while auto-tuning.
|
||||||
* */
|
* */
|
||||||
|
|
||||||
class EnvironmentNode : public RelayNode {
|
class ModuleNode : public RelayNode {
|
||||||
public:
|
public:
|
||||||
/*! \brief A map from ids to all global functions. */
|
/*! \brief A map from ids to all global functions. */
|
||||||
tvm::Map<GlobalVar, Function> functions;
|
tvm::Map<GlobalVar, Function> functions;
|
||||||
|
|
||||||
EnvironmentNode() {}
|
ModuleNode() {}
|
||||||
|
|
||||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||||
v->Visit("functions", &functions);
|
v->Visit("functions", &functions);
|
||||||
v->Visit("global_var_map_", &global_var_map_);
|
v->Visit("global_var_map_", &global_var_map_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
|
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Add a function to the global environment.
|
* \brief Add a function to the global environment.
|
||||||
|
@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
|
||||||
* functions in another environment.
|
* functions in another environment.
|
||||||
* \param other The other environment.
|
* \param other The other environment.
|
||||||
*/
|
*/
|
||||||
void Update(const Environment& other);
|
void Update(const Module& other);
|
||||||
|
|
||||||
static constexpr const char* _type_key = "relay.Environment";
|
static constexpr const char* _type_key = "relay.Module";
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
|
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief A map from string names to global variables that
|
/*! \brief A map from string names to global variables that
|
||||||
|
@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
|
||||||
tvm::Map<std::string, GlobalVar> global_var_map_;
|
tvm::Map<std::string, GlobalVar> global_var_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Environment : public NodeRef {
|
struct Module : public NodeRef {
|
||||||
Environment() {}
|
Module() {}
|
||||||
explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {}
|
explicit Module(NodePtr<tvm::Node> p) : NodeRef(p) {}
|
||||||
|
|
||||||
inline EnvironmentNode* operator->() const {
|
inline ModuleNode* operator->() const {
|
||||||
return static_cast<EnvironmentNode*>(node_.get());
|
return static_cast<ModuleNode*>(node_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
using ContainerType = EnvironmentNode;
|
using ContainerType = ModuleNode;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
||||||
#endif // TVM_RELAY_ENVIRONMENT_H_
|
#endif // TVM_RELAY_MODULE_H_
|
|
@ -6,7 +6,7 @@
|
||||||
#ifndef TVM_RELAY_PASS_H_
|
#ifndef TVM_RELAY_PASS_H_
|
||||||
#define TVM_RELAY_PASS_H_
|
#define TVM_RELAY_PASS_H_
|
||||||
|
|
||||||
#include <tvm/relay/environment.h>
|
#include <tvm/relay/module.h>
|
||||||
#include <tvm/relay/expr.h>
|
#include <tvm/relay/expr.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
@ -21,23 +21,23 @@ namespace relay {
|
||||||
* populated with the result type.
|
* populated with the result type.
|
||||||
*
|
*
|
||||||
* \param expr The expression to type check.
|
* \param expr The expression to type check.
|
||||||
* \param env The environment used for referencing global functions, can be
|
* \param mod The module used for referencing global functions, can be
|
||||||
* None.
|
* None.
|
||||||
*
|
*
|
||||||
* \return A type checked expression with its checked_type field populated.
|
* \return A type checked expression with its checked_type field populated.
|
||||||
*/
|
*/
|
||||||
Expr InferType(const Expr& expr, const Environment& env);
|
Expr InferType(const Expr& expr, const Module& mod);
|
||||||
/*!
|
/*!
|
||||||
* \brief Infer the type of a function as if it is mapped to var in the env.
|
* \brief Infer the type of a function as if it is mapped to var in the mod.
|
||||||
*
|
*
|
||||||
* \param f the function.
|
* \param f the function.
|
||||||
* \param env The environment used for referencing global functions.
|
* \param mod The module used for referencing global functions.
|
||||||
* \param var The global variable corresponding to the function.
|
* \param var The global variable corresponding to the function.
|
||||||
*
|
*
|
||||||
* \return A type checked Function with its checked_type field populated.
|
* \return A type checked Function with its checked_type field populated.
|
||||||
* \note this function mutates env and is not thread-safe.
|
* \note this function mutates mod and is not thread-safe.
|
||||||
*/
|
*/
|
||||||
Function InferType(const Function& f, const Environment& env,
|
Function InferType(const Function& f, const Module& mod,
|
||||||
const GlobalVar& var);
|
const GlobalVar& var);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
|
||||||
* a data type such as `int`, `float`, `uint`.
|
* a data type such as `int`, `float`, `uint`.
|
||||||
*
|
*
|
||||||
* \param t The type to check.
|
* \param t The type to check.
|
||||||
* \param env The global environment.
|
* \param mod The global module.
|
||||||
*
|
*
|
||||||
* \return true if the rules are satisified otherwise false
|
* \return true if the rules are satisified otherwise false
|
||||||
*/
|
*/
|
||||||
bool KindCheck(const Type& t, const Environment& env);
|
bool KindCheck(const Type& t, const Module& mod);
|
||||||
|
|
||||||
/*! \brief Compare two expressions for structural equivalence.
|
/*! \brief Compare two expressions for structural equivalence.
|
||||||
*
|
*
|
||||||
|
|
|
@ -349,14 +349,14 @@ class TypeRelation;
|
||||||
/*!
|
/*!
|
||||||
* \brief TypeRelation container.
|
* \brief TypeRelation container.
|
||||||
* \note This node is not directly serializable.
|
* \note This node is not directly serializable.
|
||||||
* The type function need to be lookedup in the environment.
|
* The type function need to be lookedup in the module.
|
||||||
*/
|
*/
|
||||||
class TypeRelationNode : public TypeConstraintNode {
|
class TypeRelationNode : public TypeConstraintNode {
|
||||||
public:
|
public:
|
||||||
/*!
|
/*!
|
||||||
* \brief The function on input and output variables which
|
* \brief The function on input and output variables which
|
||||||
* this is not directly serializable,
|
* this is not directly serializable,
|
||||||
* need to be looked-up in the environment.
|
* need to be looked-up in the module.
|
||||||
*/
|
*/
|
||||||
TypeRelationFn func;
|
TypeRelationFn func;
|
||||||
/*! \brief The type arguments to the type function. */
|
/*! \brief The type arguments to the type function. */
|
||||||
|
|
|
@ -5,7 +5,7 @@ from ..api import register_func
|
||||||
from . import base
|
from . import base
|
||||||
from . import ty
|
from . import ty
|
||||||
from . import expr
|
from . import expr
|
||||||
from . import env
|
from . import module
|
||||||
from . import ir_pass
|
from . import ir_pass
|
||||||
from .build_module import build
|
from .build_module import build
|
||||||
from .interpreter import create_executor
|
from .interpreter import create_executor
|
||||||
|
@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
|
||||||
Span = base.Span
|
Span = base.Span
|
||||||
|
|
||||||
# Env
|
# Env
|
||||||
Environment = env.Environment
|
Module = module.Module
|
||||||
|
|
||||||
# Type
|
# Type
|
||||||
Type = ty.Type
|
Type = ty.Type
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from .env import Environment
|
from .env import Module
|
||||||
from . import ir
|
from . import ir
|
||||||
|
|
||||||
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
|
def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
|
||||||
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
|
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
|
||||||
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
|
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
|
||||||
def well_formed(expr: ir.Expr) -> bool: ...
|
def well_formed(expr: ir.Expr) -> bool: ...
|
||||||
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
|
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
|
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
|
||||||
"""The interface to the Environment exposed from C++."""
|
"""The interface to the Module exposed from C++."""
|
||||||
from tvm._ffi.function import _init_api
|
from tvm._ffi.function import _init_api
|
||||||
|
|
||||||
_init_api("relay._env", __name__)
|
_init_api("relay._module", __name__)
|
|
@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
|
||||||
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
|
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
|
||||||
from relay.ir import ShapeExtension, Operator, Defn
|
from relay.ir import ShapeExtension, Operator, Defn
|
||||||
|
|
||||||
class Environment(NodeBase): ...
|
class Module(NodeBase): ...
|
|
@ -5,9 +5,9 @@ from a Relay expression.
|
||||||
from ..build_module import build as tvm_build_module
|
from ..build_module import build as tvm_build_module
|
||||||
from . graph_runtime_codegen import GraphRuntimeCodegen
|
from . graph_runtime_codegen import GraphRuntimeCodegen
|
||||||
from . import ir_pass
|
from . import ir_pass
|
||||||
from .env import Environment
|
from .module import Module
|
||||||
|
|
||||||
def build(func, params=None, target=None, env=None):
|
def build(func, params=None, target=None, mod=None):
|
||||||
"""
|
"""
|
||||||
Compile a single function to the components needed by the
|
Compile a single function to the components needed by the
|
||||||
TVM RTS.
|
TVM RTS.
|
||||||
|
@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
|
||||||
if target is None:
|
if target is None:
|
||||||
target = 'llvm'
|
target = 'llvm'
|
||||||
|
|
||||||
if env is None:
|
if mod is None:
|
||||||
env = Environment({})
|
mod = Module({})
|
||||||
|
|
||||||
comp = GraphRuntimeCodegen(env)
|
comp = GraphRuntimeCodegen(mod)
|
||||||
# NB(@jroesch) This creates lowered functions, and generates names for them
|
# NB(@jroesch) This creates lowered functions, and generates names for them
|
||||||
#
|
#
|
||||||
# We need these names to emit the correct graph as these are names of the
|
# We need these names to emit the correct graph as these are names of the
|
||||||
# functions contained in the module.
|
# functions contained in the module.
|
||||||
lowered_ops = ir_pass.lower_ops(env, func)
|
lowered_ops = ir_pass.lower_ops(mod, func)
|
||||||
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
|
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
|
||||||
|
|
||||||
# Therefore the call to compile must come after.
|
# Therefore the call to compile must come after.
|
||||||
|
|
|
@ -172,7 +172,7 @@ class GlobalVar(Expr):
|
||||||
"""A global variable in Tvm.Relay.
|
"""A global variable in Tvm.Relay.
|
||||||
|
|
||||||
GlobalVar is used to refer to the global functions
|
GlobalVar is used to refer to the global functions
|
||||||
stored in the environment.
|
stored in the module.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
|
@ -8,7 +8,7 @@ from . import build_module
|
||||||
from . import _make
|
from . import _make
|
||||||
from . import _interpreter
|
from . import _interpreter
|
||||||
from . import ir_pass
|
from . import ir_pass
|
||||||
from .env import Environment
|
from .module import Module
|
||||||
from .expr import Call, Constant, GlobalVar, Function, const
|
from .expr import Call, Constant, GlobalVar, Function, const
|
||||||
from .scope_builder import ScopeBuilder
|
from .scope_builder import ScopeBuilder
|
||||||
from .._ffi.base import integer_types
|
from .._ffi.base import integer_types
|
||||||
|
@ -90,24 +90,24 @@ def _arg_to_ast(arg):
|
||||||
class Executor(object):
|
class Executor(object):
|
||||||
"""An abstract interface for executing Relay programs."""
|
"""An abstract interface for executing Relay programs."""
|
||||||
|
|
||||||
def __init__(self, env=None):
|
def __init__(self, mod=None):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
env: relay.Environment
|
mod: relay.Module
|
||||||
The environment.
|
The module.
|
||||||
"""
|
"""
|
||||||
if env is None:
|
if mod is None:
|
||||||
self.env = Environment({})
|
self.mod = Module({})
|
||||||
else:
|
else:
|
||||||
self.env = env
|
self.mod = mod
|
||||||
|
|
||||||
|
|
||||||
def optimize(self, expr):
|
def optimize(self, expr):
|
||||||
# TODO: We need to move this optimization code into the optimizer/pass manager
|
# TODO: We need to move this optimization code into the optimizer/pass manager
|
||||||
ck_expr = ir_pass.infer_type(expr, env=self.env)
|
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
|
||||||
fused_expr = ir_pass.fuse_ops(self.env, ck_expr)
|
fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
|
||||||
ck_fused = ir_pass.infer_type(fused_expr, env=self.env)
|
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
|
||||||
return ck_fused
|
return ck_fused
|
||||||
|
|
||||||
def _make_executor(self, _):
|
def _make_executor(self, _):
|
||||||
|
@ -153,8 +153,8 @@ class Interpreter(Executor):
|
||||||
"""
|
"""
|
||||||
A wrapper around the Relay interpreter, implements the excecutor interface.
|
A wrapper around the Relay interpreter, implements the excecutor interface.
|
||||||
"""
|
"""
|
||||||
def __init__(self, env=None):
|
def __init__(self, mod=None):
|
||||||
Executor.__init__(self, env)
|
Executor.__init__(self, mod)
|
||||||
|
|
||||||
def _make_executor(self, expr):
|
def _make_executor(self, expr):
|
||||||
def _interp_wrapper(*args):
|
def _interp_wrapper(*args):
|
||||||
|
@ -163,28 +163,28 @@ class Interpreter(Executor):
|
||||||
relay_args.append(_arg_to_ast(arg))
|
relay_args.append(_arg_to_ast(arg))
|
||||||
|
|
||||||
if isinstance(expr, GlobalVar):
|
if isinstance(expr, GlobalVar):
|
||||||
func = self.env[expr]
|
func = self.mod[expr]
|
||||||
func = self.optimize(func)
|
func = self.optimize(func)
|
||||||
self.env._add(expr, func, True)
|
self.mod._add(expr, func, True)
|
||||||
opt_expr = Call(expr, relay_args)
|
opt_expr = Call(expr, relay_args)
|
||||||
return _interpreter.evaluate(self.env, opt_expr)
|
return _interpreter.evaluate(self.mod, opt_expr)
|
||||||
else:
|
else:
|
||||||
call = Call(expr, relay_args)
|
call = Call(expr, relay_args)
|
||||||
opt_expr = self.optimize(call)
|
opt_expr = self.optimize(call)
|
||||||
return _interpreter.evaluate(self.env, opt_expr)
|
return _interpreter.evaluate(self.mod, opt_expr)
|
||||||
|
|
||||||
return _interp_wrapper
|
return _interp_wrapper
|
||||||
|
|
||||||
|
|
||||||
class GraphRuntime(Executor):
|
class GraphRuntime(Executor):
|
||||||
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
|
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
|
||||||
def __init__(self, env=None):
|
def __init__(self, mod=None):
|
||||||
Executor.__init__(self, env)
|
Executor.__init__(self, mod)
|
||||||
|
|
||||||
def _make_executor(self, expr):
|
def _make_executor(self, expr):
|
||||||
def _graph_wrapper(*args):
|
def _graph_wrapper(*args):
|
||||||
func = self.optimize(expr)
|
func = self.optimize(expr)
|
||||||
graph_json, mod, params = build_module.build(func, env=self.env)
|
graph_json, mod, params = build_module.build(func, mod=self.mod)
|
||||||
assert params is None
|
assert params is None
|
||||||
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
|
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
|
||||||
# Create map of inputs.
|
# Create map of inputs.
|
||||||
|
@ -199,10 +199,10 @@ class GraphRuntime(Executor):
|
||||||
|
|
||||||
return _graph_wrapper
|
return _graph_wrapper
|
||||||
|
|
||||||
def create_executor(mode='debug', env=None):
|
def create_executor(mode='debug', mod=None):
|
||||||
if mode == 'debug':
|
if mode == 'debug':
|
||||||
return Interpreter(env)
|
return Interpreter(mod)
|
||||||
elif mode == 'graph':
|
elif mode == 'graph':
|
||||||
return GraphRuntime(env)
|
return GraphRuntime(mod)
|
||||||
else:
|
else:
|
||||||
raise Exception("unknown mode {0}".format(mode))
|
raise Exception("unknown mode {0}".format(mode))
|
||||||
|
|
|
@ -11,16 +11,16 @@ from .expr import Expr
|
||||||
from .ty import Type
|
from .ty import Type
|
||||||
|
|
||||||
|
|
||||||
def infer_type(expr, env=None):
|
def infer_type(expr, mod=None):
|
||||||
"""Infer the type of expr under the context of env.
|
"""Infer the type of expr under the context of mod.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: tvm.relay.Expr
|
expr: tvm.relay.Expr
|
||||||
The input expression.
|
The input expression.
|
||||||
|
|
||||||
env: Optional[tvm.relay.Environment]
|
mod: Optional[tvm.relay.Module]
|
||||||
The global environment.
|
The global module.
|
||||||
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -28,7 +28,7 @@ def infer_type(expr, env=None):
|
||||||
checked_expr : tvm.relay.Expr
|
checked_expr : tvm.relay.Expr
|
||||||
The checked expression.
|
The checked expression.
|
||||||
"""
|
"""
|
||||||
return _ir_pass.infer_type(expr, env)
|
return _ir_pass.infer_type(expr, mod)
|
||||||
|
|
||||||
|
|
||||||
def backward_fold_scale_axis(expr):
|
def backward_fold_scale_axis(expr):
|
||||||
|
@ -93,7 +93,7 @@ def well_formed(expr):
|
||||||
return _ir_pass.well_formed(expr)
|
return _ir_pass.well_formed(expr)
|
||||||
|
|
||||||
|
|
||||||
def check_kind(t, env=None):
|
def check_kind(t, mod=None):
|
||||||
"""Check that the type is well kinded.
|
"""Check that the type is well kinded.
|
||||||
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
|
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
|
||||||
|
|
||||||
|
@ -102,8 +102,8 @@ def check_kind(t, env=None):
|
||||||
t: tvm.relay.Type
|
t: tvm.relay.Type
|
||||||
The type to check
|
The type to check
|
||||||
|
|
||||||
env: tvm.relay.Environment, optional
|
mod: tvm.relay.Module, optional
|
||||||
The global environment
|
The global module
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -117,8 +117,8 @@ def check_kind(t, env=None):
|
||||||
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
|
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
|
||||||
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
|
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
|
||||||
"""
|
"""
|
||||||
if env is not None:
|
if mod is not None:
|
||||||
return _ir_pass.check_kind(t, env)
|
return _ir_pass.check_kind(t, mod)
|
||||||
else:
|
else:
|
||||||
return _ir_pass.check_kind(t)
|
return _ir_pass.check_kind(t)
|
||||||
|
|
||||||
|
@ -256,8 +256,8 @@ def structural_hash(value):
|
||||||
"relay.Expr or relay.Type").format(type(value))
|
"relay.Expr or relay.Type").format(type(value))
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
def fuse_ops(expr, env):
|
def fuse_ops(expr, mod):
|
||||||
return _ir_pass.FuseOps(env, expr)
|
return _ir_pass.FuseOps(mod, expr)
|
||||||
|
|
||||||
def lower_ops(env, expr, target='llvm'):
|
def lower_ops(mod, expr, target='llvm'):
|
||||||
return _ir_pass.LowerOps(env, expr, target)
|
return _ir_pass.LowerOps(mod, expr, target)
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
|
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
|
||||||
"""A global environment storing everything needed to interpret or compile a Relay program."""
|
"""A global module storing everything needed to interpret or compile a Relay program."""
|
||||||
from .base import register_relay_node, RelayNode
|
from .base import register_relay_node, RelayNode
|
||||||
from .._ffi import base as _base
|
from .._ffi import base as _base
|
||||||
from . import _make
|
from . import _make
|
||||||
from . import _env
|
from . import _module
|
||||||
from . import expr as _expr
|
from . import expr as _expr
|
||||||
|
|
||||||
|
|
||||||
@register_relay_node
|
@register_relay_node
|
||||||
class Environment(RelayNode):
|
class Module(RelayNode):
|
||||||
"""The global Relay environment containing collection of functions.
|
"""The global Relay module containing collection of functions.
|
||||||
|
|
||||||
Each global function is identified by an unique tvm.relay.GlobalVar.
|
Each global function is identified by an unique tvm.relay.GlobalVar.
|
||||||
tvm.relay.GlobalVar and Environment is necessary in order to enable
|
tvm.relay.GlobalVar and Module is necessary in order to enable
|
||||||
recursions in function to avoid cyclic reference in the function.x
|
recursions in function to avoid cyclic reference in the function.x
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -32,10 +32,10 @@ class Environment(RelayNode):
|
||||||
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
|
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
|
||||||
mapped_funcs[k] = v
|
mapped_funcs[k] = v
|
||||||
functions = mapped_funcs
|
functions = mapped_funcs
|
||||||
self.__init_handle_by_constructor__(_make.Environment, functions)
|
self.__init_handle_by_constructor__(_make.Module, functions)
|
||||||
|
|
||||||
def __setitem__(self, var, func):
|
def __setitem__(self, var, func):
|
||||||
"""Add a function to the environment.
|
"""Add a function to the module.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
---------
|
---------
|
||||||
|
@ -50,7 +50,7 @@ class Environment(RelayNode):
|
||||||
def _add(self, var, func, update=False):
|
def _add(self, var, func, update=False):
|
||||||
if isinstance(var, _base.string_types):
|
if isinstance(var, _base.string_types):
|
||||||
var = _expr.GlobalVar(var)
|
var = _expr.GlobalVar(var)
|
||||||
return _env.Environment_Add(self, var, func, update)
|
return _module.Module_Add(self, var, func, update)
|
||||||
|
|
||||||
def __getitem__(self, var):
|
def __getitem__(self, var):
|
||||||
"""Lookup a global function by name or by variable.
|
"""Lookup a global function by name or by variable.
|
||||||
|
@ -66,21 +66,21 @@ class Environment(RelayNode):
|
||||||
The function referenced by :code:`var`.
|
The function referenced by :code:`var`.
|
||||||
"""
|
"""
|
||||||
if isinstance(var, _base.string_types):
|
if isinstance(var, _base.string_types):
|
||||||
return _env.Environment_Lookup_str(self, var)
|
return _module.Module_Lookup_str(self, var)
|
||||||
else:
|
else:
|
||||||
return _env.Environment_Lookup(self, var)
|
return _module.Module_Lookup(self, var)
|
||||||
|
|
||||||
def update(self, other):
|
def update(self, other):
|
||||||
"""Insert functions in another Environment to current one.
|
"""Insert functions in another Module to current one.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
other: Environment
|
other: Module
|
||||||
The environment to merge into the current Environment.
|
The module to merge into the current Module.
|
||||||
"""
|
"""
|
||||||
if isinstance(other, dict):
|
if isinstance(other, dict):
|
||||||
other = Environment(other)
|
other = Module(other)
|
||||||
return _env.Environment_Update(self, other)
|
return _module.Module_Update(self, other)
|
||||||
|
|
||||||
def get_global_var(self, name):
|
def get_global_var(self, name):
|
||||||
"""Get a global variable in the function by name.
|
"""Get a global variable in the function by name.
|
||||||
|
@ -99,4 +99,4 @@ class Environment(RelayNode):
|
||||||
------
|
------
|
||||||
tvm.TVMError if we cannot find corresponding global var.
|
tvm.TVMError if we cannot find corresponding global var.
|
||||||
"""
|
"""
|
||||||
return _env.Environment_GetGlobalVar(self, name)
|
return _module.Module_GetGlobalVar(self, name)
|
|
@ -183,7 +183,7 @@ struct ExprEqual {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||||
Environment env;
|
Module mod;
|
||||||
Stack stack;
|
Stack stack;
|
||||||
using JitKey = Function;
|
using JitKey = Function;
|
||||||
|
|
||||||
|
@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||||
return f();
|
return f();
|
||||||
}
|
}
|
||||||
|
|
||||||
Interpreter(Environment env) : env(env), operator_map_() {}
|
Interpreter(Module mod) : mod(mod), operator_map_() {}
|
||||||
Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {}
|
Interpreter(Module mod, OpMap operator_map) : mod(mod), operator_map_(operator_map) {}
|
||||||
|
|
||||||
void extend(const Var& id, Value v) {
|
void extend(const Var& id, Value v) {
|
||||||
this->stack.current_frame().locals.Set(id, v);
|
this->stack.current_frame().locals.Set(id, v);
|
||||||
|
@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Value VisitExpr_(const GlobalVarNode* op) override {
|
Value VisitExpr_(const GlobalVarNode* op) override {
|
||||||
return Eval(this->env->Lookup(GetRef<GlobalVar>(op)));
|
return Eval(this->mod->Lookup(GetRef<GlobalVar>(op)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value VisitExpr_(const OpNode* id) override {
|
Value VisitExpr_(const OpNode* id) override {
|
||||||
|
@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||||
|
|
||||||
Value VisitExpr_(const FunctionNode* func_node) override {
|
Value VisitExpr_(const FunctionNode* func_node) override {
|
||||||
auto func = GetRef<Function>(func_node);
|
auto func = GetRef<Function>(func_node);
|
||||||
tvm::Map<Var, Value> captured_env;
|
tvm::Map<Var, Value> captured_mod;
|
||||||
Array<Var> free_vars = FreeVars(func);
|
Array<Var> free_vars = FreeVars(func);
|
||||||
|
|
||||||
for (const auto& var : free_vars) {
|
for (const auto& var : free_vars) {
|
||||||
captured_env.Set(var, Eval(var));
|
captured_mod.Set(var, Eval(var));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ClosureNode::make(captured_env, func);
|
return ClosureNode::make(captured_mod, func);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
|
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
|
||||||
|
@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||||
locals.Set(func->params[i], args[i]);
|
locals.Set(func->params[i], args[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the var to value mappings from the Closure's environment.
|
// Add the var to value mappings from the Closure's modironment.
|
||||||
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
|
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
|
||||||
CHECK_EQ(locals.count((*it).first), 0);
|
CHECK_EQ(locals.count((*it).first), 0);
|
||||||
locals.Set((*it).first, (*it).second);
|
locals.Set((*it).first, (*it).second);
|
||||||
|
@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
Interpreter::OpMap CompileOperators(const Module& mod, const Expr& e) {
|
||||||
Interpreter::OpMap op_map;
|
Interpreter::OpMap op_map;
|
||||||
auto lowered_ops = LowerOps(env, e);
|
auto lowered_ops = LowerOps(mod, e);
|
||||||
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
|
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
|
||||||
if (lowered_ops.size()) {
|
if (lowered_ops.size()) {
|
||||||
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
|
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
|
||||||
|
@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
||||||
lowered_funcs.push_back(lop->lowered_func);
|
lowered_funcs.push_back(lop->lowered_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
Module module = fbuild(lowered_funcs);
|
runtime::Module module = fbuild(lowered_funcs);
|
||||||
|
|
||||||
// Loop over the lowered operations to map them into the operator map.
|
// Loop over the lowered operations to map them into the operator map.
|
||||||
for (auto lop : lowered_ops) {
|
for (auto lop : lowered_ops) {
|
||||||
|
@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
||||||
return op_map;
|
return op_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value Evaluate(Environment env, Expr e) {
|
Value Evaluate(Module mod, Expr e) {
|
||||||
auto op_map = CompileOperators(env, e);
|
auto op_map = CompileOperators(mod, e);
|
||||||
Interpreter interp(env, op_map);
|
Interpreter interp(mod, op_map);
|
||||||
return interp.Eval(e);
|
return interp.Eval(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._interpreter.evaluate")
|
TVM_REGISTER_API("relay._interpreter.evaluate")
|
||||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
Environment env = args[0];
|
Module mod = args[0];
|
||||||
Expr expr = args[1];
|
Expr expr = args[1];
|
||||||
*ret = Evaluate(env, expr);
|
*ret = Evaluate(mod, expr);
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2018 by Contributors
|
* Copyright (c) 2018 by Contributors
|
||||||
* \file environment.cc
|
* \file module.cc
|
||||||
* \brief The global environment in Relay.
|
* \brief The global module in Relay.
|
||||||
*/
|
*/
|
||||||
#include <tvm/relay/environment.h>
|
#include <tvm/relay/module.h>
|
||||||
#include <tvm/relay/pass.h>
|
#include <tvm/relay/pass.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ namespace relay {
|
||||||
using tvm::IRPrinter;
|
using tvm::IRPrinter;
|
||||||
using namespace runtime;
|
using namespace runtime;
|
||||||
|
|
||||||
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
|
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
|
||||||
auto n = make_node<EnvironmentNode>();
|
auto n = make_node<ModuleNode>();
|
||||||
n->functions = std::move(global_funcs);
|
n->functions = std::move(global_funcs);
|
||||||
|
|
||||||
for (const auto& kv : n->functions) {
|
for (const auto& kv : n->functions) {
|
||||||
|
@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
|
||||||
<< "Duplicate global function name " << kv.first->name_hint;
|
<< "Duplicate global function name " << kv.first->name_hint;
|
||||||
n->global_var_map_.Set(kv.first->name_hint, kv.first);
|
n->global_var_map_.Set(kv.first->name_hint, kv.first);
|
||||||
}
|
}
|
||||||
return Environment(n);
|
return Module(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) {
|
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
|
||||||
auto it = global_var_map_.find(name);
|
auto it = global_var_map_.find(name);
|
||||||
CHECK(it != global_var_map_.end())
|
CHECK(it != global_var_map_.end())
|
||||||
<< "Cannot find global var " << name << " in the Environment";
|
<< "Cannot find global var " << name << " in the Module";
|
||||||
return (*it).second;
|
return (*it).second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnvironmentNode::Add(const GlobalVar& var,
|
void ModuleNode::Add(const GlobalVar& var,
|
||||||
const Function& func,
|
const Function& func,
|
||||||
bool update) {
|
bool update) {
|
||||||
// Type check the item before we add it to the environment.
|
// Type check the item before we add it to the modironment.
|
||||||
auto env = GetRef<Environment>(this);
|
auto mod = GetRef<Module>(this);
|
||||||
Function checked_func = InferType(func, env, var);
|
Function checked_func = InferType(func, mod, var);
|
||||||
auto type = checked_func->checked_type();
|
auto type = checked_func->checked_type();
|
||||||
CHECK(type.as<IncompleteTypeNode>() == nullptr);
|
CHECK(type.as<IncompleteTypeNode>() == nullptr);
|
||||||
if (functions.find(var) != functions.end()) {
|
if (functions.find(var) != functions.end()) {
|
||||||
|
@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
|
||||||
<< "Already have definition for " << var->name_hint;
|
<< "Already have definition for " << var->name_hint;
|
||||||
auto old_type = functions[var].as<FunctionNode>()->checked_type();
|
auto old_type = functions[var].as<FunctionNode>()->checked_type();
|
||||||
CHECK(AlphaEqual(type, old_type))
|
CHECK(AlphaEqual(type, old_type))
|
||||||
<< "Environment#update changes type, not possible in this mode.";
|
<< "Module#update changes type, not possible in this mode.";
|
||||||
}
|
}
|
||||||
this->functions.Set(var, checked_func);
|
this->functions.Set(var, checked_func);
|
||||||
|
|
||||||
|
@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
|
||||||
global_var_map_.Set(var->name_hint, var);
|
global_var_map_.Set(var->name_hint, var);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnvironmentNode::Update(const GlobalVar& var, const Function& func) {
|
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
|
||||||
this->Add(var, func, true);
|
this->Add(var, func, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnvironmentNode::Remove(const GlobalVar& var) {
|
void ModuleNode::Remove(const GlobalVar& var) {
|
||||||
auto functions_node = this->functions.CopyOnWrite();
|
auto functions_node = this->functions.CopyOnWrite();
|
||||||
functions_node->data.erase(var.node_);
|
functions_node->data.erase(var.node_);
|
||||||
auto gvar_node = global_var_map_.CopyOnWrite();
|
auto gvar_node = global_var_map_.CopyOnWrite();
|
||||||
gvar_node->data.erase(var->name_hint);
|
gvar_node->data.erase(var->name_hint);
|
||||||
}
|
}
|
||||||
|
|
||||||
Function EnvironmentNode::Lookup(const GlobalVar& var) {
|
Function ModuleNode::Lookup(const GlobalVar& var) {
|
||||||
auto it = functions.find(var);
|
auto it = functions.find(var);
|
||||||
CHECK(it != functions.end())
|
CHECK(it != functions.end())
|
||||||
<< "There is no definition of " << var->name_hint;
|
<< "There is no definition of " << var->name_hint;
|
||||||
return (*it).second;
|
return (*it).second;
|
||||||
}
|
}
|
||||||
|
|
||||||
Function EnvironmentNode::Lookup(const std::string& name) {
|
Function ModuleNode::Lookup(const std::string& name) {
|
||||||
GlobalVar id = this->GetGlobalVar(name);
|
GlobalVar id = this->GetGlobalVar(name);
|
||||||
return this->Lookup(id);
|
return this->Lookup(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EnvironmentNode::Update(const Environment& env) {
|
void ModuleNode::Update(const Module& mod) {
|
||||||
for (auto pair : env->functions) {
|
for (auto pair : mod->functions) {
|
||||||
this->Update(pair.first, pair.second);
|
this->Update(pair.first, pair.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
|
TVM_REGISTER_NODE_TYPE(ModuleNode);
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._make.Environment")
|
TVM_REGISTER_API("relay._make.Module")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
*ret = EnvironmentNode::make(args[0]);
|
*ret = ModuleNode::make(args[0]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._env.Environment_Add")
|
TVM_REGISTER_API("relay._module.Module_Add")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
Environment env = args[0];
|
Module mod = args[0];
|
||||||
env->Add(args[1], args[2], args[3]);
|
mod->Add(args[1], args[2], args[3]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
|
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
Environment env = args[0];
|
Module mod = args[0];
|
||||||
*ret = env->GetGlobalVar(args[1]);
|
*ret = mod->GetGlobalVar(args[1]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._env.Environment_Lookup")
|
TVM_REGISTER_API("relay._module.Module_Lookup")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
Environment env = args[0];
|
Module mod = args[0];
|
||||||
GlobalVar var = args[1];
|
GlobalVar var = args[1];
|
||||||
*ret = env->Lookup(var);
|
*ret = mod->Lookup(var);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._env.Environment_Lookup_str")
|
TVM_REGISTER_API("relay._module.Module_Lookup_str")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
Environment env = args[0];
|
Module mod = args[0];
|
||||||
std::string var_name = args[1];
|
std::string var_name = args[1];
|
||||||
auto var = env->GetGlobalVar(var_name);
|
auto var = mod->GetGlobalVar(var_name);
|
||||||
*ret = env->Lookup(var);
|
*ret = mod->Lookup(var);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._env.Environment_Update")
|
TVM_REGISTER_API("relay._module.Module_Update")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
Environment env = args[0];
|
Module mod = args[0];
|
||||||
env->Update(args[1]);
|
mod->Update(args[1]);
|
||||||
});
|
});
|
||||||
|
|
||||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
.set_dispatch<EnvironmentNode>(
|
.set_dispatch<ModuleNode>(
|
||||||
[](const EnvironmentNode *node, tvm::IRPrinter *p) {
|
[](const ModuleNode *node, tvm::IRPrinter *p) {
|
||||||
p->stream << "EnvironmentNode( " << node->functions << ")";
|
p->stream << "ModuleNode( " << node->functions << ")";
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
|
@ -3,7 +3,7 @@
|
||||||
* \file text_printer.cc
|
* \file text_printer.cc
|
||||||
* \brief Text printer to print relay in text form.
|
* \brief Text printer to print relay in text form.
|
||||||
*/
|
*/
|
||||||
#include <tvm/relay/environment.h>
|
#include <tvm/relay/module.h>
|
||||||
#include <tvm/relay/expr_functor.h>
|
#include <tvm/relay/expr_functor.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "type_functor.h"
|
#include "type_functor.h"
|
||||||
|
@ -133,8 +133,8 @@ class TextPrinter :
|
||||||
std::string Print(const NodeRef& node) {
|
std::string Print(const NodeRef& node) {
|
||||||
if (node.as<FunctionNode>()) {
|
if (node.as<FunctionNode>()) {
|
||||||
this->PrintFunc(Downcast<Function>(node));
|
this->PrintFunc(Downcast<Function>(node));
|
||||||
} else if (node.as<EnvironmentNode>()) {
|
} else if (node.as<ModuleNode>()) {
|
||||||
this->PrintEnv(Downcast<Environment>(node));
|
this->PrintEnv(Downcast<Module>(node));
|
||||||
} else if (node.as_derived<TypeNode>()) {
|
} else if (node.as_derived<TypeNode>()) {
|
||||||
this->PrintType(Downcast<Type>(node), stream_);
|
this->PrintType(Downcast<Type>(node), stream_);
|
||||||
} else if (node.as_derived<ExprNode>()) {
|
} else if (node.as_derived<ExprNode>()) {
|
||||||
|
@ -158,9 +158,9 @@ class TextPrinter :
|
||||||
stream_ << "\n";
|
stream_ << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrintEnv(const Environment& env) {
|
void PrintEnv(const Module& mod) {
|
||||||
int counter = 0;
|
int counter = 0;
|
||||||
for (const auto& kv : env->functions) {
|
for (const auto& kv : mod->functions) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
if (counter++ != 0) {
|
if (counter++ != 0) {
|
||||||
stream_ << "\n";
|
stream_ << "\n";
|
||||||
|
|
|
@ -20,12 +20,12 @@ namespace relay {
|
||||||
using namespace runtime;
|
using namespace runtime;
|
||||||
|
|
||||||
struct AbstractFusableOps : ExprMutator {
|
struct AbstractFusableOps : ExprMutator {
|
||||||
Environment env;
|
Module mod;
|
||||||
Array<GlobalVar> fusable_funcs;
|
Array<GlobalVar> fusable_funcs;
|
||||||
int counter = 0;
|
int counter = 0;
|
||||||
size_t expr_hash;
|
size_t expr_hash;
|
||||||
|
|
||||||
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {}
|
AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {}
|
||||||
|
|
||||||
Expr VisitExpr_(const CallNode* call) {
|
Expr VisitExpr_(const CallNode* call) {
|
||||||
if (auto op_node = call->op.as<OpNode>()) {
|
if (auto op_node = call->op.as<OpNode>()) {
|
||||||
|
@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
|
||||||
func_name += "_";
|
func_name += "_";
|
||||||
func_name += std::to_string(expr_hash);
|
func_name += std::to_string(expr_hash);
|
||||||
auto gv = GlobalVarNode::make(func_name);
|
auto gv = GlobalVarNode::make(func_name);
|
||||||
env->Add(gv, func);
|
mod->Add(gv, func);
|
||||||
fusable_funcs.push_back(gv);
|
fusable_funcs.push_back(gv);
|
||||||
return CallNode::make(gv, args, Attrs());
|
return CallNode::make(gv, args, Attrs());
|
||||||
} else {
|
} else {
|
||||||
|
@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Expr FuseOps(const Environment& env, const Expr& e) {
|
Expr FuseOps(const Module& mod, const Expr& e) {
|
||||||
// First we convert all chains of fusable ops into
|
// First we convert all chains of fusable ops into
|
||||||
// abstracted functions which we mark as primtive
|
// abstracted functions which we mark as primtive
|
||||||
// then we convert these primtive functions into
|
// then we convert these primtive functions into
|
||||||
// new operators.
|
// new operators.
|
||||||
auto abstract = AbstractFusableOps(env, StructuralHash()(e));
|
auto abstract = AbstractFusableOps(mod, StructuralHash()(e));
|
||||||
auto abstracted_e = abstract.VisitExpr(e);
|
auto abstracted_e = abstract.VisitExpr(e);
|
||||||
RELAY_LOG(INFO) << "FuseOps: before=" << e
|
RELAY_LOG(INFO) << "FuseOps: before=" << e
|
||||||
<< "Fuse: after=" << abstracted_e;
|
<< "Fuse: after=" << abstracted_e;
|
||||||
|
|
|
@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool KindCheck(const Type& t, const Environment& env) {
|
bool KindCheck(const Type& t, const Module& mod) {
|
||||||
KindChecker kc;
|
KindChecker kc;
|
||||||
return kc.Check(t);
|
return kc.Check(t);
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
|
||||||
TVM_REGISTER_API("relay._ir_pass.check_kind")
|
TVM_REGISTER_API("relay._ir_pass.check_kind")
|
||||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
if (args.size() == 1) {
|
if (args.size() == 1) {
|
||||||
*ret = KindCheck(args[0], EnvironmentNode::make({}));
|
*ret = KindCheck(args[0], ModuleNode::make({}));
|
||||||
} else {
|
} else {
|
||||||
*ret = KindCheck(args[0], args[1]);
|
*ret = KindCheck(args[0], args[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AbstractLocalFunctions : ExprMutator {
|
struct AbstractLocalFunctions : ExprMutator {
|
||||||
Environment env;
|
Module mod;
|
||||||
size_t expr_hash;
|
size_t expr_hash;
|
||||||
int counter = 0;
|
int counter = 0;
|
||||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
||||||
explicit AbstractLocalFunctions(Environment env)
|
explicit AbstractLocalFunctions(Module mod)
|
||||||
: env(env), expr_hash(0), counter(0), visited_funcs() {}
|
: mod(mod), expr_hash(0), counter(0), visited_funcs() {}
|
||||||
|
|
||||||
Expr Abstract(const Expr& e) {
|
Expr Abstract(const Expr& e) {
|
||||||
expr_hash = StructuralHash()(e);
|
expr_hash = StructuralHash()(e);
|
||||||
|
@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
|
||||||
auto gvar = GetRef<GlobalVar>(gvar_node);
|
auto gvar = GetRef<GlobalVar>(gvar_node);
|
||||||
auto it = visited_funcs.find(gvar);
|
auto it = visited_funcs.find(gvar);
|
||||||
if (it == visited_funcs.end()) {
|
if (it == visited_funcs.end()) {
|
||||||
auto func = env->Lookup(gvar);
|
auto func = mod->Lookup(gvar);
|
||||||
visited_funcs.insert(gvar);
|
visited_funcs.insert(gvar);
|
||||||
auto new_func = FunctionNode::make(
|
auto new_func = FunctionNode::make(
|
||||||
func->params,
|
func->params,
|
||||||
|
@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
|
||||||
func->ret_type,
|
func->ret_type,
|
||||||
func->type_params,
|
func->type_params,
|
||||||
func->attrs);
|
func->attrs);
|
||||||
env->Update(gvar, new_func);
|
mod->Update(gvar, new_func);
|
||||||
}
|
}
|
||||||
return gvar;
|
return gvar;
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
|
||||||
abs_func += std::to_string(expr_hash);
|
abs_func += std::to_string(expr_hash);
|
||||||
auto gv = GlobalVarNode::make(abs_func);
|
auto gv = GlobalVarNode::make(abs_func);
|
||||||
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
|
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
|
||||||
env->Add(gv, lifted_func);
|
mod->Add(gv, lifted_func);
|
||||||
Array<Expr> args;
|
Array<Expr> args;
|
||||||
for (auto free_var : free_vars) {
|
for (auto free_var : free_vars) {
|
||||||
args.push_back(free_var);
|
args.push_back(free_var);
|
||||||
|
@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LiveFunctions : ExprVisitor {
|
struct LiveFunctions : ExprVisitor {
|
||||||
Environment env;
|
Module mod;
|
||||||
explicit LiveFunctions(Environment env) : env(env), global_funcs() {}
|
explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
|
||||||
|
|
||||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
||||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
|
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
|
||||||
|
@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
|
||||||
GlobalVar var = GetRef<GlobalVar>(var_node);
|
GlobalVar var = GetRef<GlobalVar>(var_node);
|
||||||
auto it = visited_funcs.find(var);
|
auto it = visited_funcs.find(var);
|
||||||
if (it == visited_funcs.end()) {
|
if (it == visited_funcs.end()) {
|
||||||
auto func = env->Lookup(var);
|
auto func = mod->Lookup(var);
|
||||||
visited_funcs.insert(var);
|
visited_funcs.insert(var);
|
||||||
// The last pass has trasnformed functions of the form:
|
// The last pass has trasnformed functions of the form:
|
||||||
//
|
//
|
||||||
|
@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
|
||||||
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
|
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
|
||||||
if (auto gv_node = call->op.as<GlobalVarNode>()) {
|
if (auto gv_node = call->op.as<GlobalVarNode>()) {
|
||||||
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
|
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
|
||||||
Function func = env->Lookup(gvar);
|
Function func = mod->Lookup(gvar);
|
||||||
|
|
||||||
auto attr = FunctionGetAttr(func, "Primitive");
|
auto attr = FunctionGetAttr(func, "Primitive");
|
||||||
|
|
||||||
|
@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
|
||||||
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
|
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
|
||||||
|
|
||||||
/*! \brief Return the set of operators in their TVM format. */
|
/*! \brief Return the set of operators in their TVM format. */
|
||||||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
|
||||||
const std::string& target) {
|
const std::string& target) {
|
||||||
RELAY_LOG(INFO) << "LowerOps: e=" << e;
|
RELAY_LOG(INFO) << "LowerOps: e=" << e;
|
||||||
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
|
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
|
||||||
CHECK(flower_ptr);
|
CHECK(flower_ptr);
|
||||||
PackedFunc flower = *flower_ptr;
|
PackedFunc flower = *flower_ptr;
|
||||||
|
|
||||||
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e);
|
auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
|
||||||
auto live_funcs = LiveFunctions(env);
|
auto live_funcs = LiveFunctions(mod);
|
||||||
live_funcs.VisitExpr(abstracted_e);
|
live_funcs.VisitExpr(abstracted_e);
|
||||||
|
|
||||||
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
|
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
|
||||||
|
@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
||||||
Array<LoweredOp> lowered_funcs;
|
Array<LoweredOp> lowered_funcs;
|
||||||
|
|
||||||
for (auto func_name : live_funcs.global_funcs) {
|
for (auto func_name : live_funcs.global_funcs) {
|
||||||
auto func = env->Lookup(func_name);
|
auto func = mod->Lookup(func_name);
|
||||||
auto call = Downcast<Call>(func->body);
|
auto call = Downcast<Call>(func->body);
|
||||||
auto op_node = call->op.as<OpNode>();
|
auto op_node = call->op.as<OpNode>();
|
||||||
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
|
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
|
||||||
|
@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
||||||
LoweredFunc lf =
|
LoweredFunc lf =
|
||||||
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
|
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
|
||||||
func = FunctionSetAttr(func, "LoweredFunc", lf);
|
func = FunctionSetAttr(func, "LoweredFunc", lf);
|
||||||
env->Add(func_name, func, true);
|
mod->Add(func_name, func, true);
|
||||||
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
|
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
||||||
// constructors
|
// constructors
|
||||||
TypeInferencer() {
|
TypeInferencer() {
|
||||||
}
|
}
|
||||||
explicit TypeInferencer(Environment env)
|
explicit TypeInferencer(Module mod)
|
||||||
: env_(env) {
|
: mod_(mod) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// inference the type of expr.
|
// inference the type of expr.
|
||||||
|
@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
||||||
// type resolver that maps back to type
|
// type resolver that maps back to type
|
||||||
class Resolver;
|
class Resolver;
|
||||||
// internal environment
|
// internal environment
|
||||||
Environment env_;
|
Module mod_;
|
||||||
// map from expression to checked type
|
// map from expression to checked type
|
||||||
// type inferencer will populate it up
|
// type inferencer will populate it up
|
||||||
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
|
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
|
||||||
|
@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
||||||
|
|
||||||
Type VisitExpr_(const GlobalVarNode* op) final {
|
Type VisitExpr_(const GlobalVarNode* op) final {
|
||||||
GlobalVar var = GetRef<GlobalVar>(op);
|
GlobalVar var = GetRef<GlobalVar>(op);
|
||||||
CHECK(env_.defined())
|
CHECK(mod_.defined())
|
||||||
<< "Cannot do type inference without a global variable";
|
<< "Cannot do type inference without a global variable";
|
||||||
Expr e = env_->Lookup(var);
|
Expr e = mod_->Lookup(var);
|
||||||
return e->checked_type();
|
return e->checked_type();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Expr InferType(const Expr& expr, const Environment& env) {
|
Expr InferType(const Expr& expr, const Module& mod) {
|
||||||
auto e = TypeInferencer(env).Infer(expr);
|
auto e = TypeInferencer(mod).Infer(expr);
|
||||||
CHECK(WellFormed(e));
|
CHECK(WellFormed(e));
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
Function InferType(const Function& func,
|
Function InferType(const Function& func,
|
||||||
const Environment& env,
|
const Module& mod,
|
||||||
const GlobalVar& var) {
|
const GlobalVar& var) {
|
||||||
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
|
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
|
||||||
func_copy->checked_type_ = func_copy->func_type_annotation();
|
func_copy->checked_type_ = func_copy->func_type_annotation();
|
||||||
env->functions.Set(var, func_copy);
|
mod->functions.Set(var, func_copy);
|
||||||
Expr func_ret = TypeInferencer(env).Infer(func_copy);
|
Expr func_ret = TypeInferencer(mod).Infer(func_copy);
|
||||||
auto map_node = env->functions.CopyOnWrite();
|
auto map_node = mod->functions.CopyOnWrite();
|
||||||
map_node->data.erase(var.node_);
|
map_node->data.erase(var.node_);
|
||||||
CHECK(WellFormed(func_ret));
|
CHECK(WellFormed(func_ret));
|
||||||
return Downcast<Function>(func_ret);
|
return Downcast<Function>(func_ret);
|
||||||
|
|
|
@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
|
||||||
auto x = relay::VarNode::make("x", type_a);
|
auto x = relay::VarNode::make("x", type_a);
|
||||||
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
|
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
|
||||||
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
|
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
|
||||||
auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{}));
|
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
|
||||||
CHECK_EQ(type_fx->checked_type(), type_a);
|
CHECK_EQ(type_fx->checked_type(), type_a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
|
||||||
from tvm.relay.interpreter import Interpreter
|
from tvm.relay.interpreter import Interpreter
|
||||||
from tvm.relay.scope_builder import ScopeBuilder
|
from tvm.relay.scope_builder import ScopeBuilder
|
||||||
from tvm.relay.op import add
|
from tvm.relay.op import add
|
||||||
from tvm.relay.env import Environment
|
from tvm.relay.module import Module
|
||||||
|
|
||||||
# @tq, @jr should we put this in testing ns?
|
# @tq, @jr should we put this in testing ns?
|
||||||
def check_rts(expr, args, expected_result, env=None):
|
def check_rts(expr, args, expected_result, mod=None):
|
||||||
"""
|
"""
|
||||||
Check that evaluating `expr` applied to the arguments produces
|
Check that evaluating `expr` applied to the arguments produces
|
||||||
`result` on both the evaluator and TVM runtime.
|
`result` on both the evaluator and TVM runtime.
|
||||||
|
@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
|
||||||
expected_result:
|
expected_result:
|
||||||
The expected result of running the expression.
|
The expected result of running the expression.
|
||||||
"""
|
"""
|
||||||
intrp = create_executor('graph', env=env)
|
intrp = create_executor('graph', mod=mod)
|
||||||
graph = create_executor('graph', env=env)
|
graph = create_executor('graph', mod=mod)
|
||||||
eval_result = intrp.evaluate(expr)(*args)
|
eval_result = intrp.evaluate(expr)(*args)
|
||||||
rts_result = graph.evaluate(expr)(*args)
|
rts_result = graph.evaluate(expr)(*args)
|
||||||
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
|
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
|
||||||
|
|
|
@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
|
||||||
from tvm.relay import testing, create_executor
|
from tvm.relay import testing, create_executor
|
||||||
|
|
||||||
|
|
||||||
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
|
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
|
||||||
intrp = create_executor(env=env)
|
intrp = create_executor(mod=mod)
|
||||||
result = intrp.evaluate(expr)(*args)
|
result = intrp.evaluate(expr)(*args)
|
||||||
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ def test_subtract():
|
||||||
check_eval(func, [i_data], 0)
|
check_eval(func, [i_data], 0)
|
||||||
|
|
||||||
def test_simple_loop():
|
def test_simple_loop():
|
||||||
env = relay.env.Environment({})
|
mod = relay.module.Module({})
|
||||||
sum_up = relay.GlobalVar('sum_up')
|
sum_up = relay.GlobalVar('sum_up')
|
||||||
i = relay.var('i', shape=[], dtype='int32')
|
i = relay.var('i', shape=[], dtype='int32')
|
||||||
sb = ScopeBuilder()
|
sb = ScopeBuilder()
|
||||||
|
@ -98,12 +98,12 @@ def test_simple_loop():
|
||||||
rec_call = relay.Call(sum_up, [one_less])
|
rec_call = relay.Call(sum_up, [one_less])
|
||||||
sb.ret(op.add(rec_call, i))
|
sb.ret(op.add(rec_call, i))
|
||||||
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
|
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
|
||||||
env[sum_up] = func
|
mod[sum_up] = func
|
||||||
i_data = np.array(10, dtype='int32')
|
i_data = np.array(10, dtype='int32')
|
||||||
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env)
|
check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
|
||||||
|
|
||||||
def test_loop():
|
def test_loop():
|
||||||
env = relay.env.Environment({})
|
mod = relay.module.Module({})
|
||||||
sum_up = relay.GlobalVar('sum_up')
|
sum_up = relay.GlobalVar('sum_up')
|
||||||
i = relay.var('i', shape=[], dtype='int32')
|
i = relay.var('i', shape=[], dtype='int32')
|
||||||
accum = relay.var('accum', shape=[], dtype='int32')
|
accum = relay.var('accum', shape=[], dtype='int32')
|
||||||
|
@ -115,10 +115,10 @@ def test_loop():
|
||||||
new_accum = op.add(accum, i)
|
new_accum = op.add(accum, i)
|
||||||
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
|
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
|
||||||
func = relay.Function([i, accum], sb.get())
|
func = relay.Function([i, accum], sb.get())
|
||||||
env[sum_up] = func
|
mod[sum_up] = func
|
||||||
i_data = np.array(10, dtype='int32')
|
i_data = np.array(10, dtype='int32')
|
||||||
accum_data = np.array(0, dtype='int32')
|
accum_data = np.array(0, dtype='int32')
|
||||||
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env)
|
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
|
||||||
|
|
||||||
def test_mlp():
|
def test_mlp():
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -28,7 +28,7 @@ def test_env():
|
||||||
z = relay.add(x, y)
|
z = relay.add(x, y)
|
||||||
z = relay.add(z, z)
|
z = relay.add(z, z)
|
||||||
f = relay.Function([x, y], z)
|
f = relay.Function([x, y], z)
|
||||||
env = relay.Environment()
|
env = relay.Module()
|
||||||
env["myf"] = f
|
env["myf"] = f
|
||||||
text = env.astext()
|
text = env.astext()
|
||||||
assert "def @myf" in text
|
assert "def @myf" in text
|
||||||
|
|
|
@ -9,8 +9,8 @@ from tvm.relay import op
|
||||||
from tvm.relay.scope_builder import ScopeBuilder
|
from tvm.relay.scope_builder import ScopeBuilder
|
||||||
|
|
||||||
|
|
||||||
def assert_has_type(expr, typ, env=relay.env.Environment({})):
|
def assert_has_type(expr, typ, mod=relay.module.Module({})):
|
||||||
checked_expr = infer_type(expr, env)
|
checked_expr = infer_type(expr, mod)
|
||||||
checked_type = checked_expr.checked_type
|
checked_type = checked_expr.checked_type
|
||||||
if checked_type != typ:
|
if checked_type != typ:
|
||||||
raise RuntimeError("Type mismatch %s vs %s" % (
|
raise RuntimeError("Type mismatch %s vs %s" % (
|
||||||
|
@ -105,10 +105,10 @@ def test_recursion():
|
||||||
sb.ret(data)
|
sb.ret(data)
|
||||||
with sb.else_scope():
|
with sb.else_scope():
|
||||||
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
|
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
|
||||||
env = relay.Environment()
|
mod = relay.Module()
|
||||||
env[f] = relay.Function([n, data], sb.get())
|
mod[f] = relay.Function([n, data], sb.get())
|
||||||
assert "%3 = @f(%1, %2)" in env.astext()
|
assert "%3 = @f(%1, %2)" in mod.astext()
|
||||||
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
|
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
|
||||||
|
|
||||||
# This currently fails and should pass under the type system.
|
# This currently fails and should pass under the type system.
|
||||||
#
|
#
|
||||||
|
@ -179,12 +179,12 @@ def test_self_reference():
|
||||||
|
|
||||||
|
|
||||||
def test_global_var_cow_issue():
|
def test_global_var_cow_issue():
|
||||||
env = relay.env.Environment({})
|
mod = relay.Module({})
|
||||||
gv = relay.GlobalVar("foo")
|
gv = relay.GlobalVar("foo")
|
||||||
x = relay.var('x', shape=[])
|
x = relay.var('x', shape=[])
|
||||||
func = relay.Function([x], relay.Call(gv, [x]),
|
func = relay.Function([x], relay.Call(gv, [x]),
|
||||||
relay.TensorType([], 'float32'))
|
relay.TensorType([], 'float32'))
|
||||||
env[gv] = func
|
mod[gv] = func
|
||||||
|
|
||||||
|
|
||||||
def test_equal():
|
def test_equal():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче