Rename relay::Environment to relay::Module (#2054)
This commit is contained in:
Родитель
420ec786e9
Коммит
ead3ac6c23
|
@ -165,7 +165,7 @@ class RelayNode : public Node {
|
|||
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
|
||||
};
|
||||
|
||||
struct Environment;
|
||||
struct Module;
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#define TVM_RELAY_BUILD_MODULE_H_
|
||||
|
||||
#include <tvm/lowered_func.h>
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <string>
|
||||
|
||||
|
@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);
|
|||
* \note This will do a reachability analysis and lower all definitions
|
||||
* reachable from the provided expression.
|
||||
*
|
||||
* \param env The environment.
|
||||
* \param mod The module.
|
||||
* \param expr The expression with operations to be lowered.
|
||||
* \param target The target to lower the functions to.
|
||||
*
|
||||
* \return The set of lowered operations.
|
||||
*/
|
||||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
|
||||
Array<LoweredOp> LowerOps(const Module& mod, const Expr& expr,
|
||||
const std::string& target = "llvm");
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -160,7 +160,7 @@ class VarNode : public ExprNode {
|
|||
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
|
||||
|
||||
/*!
|
||||
* \brief Global variable that leaves in the top-level environment.
|
||||
* \brief Global variable that leaves in the top-level module.
|
||||
* This is used to enable recursive calls between function.
|
||||
*
|
||||
* \note A GlobalVar may only point to functions.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
* \brief An interpreter for Relay.
|
||||
*
|
||||
* This file implements a simple reference interpreter for Relay programs.
|
||||
* Given a Relay environment, and a Relay expression it produces a value.
|
||||
* Given a Relay module, and a Relay expression it produces a value.
|
||||
*
|
||||
* The interpreter's values are a naive representation of the values that
|
||||
* can be produced by a Relay program and are exposed via tvm::Node's
|
||||
|
@ -16,7 +16,7 @@
|
|||
#ifndef TVM_RELAY_INTERPRETER_H_
|
||||
#define TVM_RELAY_INTERPRETER_H_
|
||||
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
|
||||
namespace tvm {
|
||||
|
@ -39,7 +39,7 @@ class Value;
|
|||
* Our intent is that this will never be the most efficient implementation of
|
||||
* Relay's semantics, but a readable and clear one.
|
||||
*/
|
||||
Value Evaluate(Environment env, Expr e);
|
||||
Value Evaluate(Module mod, Expr e);
|
||||
|
||||
/*! \brief The base container type of Relay values. */
|
||||
class ValueNode : public RelayNode {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file tvm/relay/environment.h
|
||||
* \file tvm/relay/module.h
|
||||
* \brief The global environment: contains information needed to
|
||||
* compile & optimize Relay programs.
|
||||
*/
|
||||
#ifndef TVM_RELAY_ENVIRONMENT_H_
|
||||
#define TVM_RELAY_ENVIRONMENT_H_
|
||||
#ifndef TVM_RELAY_MODULE_H_
|
||||
#define TVM_RELAY_MODULE_H_
|
||||
|
||||
#include <tvm/relay/error.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
|
@ -17,7 +17,7 @@
|
|||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
struct Environment;
|
||||
struct Module;
|
||||
|
||||
/*! \brief The global environment of Relay programs.
|
||||
*
|
||||
|
@ -28,29 +28,29 @@ struct Environment;
|
|||
* options.
|
||||
*
|
||||
* Many operations require access to the global
|
||||
* Environment. We pass the Environment by value
|
||||
* Module. We pass the Module by value
|
||||
* in a functional style as an explicit argument,
|
||||
* but we mutate the Environment while optimizing
|
||||
* but we mutate the Module while optimizing
|
||||
* Relay programs.
|
||||
*
|
||||
* The functional style allows users to construct custom
|
||||
* environments easily, for example each thread can store
|
||||
* an Environment while auto-tuning.
|
||||
* an Module while auto-tuning.
|
||||
* */
|
||||
|
||||
class EnvironmentNode : public RelayNode {
|
||||
class ModuleNode : public RelayNode {
|
||||
public:
|
||||
/*! \brief A map from ids to all global functions. */
|
||||
tvm::Map<GlobalVar, Function> functions;
|
||||
|
||||
EnvironmentNode() {}
|
||||
ModuleNode() {}
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("functions", &functions);
|
||||
v->Visit("global_var_map_", &global_var_map_);
|
||||
}
|
||||
|
||||
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
|
||||
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
|
||||
|
||||
/*!
|
||||
* \brief Add a function to the global environment.
|
||||
|
@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode {
|
|||
* functions in another environment.
|
||||
* \param other The other environment.
|
||||
*/
|
||||
void Update(const Environment& other);
|
||||
void Update(const Module& other);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Environment";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
|
||||
static constexpr const char* _type_key = "relay.Module";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
|
||||
|
||||
private:
|
||||
/*! \brief A map from string names to global variables that
|
||||
|
@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode {
|
|||
tvm::Map<std::string, GlobalVar> global_var_map_;
|
||||
};
|
||||
|
||||
struct Environment : public NodeRef {
|
||||
Environment() {}
|
||||
explicit Environment(NodePtr<tvm::Node> p) : NodeRef(p) {}
|
||||
struct Module : public NodeRef {
|
||||
Module() {}
|
||||
explicit Module(NodePtr<tvm::Node> p) : NodeRef(p) {}
|
||||
|
||||
inline EnvironmentNode* operator->() const {
|
||||
return static_cast<EnvironmentNode*>(node_.get());
|
||||
inline ModuleNode* operator->() const {
|
||||
return static_cast<ModuleNode*>(node_.get());
|
||||
}
|
||||
|
||||
using ContainerType = EnvironmentNode;
|
||||
using ContainerType = ModuleNode;
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_RELAY_ENVIRONMENT_H_
|
||||
#endif // TVM_RELAY_MODULE_H_
|
|
@ -6,7 +6,7 @@
|
|||
#ifndef TVM_RELAY_PASS_H_
|
||||
#define TVM_RELAY_PASS_H_
|
||||
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <string>
|
||||
|
||||
|
@ -21,23 +21,23 @@ namespace relay {
|
|||
* populated with the result type.
|
||||
*
|
||||
* \param expr The expression to type check.
|
||||
* \param env The environment used for referencing global functions, can be
|
||||
* \param mod The module used for referencing global functions, can be
|
||||
* None.
|
||||
*
|
||||
* \return A type checked expression with its checked_type field populated.
|
||||
*/
|
||||
Expr InferType(const Expr& expr, const Environment& env);
|
||||
Expr InferType(const Expr& expr, const Module& mod);
|
||||
/*!
|
||||
* \brief Infer the type of a function as if it is mapped to var in the env.
|
||||
* \brief Infer the type of a function as if it is mapped to var in the mod.
|
||||
*
|
||||
* \param f the function.
|
||||
* \param env The environment used for referencing global functions.
|
||||
* \param mod The module used for referencing global functions.
|
||||
* \param var The global variable corresponding to the function.
|
||||
*
|
||||
* \return A type checked Function with its checked_type field populated.
|
||||
* \note this function mutates env and is not thread-safe.
|
||||
* \note this function mutates mod and is not thread-safe.
|
||||
*/
|
||||
Function InferType(const Function& f, const Environment& env,
|
||||
Function InferType(const Function& f, const Module& mod,
|
||||
const GlobalVar& var);
|
||||
|
||||
/*!
|
||||
|
@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env,
|
|||
* a data type such as `int`, `float`, `uint`.
|
||||
*
|
||||
* \param t The type to check.
|
||||
* \param env The global environment.
|
||||
* \param mod The global module.
|
||||
*
|
||||
* \return true if the rules are satisified otherwise false
|
||||
*/
|
||||
bool KindCheck(const Type& t, const Environment& env);
|
||||
bool KindCheck(const Type& t, const Module& mod);
|
||||
|
||||
/*! \brief Compare two expressions for structural equivalence.
|
||||
*
|
||||
|
|
|
@ -349,14 +349,14 @@ class TypeRelation;
|
|||
/*!
|
||||
* \brief TypeRelation container.
|
||||
* \note This node is not directly serializable.
|
||||
* The type function need to be lookedup in the environment.
|
||||
* The type function need to be lookedup in the module.
|
||||
*/
|
||||
class TypeRelationNode : public TypeConstraintNode {
|
||||
public:
|
||||
/*!
|
||||
* \brief The function on input and output variables which
|
||||
* this is not directly serializable,
|
||||
* need to be looked-up in the environment.
|
||||
* need to be looked-up in the module.
|
||||
*/
|
||||
TypeRelationFn func;
|
||||
/*! \brief The type arguments to the type function. */
|
||||
|
|
|
@ -5,7 +5,7 @@ from ..api import register_func
|
|||
from . import base
|
||||
from . import ty
|
||||
from . import expr
|
||||
from . import env
|
||||
from . import module
|
||||
from . import ir_pass
|
||||
from .build_module import build
|
||||
from .interpreter import create_executor
|
||||
|
@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder
|
|||
Span = base.Span
|
||||
|
||||
# Env
|
||||
Environment = env.Environment
|
||||
Module = module.Module
|
||||
|
||||
# Type
|
||||
Type = ty.Type
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from .env import Environment
|
||||
from .env import Module
|
||||
from . import ir
|
||||
|
||||
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
|
||||
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
|
||||
def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
|
||||
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
|
||||
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
|
||||
def well_formed(expr: ir.Expr) -> bool: ...
|
||||
def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ...
|
|
@ -1,5 +1,5 @@
|
|||
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
|
||||
"""The interface to the Environment exposed from C++."""
|
||||
"""The interface to the Module exposed from C++."""
|
||||
from tvm._ffi.function import _init_api
|
||||
|
||||
_init_api("relay._env", __name__)
|
||||
_init_api("relay._module", __name__)
|
|
@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List
|
|||
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
|
||||
from relay.ir import ShapeExtension, Operator, Defn
|
||||
|
||||
class Environment(NodeBase): ...
|
||||
class Module(NodeBase): ...
|
|
@ -5,9 +5,9 @@ from a Relay expression.
|
|||
from ..build_module import build as tvm_build_module
|
||||
from . graph_runtime_codegen import GraphRuntimeCodegen
|
||||
from . import ir_pass
|
||||
from .env import Environment
|
||||
from .module import Module
|
||||
|
||||
def build(func, params=None, target=None, env=None):
|
||||
def build(func, params=None, target=None, mod=None):
|
||||
"""
|
||||
Compile a single function to the components needed by the
|
||||
TVM RTS.
|
||||
|
@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None):
|
|||
if target is None:
|
||||
target = 'llvm'
|
||||
|
||||
if env is None:
|
||||
env = Environment({})
|
||||
if mod is None:
|
||||
mod = Module({})
|
||||
|
||||
comp = GraphRuntimeCodegen(env)
|
||||
comp = GraphRuntimeCodegen(mod)
|
||||
# NB(@jroesch) This creates lowered functions, and generates names for them
|
||||
#
|
||||
# We need these names to emit the correct graph as these are names of the
|
||||
# functions contained in the module.
|
||||
lowered_ops = ir_pass.lower_ops(env, func)
|
||||
lowered_ops = ir_pass.lower_ops(mod, func)
|
||||
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
|
||||
|
||||
# Therefore the call to compile must come after.
|
||||
|
|
|
@ -172,7 +172,7 @@ class GlobalVar(Expr):
|
|||
"""A global variable in Tvm.Relay.
|
||||
|
||||
GlobalVar is used to refer to the global functions
|
||||
stored in the environment.
|
||||
stored in the module.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -8,7 +8,7 @@ from . import build_module
|
|||
from . import _make
|
||||
from . import _interpreter
|
||||
from . import ir_pass
|
||||
from .env import Environment
|
||||
from .module import Module
|
||||
from .expr import Call, Constant, GlobalVar, Function, const
|
||||
from .scope_builder import ScopeBuilder
|
||||
from .._ffi.base import integer_types
|
||||
|
@ -90,24 +90,24 @@ def _arg_to_ast(arg):
|
|||
class Executor(object):
|
||||
"""An abstract interface for executing Relay programs."""
|
||||
|
||||
def __init__(self, env=None):
|
||||
def __init__(self, mod=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
env: relay.Environment
|
||||
The environment.
|
||||
mod: relay.Module
|
||||
The module.
|
||||
"""
|
||||
if env is None:
|
||||
self.env = Environment({})
|
||||
if mod is None:
|
||||
self.mod = Module({})
|
||||
else:
|
||||
self.env = env
|
||||
self.mod = mod
|
||||
|
||||
|
||||
def optimize(self, expr):
|
||||
# TODO: We need to move this optimization code into the optimizer/pass manager
|
||||
ck_expr = ir_pass.infer_type(expr, env=self.env)
|
||||
fused_expr = ir_pass.fuse_ops(self.env, ck_expr)
|
||||
ck_fused = ir_pass.infer_type(fused_expr, env=self.env)
|
||||
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
|
||||
fused_expr = ir_pass.fuse_ops(self.mod, ck_expr)
|
||||
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
|
||||
return ck_fused
|
||||
|
||||
def _make_executor(self, _):
|
||||
|
@ -153,8 +153,8 @@ class Interpreter(Executor):
|
|||
"""
|
||||
A wrapper around the Relay interpreter, implements the excecutor interface.
|
||||
"""
|
||||
def __init__(self, env=None):
|
||||
Executor.__init__(self, env)
|
||||
def __init__(self, mod=None):
|
||||
Executor.__init__(self, mod)
|
||||
|
||||
def _make_executor(self, expr):
|
||||
def _interp_wrapper(*args):
|
||||
|
@ -163,28 +163,28 @@ class Interpreter(Executor):
|
|||
relay_args.append(_arg_to_ast(arg))
|
||||
|
||||
if isinstance(expr, GlobalVar):
|
||||
func = self.env[expr]
|
||||
func = self.mod[expr]
|
||||
func = self.optimize(func)
|
||||
self.env._add(expr, func, True)
|
||||
self.mod._add(expr, func, True)
|
||||
opt_expr = Call(expr, relay_args)
|
||||
return _interpreter.evaluate(self.env, opt_expr)
|
||||
return _interpreter.evaluate(self.mod, opt_expr)
|
||||
else:
|
||||
call = Call(expr, relay_args)
|
||||
opt_expr = self.optimize(call)
|
||||
return _interpreter.evaluate(self.env, opt_expr)
|
||||
return _interpreter.evaluate(self.mod, opt_expr)
|
||||
|
||||
return _interp_wrapper
|
||||
|
||||
|
||||
class GraphRuntime(Executor):
|
||||
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
|
||||
def __init__(self, env=None):
|
||||
Executor.__init__(self, env)
|
||||
def __init__(self, mod=None):
|
||||
Executor.__init__(self, mod)
|
||||
|
||||
def _make_executor(self, expr):
|
||||
def _graph_wrapper(*args):
|
||||
func = self.optimize(expr)
|
||||
graph_json, mod, params = build_module.build(func, env=self.env)
|
||||
graph_json, mod, params = build_module.build(func, mod=self.mod)
|
||||
assert params is None
|
||||
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
|
||||
# Create map of inputs.
|
||||
|
@ -199,10 +199,10 @@ class GraphRuntime(Executor):
|
|||
|
||||
return _graph_wrapper
|
||||
|
||||
def create_executor(mode='debug', env=None):
|
||||
def create_executor(mode='debug', mod=None):
|
||||
if mode == 'debug':
|
||||
return Interpreter(env)
|
||||
return Interpreter(mod)
|
||||
elif mode == 'graph':
|
||||
return GraphRuntime(env)
|
||||
return GraphRuntime(mod)
|
||||
else:
|
||||
raise Exception("unknown mode {0}".format(mode))
|
||||
|
|
|
@ -11,16 +11,16 @@ from .expr import Expr
|
|||
from .ty import Type
|
||||
|
||||
|
||||
def infer_type(expr, env=None):
|
||||
"""Infer the type of expr under the context of env.
|
||||
def infer_type(expr, mod=None):
|
||||
"""Infer the type of expr under the context of mod.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
env: Optional[tvm.relay.Environment]
|
||||
The global environment.
|
||||
mod: Optional[tvm.relay.Module]
|
||||
The global module.
|
||||
|
||||
|
||||
Returns
|
||||
|
@ -28,7 +28,7 @@ def infer_type(expr, env=None):
|
|||
checked_expr : tvm.relay.Expr
|
||||
The checked expression.
|
||||
"""
|
||||
return _ir_pass.infer_type(expr, env)
|
||||
return _ir_pass.infer_type(expr, mod)
|
||||
|
||||
|
||||
def backward_fold_scale_axis(expr):
|
||||
|
@ -93,7 +93,7 @@ def well_formed(expr):
|
|||
return _ir_pass.well_formed(expr)
|
||||
|
||||
|
||||
def check_kind(t, env=None):
|
||||
def check_kind(t, mod=None):
|
||||
"""Check that the type is well kinded.
|
||||
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
|
||||
|
||||
|
@ -102,8 +102,8 @@ def check_kind(t, env=None):
|
|||
t: tvm.relay.Type
|
||||
The type to check
|
||||
|
||||
env: tvm.relay.Environment, optional
|
||||
The global environment
|
||||
mod: tvm.relay.Module, optional
|
||||
The global module
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -117,8 +117,8 @@ def check_kind(t, env=None):
|
|||
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
|
||||
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
|
||||
"""
|
||||
if env is not None:
|
||||
return _ir_pass.check_kind(t, env)
|
||||
if mod is not None:
|
||||
return _ir_pass.check_kind(t, mod)
|
||||
else:
|
||||
return _ir_pass.check_kind(t)
|
||||
|
||||
|
@ -256,8 +256,8 @@ def structural_hash(value):
|
|||
"relay.Expr or relay.Type").format(type(value))
|
||||
raise TypeError(msg)
|
||||
|
||||
def fuse_ops(expr, env):
|
||||
return _ir_pass.FuseOps(env, expr)
|
||||
def fuse_ops(expr, mod):
|
||||
return _ir_pass.FuseOps(mod, expr)
|
||||
|
||||
def lower_ops(env, expr, target='llvm'):
|
||||
return _ir_pass.LowerOps(env, expr, target)
|
||||
def lower_ops(mod, expr, target='llvm'):
|
||||
return _ir_pass.LowerOps(mod, expr, target)
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
|
||||
"""A global environment storing everything needed to interpret or compile a Relay program."""
|
||||
"""A global module storing everything needed to interpret or compile a Relay program."""
|
||||
from .base import register_relay_node, RelayNode
|
||||
from .._ffi import base as _base
|
||||
from . import _make
|
||||
from . import _env
|
||||
from . import _module
|
||||
from . import expr as _expr
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Environment(RelayNode):
|
||||
"""The global Relay environment containing collection of functions.
|
||||
class Module(RelayNode):
|
||||
"""The global Relay module containing collection of functions.
|
||||
|
||||
Each global function is identified by an unique tvm.relay.GlobalVar.
|
||||
tvm.relay.GlobalVar and Environment is necessary in order to enable
|
||||
tvm.relay.GlobalVar and Module is necessary in order to enable
|
||||
recursions in function to avoid cyclic reference in the function.x
|
||||
|
||||
Parameters
|
||||
|
@ -32,10 +32,10 @@ class Environment(RelayNode):
|
|||
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
|
||||
mapped_funcs[k] = v
|
||||
functions = mapped_funcs
|
||||
self.__init_handle_by_constructor__(_make.Environment, functions)
|
||||
self.__init_handle_by_constructor__(_make.Module, functions)
|
||||
|
||||
def __setitem__(self, var, func):
|
||||
"""Add a function to the environment.
|
||||
"""Add a function to the module.
|
||||
|
||||
Parameters
|
||||
---------
|
||||
|
@ -50,7 +50,7 @@ class Environment(RelayNode):
|
|||
def _add(self, var, func, update=False):
|
||||
if isinstance(var, _base.string_types):
|
||||
var = _expr.GlobalVar(var)
|
||||
return _env.Environment_Add(self, var, func, update)
|
||||
return _module.Module_Add(self, var, func, update)
|
||||
|
||||
def __getitem__(self, var):
|
||||
"""Lookup a global function by name or by variable.
|
||||
|
@ -66,21 +66,21 @@ class Environment(RelayNode):
|
|||
The function referenced by :code:`var`.
|
||||
"""
|
||||
if isinstance(var, _base.string_types):
|
||||
return _env.Environment_Lookup_str(self, var)
|
||||
return _module.Module_Lookup_str(self, var)
|
||||
else:
|
||||
return _env.Environment_Lookup(self, var)
|
||||
return _module.Module_Lookup(self, var)
|
||||
|
||||
def update(self, other):
|
||||
"""Insert functions in another Environment to current one.
|
||||
"""Insert functions in another Module to current one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other: Environment
|
||||
The environment to merge into the current Environment.
|
||||
other: Module
|
||||
The module to merge into the current Module.
|
||||
"""
|
||||
if isinstance(other, dict):
|
||||
other = Environment(other)
|
||||
return _env.Environment_Update(self, other)
|
||||
other = Module(other)
|
||||
return _module.Module_Update(self, other)
|
||||
|
||||
def get_global_var(self, name):
|
||||
"""Get a global variable in the function by name.
|
||||
|
@ -99,4 +99,4 @@ class Environment(RelayNode):
|
|||
------
|
||||
tvm.TVMError if we cannot find corresponding global var.
|
||||
"""
|
||||
return _env.Environment_GetGlobalVar(self, name)
|
||||
return _module.Module_GetGlobalVar(self, name)
|
|
@ -183,7 +183,7 @@ struct ExprEqual {
|
|||
};
|
||||
|
||||
struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
||||
Environment env;
|
||||
Module mod;
|
||||
Stack stack;
|
||||
using JitKey = Function;
|
||||
|
||||
|
@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
|||
return f();
|
||||
}
|
||||
|
||||
Interpreter(Environment env) : env(env), operator_map_() {}
|
||||
Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {}
|
||||
Interpreter(Module mod) : mod(mod), operator_map_() {}
|
||||
Interpreter(Module mod, OpMap operator_map) : mod(mod), operator_map_(operator_map) {}
|
||||
|
||||
void extend(const Var& id, Value v) {
|
||||
this->stack.current_frame().locals.Set(id, v);
|
||||
|
@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
|||
}
|
||||
|
||||
Value VisitExpr_(const GlobalVarNode* op) override {
|
||||
return Eval(this->env->Lookup(GetRef<GlobalVar>(op)));
|
||||
return Eval(this->mod->Lookup(GetRef<GlobalVar>(op)));
|
||||
}
|
||||
|
||||
Value VisitExpr_(const OpNode* id) override {
|
||||
|
@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
|||
|
||||
Value VisitExpr_(const FunctionNode* func_node) override {
|
||||
auto func = GetRef<Function>(func_node);
|
||||
tvm::Map<Var, Value> captured_env;
|
||||
tvm::Map<Var, Value> captured_mod;
|
||||
Array<Var> free_vars = FreeVars(func);
|
||||
|
||||
for (const auto& var : free_vars) {
|
||||
captured_env.Set(var, Eval(var));
|
||||
captured_mod.Set(var, Eval(var));
|
||||
}
|
||||
|
||||
return ClosureNode::make(captured_env, func);
|
||||
return ClosureNode::make(captured_mod, func);
|
||||
}
|
||||
|
||||
inline Value InvokeCompiledOp(PackedFunc func, const Array<Value>& args,
|
||||
|
@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
|||
locals.Set(func->params[i], args[i]);
|
||||
}
|
||||
|
||||
// Add the var to value mappings from the Closure's environment.
|
||||
// Add the var to value mappings from the Closure's modironment.
|
||||
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
|
||||
CHECK_EQ(locals.count((*it).first), 0);
|
||||
locals.Set((*it).first, (*it).second);
|
||||
|
@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor<Value(const Expr& n)> {
|
|||
}
|
||||
};
|
||||
|
||||
Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
||||
Interpreter::OpMap CompileOperators(const Module& mod, const Expr& e) {
|
||||
Interpreter::OpMap op_map;
|
||||
auto lowered_ops = LowerOps(env, e);
|
||||
auto lowered_ops = LowerOps(mod, e);
|
||||
RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl;
|
||||
if (lowered_ops.size()) {
|
||||
const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build");
|
||||
|
@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
|||
lowered_funcs.push_back(lop->lowered_func);
|
||||
}
|
||||
|
||||
Module module = fbuild(lowered_funcs);
|
||||
runtime::Module module = fbuild(lowered_funcs);
|
||||
|
||||
// Loop over the lowered operations to map them into the operator map.
|
||||
for (auto lop : lowered_ops) {
|
||||
|
@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) {
|
|||
return op_map;
|
||||
}
|
||||
|
||||
Value Evaluate(Environment env, Expr e) {
|
||||
auto op_map = CompileOperators(env, e);
|
||||
Interpreter interp(env, op_map);
|
||||
Value Evaluate(Module mod, Expr e) {
|
||||
auto op_map = CompileOperators(mod, e);
|
||||
Interpreter interp(mod, op_map);
|
||||
return interp.Eval(e);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._interpreter.evaluate")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
Environment env = args[0];
|
||||
Module mod = args[0];
|
||||
Expr expr = args[1];
|
||||
*ret = Evaluate(env, expr);
|
||||
*ret = Evaluate(mod, expr);
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file environment.cc
|
||||
* \brief The global environment in Relay.
|
||||
* \file module.cc
|
||||
* \brief The global module in Relay.
|
||||
*/
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <sstream>
|
||||
|
||||
|
@ -13,8 +13,8 @@ namespace relay {
|
|||
using tvm::IRPrinter;
|
||||
using namespace runtime;
|
||||
|
||||
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
|
||||
auto n = make_node<EnvironmentNode>();
|
||||
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
|
||||
auto n = make_node<ModuleNode>();
|
||||
n->functions = std::move(global_funcs);
|
||||
|
||||
for (const auto& kv : n->functions) {
|
||||
|
@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
|
|||
<< "Duplicate global function name " << kv.first->name_hint;
|
||||
n->global_var_map_.Set(kv.first->name_hint, kv.first);
|
||||
}
|
||||
return Environment(n);
|
||||
return Module(n);
|
||||
}
|
||||
|
||||
GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) {
|
||||
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
|
||||
auto it = global_var_map_.find(name);
|
||||
CHECK(it != global_var_map_.end())
|
||||
<< "Cannot find global var " << name << " in the Environment";
|
||||
<< "Cannot find global var " << name << " in the Module";
|
||||
return (*it).second;
|
||||
}
|
||||
|
||||
void EnvironmentNode::Add(const GlobalVar& var,
|
||||
void ModuleNode::Add(const GlobalVar& var,
|
||||
const Function& func,
|
||||
bool update) {
|
||||
// Type check the item before we add it to the environment.
|
||||
auto env = GetRef<Environment>(this);
|
||||
Function checked_func = InferType(func, env, var);
|
||||
// Type check the item before we add it to the modironment.
|
||||
auto mod = GetRef<Module>(this);
|
||||
Function checked_func = InferType(func, mod, var);
|
||||
auto type = checked_func->checked_type();
|
||||
CHECK(type.as<IncompleteTypeNode>() == nullptr);
|
||||
if (functions.find(var) != functions.end()) {
|
||||
|
@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var,
|
|||
<< "Already have definition for " << var->name_hint;
|
||||
auto old_type = functions[var].as<FunctionNode>()->checked_type();
|
||||
CHECK(AlphaEqual(type, old_type))
|
||||
<< "Environment#update changes type, not possible in this mode.";
|
||||
<< "Module#update changes type, not possible in this mode.";
|
||||
}
|
||||
this->functions.Set(var, checked_func);
|
||||
|
||||
|
@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var,
|
|||
global_var_map_.Set(var->name_hint, var);
|
||||
}
|
||||
|
||||
void EnvironmentNode::Update(const GlobalVar& var, const Function& func) {
|
||||
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
|
||||
this->Add(var, func, true);
|
||||
}
|
||||
|
||||
void EnvironmentNode::Remove(const GlobalVar& var) {
|
||||
void ModuleNode::Remove(const GlobalVar& var) {
|
||||
auto functions_node = this->functions.CopyOnWrite();
|
||||
functions_node->data.erase(var.node_);
|
||||
auto gvar_node = global_var_map_.CopyOnWrite();
|
||||
gvar_node->data.erase(var->name_hint);
|
||||
}
|
||||
|
||||
Function EnvironmentNode::Lookup(const GlobalVar& var) {
|
||||
Function ModuleNode::Lookup(const GlobalVar& var) {
|
||||
auto it = functions.find(var);
|
||||
CHECK(it != functions.end())
|
||||
<< "There is no definition of " << var->name_hint;
|
||||
return (*it).second;
|
||||
}
|
||||
|
||||
Function EnvironmentNode::Lookup(const std::string& name) {
|
||||
Function ModuleNode::Lookup(const std::string& name) {
|
||||
GlobalVar id = this->GetGlobalVar(name);
|
||||
return this->Lookup(id);
|
||||
}
|
||||
|
||||
void EnvironmentNode::Update(const Environment& env) {
|
||||
for (auto pair : env->functions) {
|
||||
void ModuleNode::Update(const Module& mod) {
|
||||
for (auto pair : mod->functions) {
|
||||
this->Update(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
|
||||
TVM_REGISTER_NODE_TYPE(ModuleNode);
|
||||
|
||||
TVM_REGISTER_API("relay._make.Environment")
|
||||
TVM_REGISTER_API("relay._make.Module")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = EnvironmentNode::make(args[0]);
|
||||
*ret = ModuleNode::make(args[0]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._env.Environment_Add")
|
||||
TVM_REGISTER_API("relay._module.Module_Add")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Environment env = args[0];
|
||||
env->Add(args[1], args[2], args[3]);
|
||||
Module mod = args[0];
|
||||
mod->Add(args[1], args[2], args[3]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
|
||||
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Environment env = args[0];
|
||||
*ret = env->GetGlobalVar(args[1]);
|
||||
Module mod = args[0];
|
||||
*ret = mod->GetGlobalVar(args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._env.Environment_Lookup")
|
||||
TVM_REGISTER_API("relay._module.Module_Lookup")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Environment env = args[0];
|
||||
Module mod = args[0];
|
||||
GlobalVar var = args[1];
|
||||
*ret = env->Lookup(var);
|
||||
*ret = mod->Lookup(var);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._env.Environment_Lookup_str")
|
||||
TVM_REGISTER_API("relay._module.Module_Lookup_str")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Environment env = args[0];
|
||||
Module mod = args[0];
|
||||
std::string var_name = args[1];
|
||||
auto var = env->GetGlobalVar(var_name);
|
||||
*ret = env->Lookup(var);
|
||||
auto var = mod->GetGlobalVar(var_name);
|
||||
*ret = mod->Lookup(var);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._env.Environment_Update")
|
||||
TVM_REGISTER_API("relay._module.Module_Update")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Environment env = args[0];
|
||||
env->Update(args[1]);
|
||||
Module mod = args[0];
|
||||
mod->Update(args[1]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<EnvironmentNode>(
|
||||
[](const EnvironmentNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "EnvironmentNode( " << node->functions << ")";
|
||||
.set_dispatch<ModuleNode>(
|
||||
[](const ModuleNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "ModuleNode( " << node->functions << ")";
|
||||
});
|
||||
|
||||
} // namespace relay
|
|
@ -3,7 +3,7 @@
|
|||
* \file text_printer.cc
|
||||
* \brief Text printer to print relay in text form.
|
||||
*/
|
||||
#include <tvm/relay/environment.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <sstream>
|
||||
#include "type_functor.h"
|
||||
|
@ -133,8 +133,8 @@ class TextPrinter :
|
|||
std::string Print(const NodeRef& node) {
|
||||
if (node.as<FunctionNode>()) {
|
||||
this->PrintFunc(Downcast<Function>(node));
|
||||
} else if (node.as<EnvironmentNode>()) {
|
||||
this->PrintEnv(Downcast<Environment>(node));
|
||||
} else if (node.as<ModuleNode>()) {
|
||||
this->PrintEnv(Downcast<Module>(node));
|
||||
} else if (node.as_derived<TypeNode>()) {
|
||||
this->PrintType(Downcast<Type>(node), stream_);
|
||||
} else if (node.as_derived<ExprNode>()) {
|
||||
|
@ -158,9 +158,9 @@ class TextPrinter :
|
|||
stream_ << "\n";
|
||||
}
|
||||
|
||||
void PrintEnv(const Environment& env) {
|
||||
void PrintEnv(const Module& mod) {
|
||||
int counter = 0;
|
||||
for (const auto& kv : env->functions) {
|
||||
for (const auto& kv : mod->functions) {
|
||||
std::ostringstream os;
|
||||
if (counter++ != 0) {
|
||||
stream_ << "\n";
|
||||
|
|
|
@ -20,12 +20,12 @@ namespace relay {
|
|||
using namespace runtime;
|
||||
|
||||
struct AbstractFusableOps : ExprMutator {
|
||||
Environment env;
|
||||
Module mod;
|
||||
Array<GlobalVar> fusable_funcs;
|
||||
int counter = 0;
|
||||
size_t expr_hash;
|
||||
|
||||
AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {}
|
||||
AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {}
|
||||
|
||||
Expr VisitExpr_(const CallNode* call) {
|
||||
if (auto op_node = call->op.as<OpNode>()) {
|
||||
|
@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator {
|
|||
func_name += "_";
|
||||
func_name += std::to_string(expr_hash);
|
||||
auto gv = GlobalVarNode::make(func_name);
|
||||
env->Add(gv, func);
|
||||
mod->Add(gv, func);
|
||||
fusable_funcs.push_back(gv);
|
||||
return CallNode::make(gv, args, Attrs());
|
||||
} else {
|
||||
|
@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator {
|
|||
}
|
||||
};
|
||||
|
||||
Expr FuseOps(const Environment& env, const Expr& e) {
|
||||
Expr FuseOps(const Module& mod, const Expr& e) {
|
||||
// First we convert all chains of fusable ops into
|
||||
// abstracted functions which we mark as primtive
|
||||
// then we convert these primtive functions into
|
||||
// new operators.
|
||||
auto abstract = AbstractFusableOps(env, StructuralHash()(e));
|
||||
auto abstract = AbstractFusableOps(mod, StructuralHash()(e));
|
||||
auto abstracted_e = abstract.VisitExpr(e);
|
||||
RELAY_LOG(INFO) << "FuseOps: before=" << e
|
||||
<< "Fuse: after=" << abstracted_e;
|
||||
|
|
|
@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor {
|
|||
}
|
||||
};
|
||||
|
||||
bool KindCheck(const Type& t, const Environment& env) {
|
||||
bool KindCheck(const Type& t, const Module& mod) {
|
||||
KindChecker kc;
|
||||
return kc.Check(t);
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) {
|
|||
TVM_REGISTER_API("relay._ir_pass.check_kind")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
if (args.size() == 1) {
|
||||
*ret = KindCheck(args[0], EnvironmentNode::make({}));
|
||||
*ret = KindCheck(args[0], ModuleNode::make({}));
|
||||
} else {
|
||||
*ret = KindCheck(args[0], args[1]);
|
||||
}
|
||||
|
|
|
@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) {
|
|||
}
|
||||
|
||||
struct AbstractLocalFunctions : ExprMutator {
|
||||
Environment env;
|
||||
Module mod;
|
||||
size_t expr_hash;
|
||||
int counter = 0;
|
||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
||||
explicit AbstractLocalFunctions(Environment env)
|
||||
: env(env), expr_hash(0), counter(0), visited_funcs() {}
|
||||
explicit AbstractLocalFunctions(Module mod)
|
||||
: mod(mod), expr_hash(0), counter(0), visited_funcs() {}
|
||||
|
||||
Expr Abstract(const Expr& e) {
|
||||
expr_hash = StructuralHash()(e);
|
||||
|
@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator {
|
|||
auto gvar = GetRef<GlobalVar>(gvar_node);
|
||||
auto it = visited_funcs.find(gvar);
|
||||
if (it == visited_funcs.end()) {
|
||||
auto func = env->Lookup(gvar);
|
||||
auto func = mod->Lookup(gvar);
|
||||
visited_funcs.insert(gvar);
|
||||
auto new_func = FunctionNode::make(
|
||||
func->params,
|
||||
|
@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator {
|
|||
func->ret_type,
|
||||
func->type_params,
|
||||
func->attrs);
|
||||
env->Update(gvar, new_func);
|
||||
mod->Update(gvar, new_func);
|
||||
}
|
||||
return gvar;
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator {
|
|||
abs_func += std::to_string(expr_hash);
|
||||
auto gv = GlobalVarNode::make(abs_func);
|
||||
auto lifted_func = FunctionNode::make(params, func, Type(), {}, {});
|
||||
env->Add(gv, lifted_func);
|
||||
mod->Add(gv, lifted_func);
|
||||
Array<Expr> args;
|
||||
for (auto free_var : free_vars) {
|
||||
args.push_back(free_var);
|
||||
|
@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator {
|
|||
};
|
||||
|
||||
struct LiveFunctions : ExprVisitor {
|
||||
Environment env;
|
||||
explicit LiveFunctions(Environment env) : env(env), global_funcs() {}
|
||||
Module mod;
|
||||
explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {}
|
||||
|
||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> visited_funcs;
|
||||
std::unordered_set<GlobalVar, NodeHash, NodeEqual> global_funcs;
|
||||
|
@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor {
|
|||
GlobalVar var = GetRef<GlobalVar>(var_node);
|
||||
auto it = visited_funcs.find(var);
|
||||
if (it == visited_funcs.end()) {
|
||||
auto func = env->Lookup(var);
|
||||
auto func = mod->Lookup(var);
|
||||
visited_funcs.insert(var);
|
||||
// The last pass has trasnformed functions of the form:
|
||||
//
|
||||
|
@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor {
|
|||
RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef<Call>(call);
|
||||
if (auto gv_node = call->op.as<GlobalVarNode>()) {
|
||||
GlobalVar gvar = GetRef<GlobalVar>(gv_node);
|
||||
Function func = env->Lookup(gvar);
|
||||
Function func = mod->Lookup(gvar);
|
||||
|
||||
auto attr = FunctionGetAttr(func, "Primitive");
|
||||
|
||||
|
@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc<Array<Tensor>(
|
|||
using FSchedule = TypedPackedFunc<Schedule(const Array<Tensor>&, std::string)>;
|
||||
|
||||
/*! \brief Return the set of operators in their TVM format. */
|
||||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
||||
Array<LoweredOp> LowerOps(const Module& mod, const Expr& e,
|
||||
const std::string& target) {
|
||||
RELAY_LOG(INFO) << "LowerOps: e=" << e;
|
||||
auto flower_ptr = Registry::Get("relay.op.compiler._lower");
|
||||
CHECK(flower_ptr);
|
||||
PackedFunc flower = *flower_ptr;
|
||||
|
||||
auto abstracted_e = AbstractLocalFunctions(env).Abstract(e);
|
||||
auto live_funcs = LiveFunctions(env);
|
||||
auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e);
|
||||
auto live_funcs = LiveFunctions(mod);
|
||||
live_funcs.VisitExpr(abstracted_e);
|
||||
|
||||
auto schedule_reg = Op::GetAttr<FSchedule>("FTVMSchedule");
|
||||
|
@ -176,7 +176,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
|||
Array<LoweredOp> lowered_funcs;
|
||||
|
||||
for (auto func_name : live_funcs.global_funcs) {
|
||||
auto func = env->Lookup(func_name);
|
||||
auto func = mod->Lookup(func_name);
|
||||
auto call = Downcast<Call>(func->body);
|
||||
auto op_node = call->op.as<OpNode>();
|
||||
CHECK(op_node) << "violated invariant that primtiive calls contain a single op call";
|
||||
|
@ -205,7 +205,7 @@ Array<LoweredOp> LowerOps(const Environment& env, const Expr& e,
|
|||
LoweredFunc lf =
|
||||
flower(op->name + std::to_string(hash), schedule, inputs, outputs);
|
||||
func = FunctionSetAttr(func, "LoweredFunc", lf);
|
||||
env->Add(func_name, func, true);
|
||||
mod->Add(func_name, func, true);
|
||||
lowered_funcs.push_back(LoweredOpNode::make(func, lf));
|
||||
}
|
||||
|
||||
|
|
|
@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
// constructors
|
||||
TypeInferencer() {
|
||||
}
|
||||
explicit TypeInferencer(Environment env)
|
||||
: env_(env) {
|
||||
explicit TypeInferencer(Module mod)
|
||||
: mod_(mod) {
|
||||
}
|
||||
|
||||
// inference the type of expr.
|
||||
|
@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
// type resolver that maps back to type
|
||||
class Resolver;
|
||||
// internal environment
|
||||
Environment env_;
|
||||
Module mod_;
|
||||
// map from expression to checked type
|
||||
// type inferencer will populate it up
|
||||
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
|
||||
|
@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
|
||||
Type VisitExpr_(const GlobalVarNode* op) final {
|
||||
GlobalVar var = GetRef<GlobalVar>(op);
|
||||
CHECK(env_.defined())
|
||||
CHECK(mod_.defined())
|
||||
<< "Cannot do type inference without a global variable";
|
||||
Expr e = env_->Lookup(var);
|
||||
Expr e = mod_->Lookup(var);
|
||||
return e->checked_type();
|
||||
}
|
||||
|
||||
|
@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) {
|
|||
}
|
||||
|
||||
|
||||
Expr InferType(const Expr& expr, const Environment& env) {
|
||||
auto e = TypeInferencer(env).Infer(expr);
|
||||
Expr InferType(const Expr& expr, const Module& mod) {
|
||||
auto e = TypeInferencer(mod).Infer(expr);
|
||||
CHECK(WellFormed(e));
|
||||
return e;
|
||||
}
|
||||
|
||||
Function InferType(const Function& func,
|
||||
const Environment& env,
|
||||
const Module& mod,
|
||||
const GlobalVar& var) {
|
||||
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
|
||||
func_copy->checked_type_ = func_copy->func_type_annotation();
|
||||
env->functions.Set(var, func_copy);
|
||||
Expr func_ret = TypeInferencer(env).Infer(func_copy);
|
||||
auto map_node = env->functions.CopyOnWrite();
|
||||
mod->functions.Set(var, func_copy);
|
||||
Expr func_ret = TypeInferencer(mod).Infer(func_copy);
|
||||
auto map_node = mod->functions.CopyOnWrite();
|
||||
map_node->data.erase(var.node_);
|
||||
CHECK(WellFormed(func_ret));
|
||||
return Downcast<Function>(func_ret);
|
||||
|
|
|
@ -11,7 +11,7 @@ TEST(Relay, SelfReference) {
|
|||
auto x = relay::VarNode::make("x", type_a);
|
||||
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
|
||||
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
|
||||
auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{}));
|
||||
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
|
||||
CHECK_EQ(type_fx->checked_type(), type_a);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type
|
|||
from tvm.relay.interpreter import Interpreter
|
||||
from tvm.relay.scope_builder import ScopeBuilder
|
||||
from tvm.relay.op import add
|
||||
from tvm.relay.env import Environment
|
||||
from tvm.relay.module import Module
|
||||
|
||||
# @tq, @jr should we put this in testing ns?
|
||||
def check_rts(expr, args, expected_result, env=None):
|
||||
def check_rts(expr, args, expected_result, mod=None):
|
||||
"""
|
||||
Check that evaluating `expr` applied to the arguments produces
|
||||
`result` on both the evaluator and TVM runtime.
|
||||
|
@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None):
|
|||
expected_result:
|
||||
The expected result of running the expression.
|
||||
"""
|
||||
intrp = create_executor('graph', env=env)
|
||||
graph = create_executor('graph', env=env)
|
||||
intrp = create_executor('graph', mod=mod)
|
||||
graph = create_executor('graph', mod=mod)
|
||||
eval_result = intrp.evaluate(expr)(*args)
|
||||
rts_result = graph.evaluate(expr)(*args)
|
||||
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
|
||||
|
|
|
@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder
|
|||
from tvm.relay import testing, create_executor
|
||||
|
||||
|
||||
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
|
||||
intrp = create_executor(env=env)
|
||||
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
|
||||
intrp = create_executor(mod=mod)
|
||||
result = intrp.evaluate(expr)(*args)
|
||||
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
||||
|
||||
|
@ -87,7 +87,7 @@ def test_subtract():
|
|||
check_eval(func, [i_data], 0)
|
||||
|
||||
def test_simple_loop():
|
||||
env = relay.env.Environment({})
|
||||
mod = relay.module.Module({})
|
||||
sum_up = relay.GlobalVar('sum_up')
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
sb = ScopeBuilder()
|
||||
|
@ -98,12 +98,12 @@ def test_simple_loop():
|
|||
rec_call = relay.Call(sum_up, [one_less])
|
||||
sb.ret(op.add(rec_call, i))
|
||||
func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
|
||||
env[sum_up] = func
|
||||
mod[sum_up] = func
|
||||
i_data = np.array(10, dtype='int32')
|
||||
check_eval(sum_up, [i_data], sum(range(1, 11)), env=env)
|
||||
check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod)
|
||||
|
||||
def test_loop():
|
||||
env = relay.env.Environment({})
|
||||
mod = relay.module.Module({})
|
||||
sum_up = relay.GlobalVar('sum_up')
|
||||
i = relay.var('i', shape=[], dtype='int32')
|
||||
accum = relay.var('accum', shape=[], dtype='int32')
|
||||
|
@ -115,10 +115,10 @@ def test_loop():
|
|||
new_accum = op.add(accum, i)
|
||||
sb.ret(relay.Call(sum_up, [one_less, new_accum]))
|
||||
func = relay.Function([i, accum], sb.get())
|
||||
env[sum_up] = func
|
||||
mod[sum_up] = func
|
||||
i_data = np.array(10, dtype='int32')
|
||||
accum_data = np.array(0, dtype='int32')
|
||||
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env)
|
||||
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
|
||||
|
||||
def test_mlp():
|
||||
pass
|
||||
|
|
|
@ -28,7 +28,7 @@ def test_env():
|
|||
z = relay.add(x, y)
|
||||
z = relay.add(z, z)
|
||||
f = relay.Function([x, y], z)
|
||||
env = relay.Environment()
|
||||
env = relay.Module()
|
||||
env["myf"] = f
|
||||
text = env.astext()
|
||||
assert "def @myf" in text
|
||||
|
|
|
@ -9,8 +9,8 @@ from tvm.relay import op
|
|||
from tvm.relay.scope_builder import ScopeBuilder
|
||||
|
||||
|
||||
def assert_has_type(expr, typ, env=relay.env.Environment({})):
|
||||
checked_expr = infer_type(expr, env)
|
||||
def assert_has_type(expr, typ, mod=relay.module.Module({})):
|
||||
checked_expr = infer_type(expr, mod)
|
||||
checked_type = checked_expr.checked_type
|
||||
if checked_type != typ:
|
||||
raise RuntimeError("Type mismatch %s vs %s" % (
|
||||
|
@ -105,10 +105,10 @@ def test_recursion():
|
|||
sb.ret(data)
|
||||
with sb.else_scope():
|
||||
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
|
||||
env = relay.Environment()
|
||||
env[f] = relay.Function([n, data], sb.get())
|
||||
assert "%3 = @f(%1, %2)" in env.astext()
|
||||
assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32)
|
||||
mod = relay.Module()
|
||||
mod[f] = relay.Function([n, data], sb.get())
|
||||
assert "%3 = @f(%1, %2)" in mod.astext()
|
||||
assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
|
||||
|
||||
# This currently fails and should pass under the type system.
|
||||
#
|
||||
|
@ -179,12 +179,12 @@ def test_self_reference():
|
|||
|
||||
|
||||
def test_global_var_cow_issue():
|
||||
env = relay.env.Environment({})
|
||||
mod = relay.Module({})
|
||||
gv = relay.GlobalVar("foo")
|
||||
x = relay.var('x', shape=[])
|
||||
func = relay.Function([x], relay.Call(gv, [x]),
|
||||
relay.TensorType([], 'float32'))
|
||||
env[gv] = func
|
||||
mod[gv] = func
|
||||
|
||||
|
||||
def test_equal():
|
||||
|
|
Загрузка…
Ссылка в новой задаче