From ead3ac6c23c713d395724037b6a4d7ce08eea392 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 2 Nov 2018 14:27:34 -0700 Subject: [PATCH] Rename relay::Environment to relay::Module (#2054) --- include/tvm/relay/base.h | 2 +- include/tvm/relay/build_module.h | 6 +- include/tvm/relay/expr.h | 2 +- include/tvm/relay/interpreter.h | 6 +- include/tvm/relay/{environment.h => module.h} | 40 ++++----- include/tvm/relay/pass.h | 18 ++-- include/tvm/relay/type.h | 4 +- python/tvm/relay/__init__.py | 4 +- python/tvm/relay/_ir_pass.pyi | 8 +- python/tvm/relay/{_env.py => _module.py} | 4 +- python/tvm/relay/{_env.pyi => _module.pyi} | 2 +- python/tvm/relay/build_module.py | 12 +-- python/tvm/relay/expr.py | 2 +- python/tvm/relay/interpreter.py | 44 +++++----- python/tvm/relay/ir_pass.py | 28 +++---- python/tvm/relay/{env.py => module.py} | 32 ++++---- src/relay/interpreter.cc | 32 ++++---- src/relay/ir/{environment.cc => module.cc} | 82 +++++++++---------- src/relay/ir/text_printer.cc | 10 +-- src/relay/pass/fuse_ops.cc | 10 +-- src/relay/pass/kind_check.cc | 4 +- src/relay/pass/lower_ops.cc | 30 +++---- src/relay/pass/type_infer.cc | 22 ++--- tests/cpp/relay_pass_type_infer_test.cc | 2 +- tests/python/relay/test_graph_runtime.py | 8 +- tests/python/relay/test_interpreter.py | 16 ++-- tests/python/relay/test_ir_text_printer.py | 2 +- tests/python/relay/test_type_infer.py | 16 ++-- 28 files changed, 224 insertions(+), 224 deletions(-) rename include/tvm/relay/{environment.h => module.h} (76%) rename python/tvm/relay/{_env.py => _module.py} (56%) rename python/tvm/relay/{_env.pyi => _module.pyi} (84%) rename python/tvm/relay/{env.py => module.py} (72%) rename src/relay/ir/{environment.cc => module.cc} (55%) diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index b7621e20..49e276b0 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -165,7 +165,7 @@ class RelayNode : public Node { TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); }; -struct Environment; +struct Module; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/build_module.h b/include/tvm/relay/build_module.h index ed889eba..35402d65 100644 --- a/include/tvm/relay/build_module.h +++ b/include/tvm/relay/build_module.h @@ -8,7 +8,7 @@ #define TVM_RELAY_BUILD_MODULE_H_ #include -#include +#include #include #include @@ -61,13 +61,13 @@ RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef); * \note This will do a reachability analysis and lower all definitions * reachable from the provided expression. * - * \param env The environment. + * \param mod The module. * \param expr The expression with operations to be lowered. * \param target The target to lower the functions to. * * \return The set of lowered operations. */ -Array LowerOps(const Environment& env, const Expr& expr, +Array LowerOps(const Module& mod, const Expr& expr, const std::string& target = "llvm"); } // namespace relay diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 029470c0..1a547048 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -160,7 +160,7 @@ class VarNode : public ExprNode { RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); /*! - * \brief Global variable that leaves in the top-level environment. + * \brief Global variable that leaves in the top-level module. * This is used to enable recursive calls between function. * * \note A GlobalVar may only point to functions. diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 1c382faa..403dd50a 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -4,7 +4,7 @@ * \brief An interpreter for Relay. * * This file implements a simple reference interpreter for Relay programs. - * Given a Relay environment, and a Relay expression it produces a value. + * Given a Relay module, and a Relay expression it produces a value. * * The interpreter's values are a naive representation of the values that * can be produced by a Relay program and are exposed via tvm::Node's @@ -16,7 +16,7 @@ #ifndef TVM_RELAY_INTERPRETER_H_ #define TVM_RELAY_INTERPRETER_H_ -#include +#include #include namespace tvm { @@ -39,7 +39,7 @@ class Value; * Our intent is that this will never be the most efficient implementation of * Relay's semantics, but a readable and clear one. */ -Value Evaluate(Environment env, Expr e); +Value Evaluate(Module mod, Expr e); /*! \brief The base container type of Relay values. */ class ValueNode : public RelayNode { diff --git a/include/tvm/relay/environment.h b/include/tvm/relay/module.h similarity index 76% rename from include/tvm/relay/environment.h rename to include/tvm/relay/module.h index 2ed38957..b04d6fec 100644 --- a/include/tvm/relay/environment.h +++ b/include/tvm/relay/module.h @@ -1,11 +1,11 @@ /*! * Copyright (c) 2018 by Contributors - * \file tvm/relay/environment.h + * \file tvm/relay/module.h * \brief The global environment: contains information needed to * compile & optimize Relay programs. */ -#ifndef TVM_RELAY_ENVIRONMENT_H_ -#define TVM_RELAY_ENVIRONMENT_H_ +#ifndef TVM_RELAY_MODULE_H_ +#define TVM_RELAY_MODULE_H_ #include #include @@ -17,7 +17,7 @@ namespace tvm { namespace relay { -struct Environment; +struct Module; /*! \brief The global environment of Relay programs. * @@ -28,29 +28,29 @@ struct Environment; * options. * * Many operations require access to the global - * Environment. We pass the Environment by value + * Module. We pass the Module by value * in a functional style as an explicit argument, - * but we mutate the Environment while optimizing + * but we mutate the Module while optimizing * Relay programs. * * The functional style allows users to construct custom * environments easily, for example each thread can store - * an Environment while auto-tuning. + * an Module while auto-tuning. * */ -class EnvironmentNode : public RelayNode { +class ModuleNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; - EnvironmentNode() {} + ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); v->Visit("global_var_map_", &global_var_map_); } - TVM_DLL static Environment make(tvm::Map global_funcs); + TVM_DLL static Module make(tvm::Map global_funcs); /*! * \brief Add a function to the global environment. @@ -100,10 +100,10 @@ class EnvironmentNode : public RelayNode { * functions in another environment. * \param other The other environment. */ - void Update(const Environment& other); + void Update(const Module& other); - static constexpr const char* _type_key = "relay.Environment"; - TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); + static constexpr const char* _type_key = "relay.Module"; + TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); private: /*! \brief A map from string names to global variables that @@ -112,18 +112,18 @@ class EnvironmentNode : public RelayNode { tvm::Map global_var_map_; }; -struct Environment : public NodeRef { - Environment() {} - explicit Environment(NodePtr p) : NodeRef(p) {} +struct Module : public NodeRef { + Module() {} + explicit Module(NodePtr p) : NodeRef(p) {} - inline EnvironmentNode* operator->() const { - return static_cast(node_.get()); + inline ModuleNode* operator->() const { + return static_cast(node_.get()); } - using ContainerType = EnvironmentNode; + using ContainerType = ModuleNode; }; } // namespace relay } // namespace tvm -#endif // TVM_RELAY_ENVIRONMENT_H_ +#endif // TVM_RELAY_MODULE_H_ diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index b2967810..5ff60c70 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -6,7 +6,7 @@ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ -#include +#include #include #include @@ -21,23 +21,23 @@ namespace relay { * populated with the result type. * * \param expr The expression to type check. - * \param env The environment used for referencing global functions, can be + * \param mod The module used for referencing global functions, can be * None. * * \return A type checked expression with its checked_type field populated. */ -Expr InferType(const Expr& expr, const Environment& env); +Expr InferType(const Expr& expr, const Module& mod); /*! - * \brief Infer the type of a function as if it is mapped to var in the env. + * \brief Infer the type of a function as if it is mapped to var in the mod. * * \param f the function. - * \param env The environment used for referencing global functions. + * \param mod The module used for referencing global functions. * \param var The global variable corresponding to the function. * * \return A type checked Function with its checked_type field populated. - * \note this function mutates env and is not thread-safe. + * \note this function mutates mod and is not thread-safe. */ -Function InferType(const Function& f, const Environment& env, +Function InferType(const Function& f, const Module& mod, const GlobalVar& var); /*! @@ -52,11 +52,11 @@ Function InferType(const Function& f, const Environment& env, * a data type such as `int`, `float`, `uint`. * * \param t The type to check. - * \param env The global environment. + * \param mod The global module. * * \return true if the rules are satisified otherwise false */ -bool KindCheck(const Type& t, const Environment& env); +bool KindCheck(const Type& t, const Module& mod); /*! \brief Compare two expressions for structural equivalence. * diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index c8ccb603..69a8a4fb 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -349,14 +349,14 @@ class TypeRelation; /*! * \brief TypeRelation container. * \note This node is not directly serializable. - * The type function need to be lookedup in the environment. + * The type function need to be lookedup in the module. */ class TypeRelationNode : public TypeConstraintNode { public: /*! * \brief The function on input and output variables which * this is not directly serializable, - * need to be looked-up in the environment. + * need to be looked-up in the module. */ TypeRelationFn func; /*! \brief The type arguments to the type function. */ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index b0a1fcec..f474eb44 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -5,7 +5,7 @@ from ..api import register_func from . import base from . import ty from . import expr -from . import env +from . import module from . import ir_pass from .build_module import build from .interpreter import create_executor @@ -26,7 +26,7 @@ from .scope_builder import ScopeBuilder Span = base.Span # Env -Environment = env.Environment +Module = module.Module # Type Type = ty.Type diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi index f1432803..6bf4e2da 100644 --- a/python/tvm/relay/_ir_pass.pyi +++ b/python/tvm/relay/_ir_pass.pyi @@ -1,8 +1,8 @@ -from .env import Environment +from .env import Module from . import ir -def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... -def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... +def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... +def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ... def well_formed(expr: ir.Expr) -> bool: ... -def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... \ No newline at end of file +def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... diff --git a/python/tvm/relay/_env.py b/python/tvm/relay/_module.py similarity index 56% rename from python/tvm/relay/_env.py rename to python/tvm/relay/_module.py index 25b8715a..b6e74c45 100644 --- a/python/tvm/relay/_env.py +++ b/python/tvm/relay/_module.py @@ -1,5 +1,5 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable -"""The interface to the Environment exposed from C++.""" +"""The interface to the Module exposed from C++.""" from tvm._ffi.function import _init_api -_init_api("relay._env", __name__) +_init_api("relay._module", __name__) diff --git a/python/tvm/relay/_env.pyi b/python/tvm/relay/_module.pyi similarity index 84% rename from python/tvm/relay/_env.pyi rename to python/tvm/relay/_module.pyi index c6b5d0f6..de3aabef 100644 --- a/python/tvm/relay/_env.pyi +++ b/python/tvm/relay/_module.pyi @@ -2,4 +2,4 @@ from typing import Union, Tuple, Dict, List from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId from relay.ir import ShapeExtension, Operator, Defn -class Environment(NodeBase): ... \ No newline at end of file +class Module(NodeBase): ... diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 6b60fd3f..e71571e6 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -5,9 +5,9 @@ from a Relay expression. from ..build_module import build as tvm_build_module from . graph_runtime_codegen import GraphRuntimeCodegen from . import ir_pass -from .env import Environment +from .module import Module -def build(func, params=None, target=None, env=None): +def build(func, params=None, target=None, mod=None): """ Compile a single function to the components needed by the TVM RTS. @@ -29,15 +29,15 @@ def build(func, params=None, target=None, env=None): if target is None: target = 'llvm' - if env is None: - env = Environment({}) + if mod is None: + mod = Module({}) - comp = GraphRuntimeCodegen(env) + comp = GraphRuntimeCodegen(mod) # NB(@jroesch) This creates lowered functions, and generates names for them # # We need these names to emit the correct graph as these are names of the # functions contained in the module. - lowered_ops = ir_pass.lower_ops(env, func) + lowered_ops = ir_pass.lower_ops(mod, func) mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target) # Therefore the call to compile must come after. diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index dd9477aa..d789f281 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -172,7 +172,7 @@ class GlobalVar(Expr): """A global variable in Tvm.Relay. GlobalVar is used to refer to the global functions - stored in the environment. + stored in the module. Parameters ---------- diff --git a/python/tvm/relay/interpreter.py b/python/tvm/relay/interpreter.py index d95943c1..4dfe3e02 100644 --- a/python/tvm/relay/interpreter.py +++ b/python/tvm/relay/interpreter.py @@ -8,7 +8,7 @@ from . import build_module from . import _make from . import _interpreter from . import ir_pass -from .env import Environment +from .module import Module from .expr import Call, Constant, GlobalVar, Function, const from .scope_builder import ScopeBuilder from .._ffi.base import integer_types @@ -90,24 +90,24 @@ def _arg_to_ast(arg): class Executor(object): """An abstract interface for executing Relay programs.""" - def __init__(self, env=None): + def __init__(self, mod=None): """ Parameters ---------- - env: relay.Environment - The environment. + mod: relay.Module + The module. """ - if env is None: - self.env = Environment({}) + if mod is None: + self.mod = Module({}) else: - self.env = env + self.mod = mod def optimize(self, expr): # TODO: We need to move this optimization code into the optimizer/pass manager - ck_expr = ir_pass.infer_type(expr, env=self.env) - fused_expr = ir_pass.fuse_ops(self.env, ck_expr) - ck_fused = ir_pass.infer_type(fused_expr, env=self.env) + ck_expr = ir_pass.infer_type(expr, mod=self.mod) + fused_expr = ir_pass.fuse_ops(self.mod, ck_expr) + ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod) return ck_fused def _make_executor(self, _): @@ -153,8 +153,8 @@ class Interpreter(Executor): """ A wrapper around the Relay interpreter, implements the excecutor interface. """ - def __init__(self, env=None): - Executor.__init__(self, env) + def __init__(self, mod=None): + Executor.__init__(self, mod) def _make_executor(self, expr): def _interp_wrapper(*args): @@ -163,28 +163,28 @@ class Interpreter(Executor): relay_args.append(_arg_to_ast(arg)) if isinstance(expr, GlobalVar): - func = self.env[expr] + func = self.mod[expr] func = self.optimize(func) - self.env._add(expr, func, True) + self.mod._add(expr, func, True) opt_expr = Call(expr, relay_args) - return _interpreter.evaluate(self.env, opt_expr) + return _interpreter.evaluate(self.mod, opt_expr) else: call = Call(expr, relay_args) opt_expr = self.optimize(call) - return _interpreter.evaluate(self.env, opt_expr) + return _interpreter.evaluate(self.mod, opt_expr) return _interp_wrapper class GraphRuntime(Executor): """A wrapper around the TVM graph runtime, implements the Executor interface.""" - def __init__(self, env=None): - Executor.__init__(self, env) + def __init__(self, mod=None): + Executor.__init__(self, mod) def _make_executor(self, expr): def _graph_wrapper(*args): func = self.optimize(expr) - graph_json, mod, params = build_module.build(func, env=self.env) + graph_json, mod, params = build_module.build(func, mod=self.mod) assert params is None gmodule = tvm_runtime.create(graph_json, mod, cpu(0)) # Create map of inputs. @@ -199,10 +199,10 @@ class GraphRuntime(Executor): return _graph_wrapper -def create_executor(mode='debug', env=None): +def create_executor(mode='debug', mod=None): if mode == 'debug': - return Interpreter(env) + return Interpreter(mod) elif mode == 'graph': - return GraphRuntime(env) + return GraphRuntime(mod) else: raise Exception("unknown mode {0}".format(mode)) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index f3950fff..989e5ad7 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -11,16 +11,16 @@ from .expr import Expr from .ty import Type -def infer_type(expr, env=None): - """Infer the type of expr under the context of env. +def infer_type(expr, mod=None): + """Infer the type of expr under the context of mod. Parameters ---------- expr: tvm.relay.Expr The input expression. - env: Optional[tvm.relay.Environment] - The global environment. + mod: Optional[tvm.relay.Module] + The global module. Returns @@ -28,7 +28,7 @@ def infer_type(expr, env=None): checked_expr : tvm.relay.Expr The checked expression. """ - return _ir_pass.infer_type(expr, env) + return _ir_pass.infer_type(expr, mod) def backward_fold_scale_axis(expr): @@ -93,7 +93,7 @@ def well_formed(expr): return _ir_pass.well_formed(expr) -def check_kind(t, env=None): +def check_kind(t, mod=None): """Check that the type is well kinded. For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. @@ -102,8 +102,8 @@ def check_kind(t, env=None): t: tvm.relay.Type The type to check - env: tvm.relay.Environment, optional - The global environment + mod: tvm.relay.Module, optional + The global module Returns ------- @@ -117,8 +117,8 @@ def check_kind(t, env=None): assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) """ - if env is not None: - return _ir_pass.check_kind(t, env) + if mod is not None: + return _ir_pass.check_kind(t, mod) else: return _ir_pass.check_kind(t) @@ -256,8 +256,8 @@ def structural_hash(value): "relay.Expr or relay.Type").format(type(value)) raise TypeError(msg) -def fuse_ops(expr, env): - return _ir_pass.FuseOps(env, expr) +def fuse_ops(expr, mod): + return _ir_pass.FuseOps(mod, expr) -def lower_ops(env, expr, target='llvm'): - return _ir_pass.LowerOps(env, expr, target) +def lower_ops(mod, expr, target='llvm'): + return _ir_pass.LowerOps(mod, expr, target) diff --git a/python/tvm/relay/env.py b/python/tvm/relay/module.py similarity index 72% rename from python/tvm/relay/env.py rename to python/tvm/relay/module.py index 37e0999d..024c6baf 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/module.py @@ -1,18 +1,18 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import -"""A global environment storing everything needed to interpret or compile a Relay program.""" +"""A global module storing everything needed to interpret or compile a Relay program.""" from .base import register_relay_node, RelayNode from .._ffi import base as _base from . import _make -from . import _env +from . import _module from . import expr as _expr @register_relay_node -class Environment(RelayNode): - """The global Relay environment containing collection of functions. +class Module(RelayNode): + """The global Relay module containing collection of functions. Each global function is identified by an unique tvm.relay.GlobalVar. - tvm.relay.GlobalVar and Environment is necessary in order to enable + tvm.relay.GlobalVar and Module is necessary in order to enable recursions in function to avoid cyclic reference in the function.x Parameters @@ -32,10 +32,10 @@ class Environment(RelayNode): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") mapped_funcs[k] = v functions = mapped_funcs - self.__init_handle_by_constructor__(_make.Environment, functions) + self.__init_handle_by_constructor__(_make.Module, functions) def __setitem__(self, var, func): - """Add a function to the environment. + """Add a function to the module. Parameters --------- @@ -50,7 +50,7 @@ class Environment(RelayNode): def _add(self, var, func, update=False): if isinstance(var, _base.string_types): var = _expr.GlobalVar(var) - return _env.Environment_Add(self, var, func, update) + return _module.Module_Add(self, var, func, update) def __getitem__(self, var): """Lookup a global function by name or by variable. @@ -66,21 +66,21 @@ class Environment(RelayNode): The function referenced by :code:`var`. """ if isinstance(var, _base.string_types): - return _env.Environment_Lookup_str(self, var) + return _module.Module_Lookup_str(self, var) else: - return _env.Environment_Lookup(self, var) + return _module.Module_Lookup(self, var) def update(self, other): - """Insert functions in another Environment to current one. + """Insert functions in another Module to current one. Parameters ---------- - other: Environment - The environment to merge into the current Environment. + other: Module + The module to merge into the current Module. """ if isinstance(other, dict): - other = Environment(other) - return _env.Environment_Update(self, other) + other = Module(other) + return _module.Module_Update(self, other) def get_global_var(self, name): """Get a global variable in the function by name. @@ -99,4 +99,4 @@ class Environment(RelayNode): ------ tvm.TVMError if we cannot find corresponding global var. """ - return _env.Environment_GetGlobalVar(self, name) + return _module.Module_GetGlobalVar(self, name) diff --git a/src/relay/interpreter.cc b/src/relay/interpreter.cc index 534a2a98..5db7b66e 100644 --- a/src/relay/interpreter.cc +++ b/src/relay/interpreter.cc @@ -183,7 +183,7 @@ struct ExprEqual { }; struct Interpreter : ExprFunctor { - Environment env; + Module mod; Stack stack; using JitKey = Function; @@ -197,8 +197,8 @@ struct Interpreter : ExprFunctor { return f(); } - Interpreter(Environment env) : env(env), operator_map_() {} - Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {} + Interpreter(Module mod) : mod(mod), operator_map_() {} + Interpreter(Module mod, OpMap operator_map) : mod(mod), operator_map_(operator_map) {} void extend(const Var& id, Value v) { this->stack.current_frame().locals.Set(id, v); @@ -223,7 +223,7 @@ struct Interpreter : ExprFunctor { } Value VisitExpr_(const GlobalVarNode* op) override { - return Eval(this->env->Lookup(GetRef(op))); + return Eval(this->mod->Lookup(GetRef(op))); } Value VisitExpr_(const OpNode* id) override { @@ -251,14 +251,14 @@ struct Interpreter : ExprFunctor { Value VisitExpr_(const FunctionNode* func_node) override { auto func = GetRef(func_node); - tvm::Map captured_env; + tvm::Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { - captured_env.Set(var, Eval(var)); + captured_mod.Set(var, Eval(var)); } - return ClosureNode::make(captured_env, func); + return ClosureNode::make(captured_mod, func); } inline Value InvokeCompiledOp(PackedFunc func, const Array& args, @@ -315,7 +315,7 @@ struct Interpreter : ExprFunctor { locals.Set(func->params[i], args[i]); } - // Add the var to value mappings from the Closure's environment. + // Add the var to value mappings from the Closure's modironment. for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { CHECK_EQ(locals.count((*it).first), 0); locals.Set((*it).first, (*it).second); @@ -384,9 +384,9 @@ struct Interpreter : ExprFunctor { } }; -Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { +Interpreter::OpMap CompileOperators(const Module& mod, const Expr& e) { Interpreter::OpMap op_map; - auto lowered_ops = LowerOps(env, e); + auto lowered_ops = LowerOps(mod, e); RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl; if (lowered_ops.size()) { const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build"); @@ -399,7 +399,7 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { lowered_funcs.push_back(lop->lowered_func); } - Module module = fbuild(lowered_funcs); + runtime::Module module = fbuild(lowered_funcs); // Loop over the lowered operations to map them into the operator map. for (auto lop : lowered_ops) { @@ -415,17 +415,17 @@ Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { return op_map; } -Value Evaluate(Environment env, Expr e) { - auto op_map = CompileOperators(env, e); - Interpreter interp(env, op_map); +Value Evaluate(Module mod, Expr e) { + auto op_map = CompileOperators(mod, e); + Interpreter interp(mod, op_map); return interp.Eval(e); } TVM_REGISTER_API("relay._interpreter.evaluate") .set_body([](TVMArgs args, TVMRetValue* ret) { - Environment env = args[0]; + Module mod = args[0]; Expr expr = args[1]; - *ret = Evaluate(env, expr); + *ret = Evaluate(mod, expr); }); } // namespace relay diff --git a/src/relay/ir/environment.cc b/src/relay/ir/module.cc similarity index 55% rename from src/relay/ir/environment.cc rename to src/relay/ir/module.cc index 262758ba..4443ed50 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/module.cc @@ -1,9 +1,9 @@ /*! * Copyright (c) 2018 by Contributors - * \file environment.cc - * \brief The global environment in Relay. + * \file module.cc + * \brief The global module in Relay. */ -#include +#include #include #include @@ -13,8 +13,8 @@ namespace relay { using tvm::IRPrinter; using namespace runtime; -Environment EnvironmentNode::make(tvm::Map global_funcs) { - auto n = make_node(); +Module ModuleNode::make(tvm::Map global_funcs) { + auto n = make_node(); n->functions = std::move(global_funcs); for (const auto& kv : n->functions) { @@ -23,22 +23,22 @@ Environment EnvironmentNode::make(tvm::Map global_funcs) { << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } - return Environment(n); + return Module(n); } -GlobalVar EnvironmentNode::GetGlobalVar(const std::string& name) { +GlobalVar ModuleNode::GetGlobalVar(const std::string& name) { auto it = global_var_map_.find(name); CHECK(it != global_var_map_.end()) - << "Cannot find global var " << name << " in the Environment"; + << "Cannot find global var " << name << " in the Module"; return (*it).second; } -void EnvironmentNode::Add(const GlobalVar& var, +void ModuleNode::Add(const GlobalVar& var, const Function& func, bool update) { - // Type check the item before we add it to the environment. - auto env = GetRef(this); - Function checked_func = InferType(func, env, var); + // Type check the item before we add it to the modironment. + auto mod = GetRef(this); + Function checked_func = InferType(func, mod, var); auto type = checked_func->checked_type(); CHECK(type.as() == nullptr); if (functions.find(var) != functions.end()) { @@ -46,7 +46,7 @@ void EnvironmentNode::Add(const GlobalVar& var, << "Already have definition for " << var->name_hint; auto old_type = functions[var].as()->checked_type(); CHECK(AlphaEqual(type, old_type)) - << "Environment#update changes type, not possible in this mode."; + << "Module#update changes type, not possible in this mode."; } this->functions.Set(var, checked_func); @@ -62,79 +62,79 @@ void EnvironmentNode::Add(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } -void EnvironmentNode::Update(const GlobalVar& var, const Function& func) { +void ModuleNode::Update(const GlobalVar& var, const Function& func) { this->Add(var, func, true); } -void EnvironmentNode::Remove(const GlobalVar& var) { +void ModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->data.erase(var.node_); auto gvar_node = global_var_map_.CopyOnWrite(); gvar_node->data.erase(var->name_hint); } -Function EnvironmentNode::Lookup(const GlobalVar& var) { +Function ModuleNode::Lookup(const GlobalVar& var) { auto it = functions.find(var); CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -Function EnvironmentNode::Lookup(const std::string& name) { +Function ModuleNode::Lookup(const std::string& name) { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } -void EnvironmentNode::Update(const Environment& env) { - for (auto pair : env->functions) { +void ModuleNode::Update(const Module& mod) { + for (auto pair : mod->functions) { this->Update(pair.first, pair.second); } } -TVM_REGISTER_NODE_TYPE(EnvironmentNode); +TVM_REGISTER_NODE_TYPE(ModuleNode); -TVM_REGISTER_API("relay._make.Environment") +TVM_REGISTER_API("relay._make.Module") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = EnvironmentNode::make(args[0]); + *ret = ModuleNode::make(args[0]); }); -TVM_REGISTER_API("relay._env.Environment_Add") +TVM_REGISTER_API("relay._module.Module_Add") .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - env->Add(args[1], args[2], args[3]); + Module mod = args[0]; + mod->Add(args[1], args[2], args[3]); }); -TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") +TVM_REGISTER_API("relay._module.Module_GetGlobalVar") .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - *ret = env->GetGlobalVar(args[1]); + Module mod = args[0]; + *ret = mod->GetGlobalVar(args[1]); }); -TVM_REGISTER_API("relay._env.Environment_Lookup") +TVM_REGISTER_API("relay._module.Module_Lookup") .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; + Module mod = args[0]; GlobalVar var = args[1]; - *ret = env->Lookup(var); + *ret = mod->Lookup(var); }); -TVM_REGISTER_API("relay._env.Environment_Lookup_str") +TVM_REGISTER_API("relay._module.Module_Lookup_str") .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; + Module mod = args[0]; std::string var_name = args[1]; - auto var = env->GetGlobalVar(var_name); - *ret = env->Lookup(var); + auto var = mod->GetGlobalVar(var_name); + *ret = mod->Lookup(var); }); -TVM_REGISTER_API("relay._env.Environment_Update") +TVM_REGISTER_API("relay._module.Module_Update") .set_body([](TVMArgs args, TVMRetValue *ret) { - Environment env = args[0]; - env->Update(args[1]); + Module mod = args[0]; + mod->Update(args[1]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch( - [](const EnvironmentNode *node, tvm::IRPrinter *p) { - p->stream << "EnvironmentNode( " << node->functions << ")"; +.set_dispatch( + [](const ModuleNode *node, tvm::IRPrinter *p) { + p->stream << "ModuleNode( " << node->functions << ")"; }); } // namespace relay diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 8056adc9..04f51a14 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -3,7 +3,7 @@ * \file text_printer.cc * \brief Text printer to print relay in text form. */ -#include +#include #include #include #include "type_functor.h" @@ -133,8 +133,8 @@ class TextPrinter : std::string Print(const NodeRef& node) { if (node.as()) { this->PrintFunc(Downcast(node)); - } else if (node.as()) { - this->PrintEnv(Downcast(node)); + } else if (node.as()) { + this->PrintEnv(Downcast(node)); } else if (node.as_derived()) { this->PrintType(Downcast(node), stream_); } else if (node.as_derived()) { @@ -158,9 +158,9 @@ class TextPrinter : stream_ << "\n"; } - void PrintEnv(const Environment& env) { + void PrintEnv(const Module& mod) { int counter = 0; - for (const auto& kv : env->functions) { + for (const auto& kv : mod->functions) { std::ostringstream os; if (counter++ != 0) { stream_ << "\n"; diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 3aea1293..f5538331 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -20,12 +20,12 @@ namespace relay { using namespace runtime; struct AbstractFusableOps : ExprMutator { - Environment env; + Module mod; Array fusable_funcs; int counter = 0; size_t expr_hash; - AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {} + AbstractFusableOps(Module mod, size_t expr_hash) : mod(mod), expr_hash(expr_hash) {} Expr VisitExpr_(const CallNode* call) { if (auto op_node = call->op.as()) { @@ -55,7 +55,7 @@ struct AbstractFusableOps : ExprMutator { func_name += "_"; func_name += std::to_string(expr_hash); auto gv = GlobalVarNode::make(func_name); - env->Add(gv, func); + mod->Add(gv, func); fusable_funcs.push_back(gv); return CallNode::make(gv, args, Attrs()); } else { @@ -64,12 +64,12 @@ struct AbstractFusableOps : ExprMutator { } }; -Expr FuseOps(const Environment& env, const Expr& e) { +Expr FuseOps(const Module& mod, const Expr& e) { // First we convert all chains of fusable ops into // abstracted functions which we mark as primtive // then we convert these primtive functions into // new operators. - auto abstract = AbstractFusableOps(env, StructuralHash()(e)); + auto abstract = AbstractFusableOps(mod, StructuralHash()(e)); auto abstracted_e = abstract.VisitExpr(e); RELAY_LOG(INFO) << "FuseOps: before=" << e << "Fuse: after=" << abstracted_e; diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 81e72c6d..7253a600 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -99,7 +99,7 @@ struct KindChecker : TypeVisitor { } }; -bool KindCheck(const Type& t, const Environment& env) { +bool KindCheck(const Type& t, const Module& mod) { KindChecker kc; return kc.Check(t); } @@ -107,7 +107,7 @@ bool KindCheck(const Type& t, const Environment& env) { TVM_REGISTER_API("relay._ir_pass.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { - *ret = KindCheck(args[0], EnvironmentNode::make({})); + *ret = KindCheck(args[0], ModuleNode::make({})); } else { *ret = KindCheck(args[0], args[1]); } diff --git a/src/relay/pass/lower_ops.cc b/src/relay/pass/lower_ops.cc index 6bab9a92..f2c8ceba 100644 --- a/src/relay/pass/lower_ops.cc +++ b/src/relay/pass/lower_ops.cc @@ -28,12 +28,12 @@ LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) { } struct AbstractLocalFunctions : ExprMutator { - Environment env; + Module mod; size_t expr_hash; int counter = 0; std::unordered_set visited_funcs; - explicit AbstractLocalFunctions(Environment env) - : env(env), expr_hash(0), counter(0), visited_funcs() {} + explicit AbstractLocalFunctions(Module mod) + : mod(mod), expr_hash(0), counter(0), visited_funcs() {} Expr Abstract(const Expr& e) { expr_hash = StructuralHash()(e); @@ -44,7 +44,7 @@ struct AbstractLocalFunctions : ExprMutator { auto gvar = GetRef(gvar_node); auto it = visited_funcs.find(gvar); if (it == visited_funcs.end()) { - auto func = env->Lookup(gvar); + auto func = mod->Lookup(gvar); visited_funcs.insert(gvar); auto new_func = FunctionNode::make( func->params, @@ -52,7 +52,7 @@ struct AbstractLocalFunctions : ExprMutator { func->ret_type, func->type_params, func->attrs); - env->Update(gvar, new_func); + mod->Update(gvar, new_func); } return gvar; } @@ -70,7 +70,7 @@ struct AbstractLocalFunctions : ExprMutator { abs_func += std::to_string(expr_hash); auto gv = GlobalVarNode::make(abs_func); auto lifted_func = FunctionNode::make(params, func, Type(), {}, {}); - env->Add(gv, lifted_func); + mod->Add(gv, lifted_func); Array args; for (auto free_var : free_vars) { args.push_back(free_var); @@ -80,8 +80,8 @@ struct AbstractLocalFunctions : ExprMutator { }; struct LiveFunctions : ExprVisitor { - Environment env; - explicit LiveFunctions(Environment env) : env(env), global_funcs() {} + Module mod; + explicit LiveFunctions(Module mod) : mod(mod), global_funcs() {} std::unordered_set visited_funcs; std::unordered_set global_funcs; @@ -100,7 +100,7 @@ struct LiveFunctions : ExprVisitor { GlobalVar var = GetRef(var_node); auto it = visited_funcs.find(var); if (it == visited_funcs.end()) { - auto func = env->Lookup(var); + auto func = mod->Lookup(var); visited_funcs.insert(var); // The last pass has trasnformed functions of the form: // @@ -134,7 +134,7 @@ struct LiveFunctions : ExprVisitor { RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef(call); if (auto gv_node = call->op.as()) { GlobalVar gvar = GetRef(gv_node); - Function func = env->Lookup(gvar); + Function func = mod->Lookup(gvar); auto attr = FunctionGetAttr(func, "Primitive"); @@ -159,15 +159,15 @@ using FCompute = TypedPackedFunc( using FSchedule = TypedPackedFunc&, std::string)>; /*! \brief Return the set of operators in their TVM format. */ -Array LowerOps(const Environment& env, const Expr& e, +Array LowerOps(const Module& mod, const Expr& e, const std::string& target) { RELAY_LOG(INFO) << "LowerOps: e=" << e; auto flower_ptr = Registry::Get("relay.op.compiler._lower"); CHECK(flower_ptr); PackedFunc flower = *flower_ptr; - auto abstracted_e = AbstractLocalFunctions(env).Abstract(e); - auto live_funcs = LiveFunctions(env); + auto abstracted_e = AbstractLocalFunctions(mod).Abstract(e); + auto live_funcs = LiveFunctions(mod); live_funcs.VisitExpr(abstracted_e); auto schedule_reg = Op::GetAttr("FTVMSchedule"); @@ -176,7 +176,7 @@ Array LowerOps(const Environment& env, const Expr& e, Array lowered_funcs; for (auto func_name : live_funcs.global_funcs) { - auto func = env->Lookup(func_name); + auto func = mod->Lookup(func_name); auto call = Downcast(func->body); auto op_node = call->op.as(); CHECK(op_node) << "violated invariant that primtiive calls contain a single op call"; @@ -205,7 +205,7 @@ Array LowerOps(const Environment& env, const Expr& e, LoweredFunc lf = flower(op->name + std::to_string(hash), schedule, inputs, outputs); func = FunctionSetAttr(func, "LoweredFunc", lf); - env->Add(func_name, func, true); + mod->Add(func_name, func, true); lowered_funcs.push_back(LoweredOpNode::make(func, lf)); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 864b7ad7..b224a099 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -104,8 +104,8 @@ class TypeInferencer : private ExprFunctor { // constructors TypeInferencer() { } - explicit TypeInferencer(Environment env) - : env_(env) { + explicit TypeInferencer(Module mod) + : mod_(mod) { } // inference the type of expr. @@ -115,7 +115,7 @@ class TypeInferencer : private ExprFunctor { // type resolver that maps back to type class Resolver; // internal environment - Environment env_; + Module mod_; // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -164,9 +164,9 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); - CHECK(env_.defined()) + CHECK(mod_.defined()) << "Cannot do type inference without a global variable"; - Expr e = env_->Lookup(var); + Expr e = mod_->Lookup(var); return e->checked_type(); } @@ -511,20 +511,20 @@ Expr TypeInferencer::Infer(Expr expr) { } -Expr InferType(const Expr& expr, const Environment& env) { - auto e = TypeInferencer(env).Infer(expr); +Expr InferType(const Expr& expr, const Module& mod) { + auto e = TypeInferencer(mod).Infer(expr); CHECK(WellFormed(e)); return e; } Function InferType(const Function& func, - const Environment& env, + const Module& mod, const GlobalVar& var) { Function func_copy = Function(make_node(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); - env->functions.Set(var, func_copy); - Expr func_ret = TypeInferencer(env).Infer(func_copy); - auto map_node = env->functions.CopyOnWrite(); + mod->functions.Set(var, func_copy); + Expr func_ret = TypeInferencer(mod).Infer(func_copy); + auto map_node = mod->functions.CopyOnWrite(); map_node->data.erase(var.node_); CHECK(WellFormed(func_ret)); return Downcast(func_ret); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index e1a81d3c..385bde97 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -11,7 +11,7 @@ TEST(Relay, SelfReference) { auto x = relay::VarNode::make("x", type_a); auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, Array{}); auto fx = relay::CallNode::make(f, Array{ x }); - auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map{})); + auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); CHECK_EQ(type_fx->checked_type(), type_a); } diff --git a/tests/python/relay/test_graph_runtime.py b/tests/python/relay/test_graph_runtime.py index 38acc5df..7b89831d 100644 --- a/tests/python/relay/test_graph_runtime.py +++ b/tests/python/relay/test_graph_runtime.py @@ -6,10 +6,10 @@ from tvm.relay.ir_pass import infer_type from tvm.relay.interpreter import Interpreter from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.op import add -from tvm.relay.env import Environment +from tvm.relay.module import Module # @tq, @jr should we put this in testing ns? -def check_rts(expr, args, expected_result, env=None): +def check_rts(expr, args, expected_result, mod=None): """ Check that evaluating `expr` applied to the arguments produces `result` on both the evaluator and TVM runtime. @@ -25,8 +25,8 @@ def check_rts(expr, args, expected_result, env=None): expected_result: The expected result of running the expression. """ - intrp = create_executor('graph', env=env) - graph = create_executor('graph', env=env) + intrp = create_executor('graph', mod=mod) + graph = create_executor('graph', mod=mod) eval_result = intrp.evaluate(expr)(*args) rts_result = graph.evaluate(expr)(*args) np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) diff --git a/tests/python/relay/test_interpreter.py b/tests/python/relay/test_interpreter.py index f2eaa3d0..b7214965 100644 --- a/tests/python/relay/test_interpreter.py +++ b/tests/python/relay/test_interpreter.py @@ -7,8 +7,8 @@ from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor -def check_eval(expr, args, expected_result, env=None, rtol=1e-07): - intrp = create_executor(env=env) +def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): + intrp = create_executor(mod=mod) result = intrp.evaluate(expr)(*args) np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) @@ -87,7 +87,7 @@ def test_subtract(): check_eval(func, [i_data], 0) def test_simple_loop(): - env = relay.env.Environment({}) + mod = relay.module.Module({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') sb = ScopeBuilder() @@ -98,12 +98,12 @@ def test_simple_loop(): rec_call = relay.Call(sum_up, [one_less]) sb.ret(op.add(rec_call, i)) func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) - env[sum_up] = func + mod[sum_up] = func i_data = np.array(10, dtype='int32') - check_eval(sum_up, [i_data], sum(range(1, 11)), env=env) + check_eval(sum_up, [i_data], sum(range(1, 11)), mod=mod) def test_loop(): - env = relay.env.Environment({}) + mod = relay.module.Module({}) sum_up = relay.GlobalVar('sum_up') i = relay.var('i', shape=[], dtype='int32') accum = relay.var('accum', shape=[], dtype='int32') @@ -115,10 +115,10 @@ def test_loop(): new_accum = op.add(accum, i) sb.ret(relay.Call(sum_up, [one_less, new_accum])) func = relay.Function([i, accum], sb.get()) - env[sum_up] = func + mod[sum_up] = func i_data = np.array(10, dtype='int32') accum_data = np.array(0, dtype='int32') - check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env) + check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod) def test_mlp(): pass diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index aa944bc2..dd790a6d 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -28,7 +28,7 @@ def test_env(): z = relay.add(x, y) z = relay.add(z, z) f = relay.Function([x, y], z) - env = relay.Environment() + env = relay.Module() env["myf"] = f text = env.astext() assert "def @myf" in text diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 31d350dc..c1f06ccc 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -9,8 +9,8 @@ from tvm.relay import op from tvm.relay.scope_builder import ScopeBuilder -def assert_has_type(expr, typ, env=relay.env.Environment({})): - checked_expr = infer_type(expr, env) +def assert_has_type(expr, typ, mod=relay.module.Module({})): + checked_expr = infer_type(expr, mod) checked_type = checked_expr.checked_type if checked_type != typ: raise RuntimeError("Type mismatch %s vs %s" % ( @@ -105,10 +105,10 @@ def test_recursion(): sb.ret(data) with sb.else_scope(): sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) - env = relay.Environment() - env[f] = relay.Function([n, data], sb.get()) - assert "%3 = @f(%1, %2)" in env.astext() - assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32) + mod = relay.Module() + mod[f] = relay.Function([n, data], sb.get()) + assert "%3 = @f(%1, %2)" in mod.astext() + assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) # This currently fails and should pass under the type system. # @@ -179,12 +179,12 @@ def test_self_reference(): def test_global_var_cow_issue(): - env = relay.env.Environment({}) + mod = relay.Module({}) gv = relay.GlobalVar("foo") x = relay.var('x', shape=[]) func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32')) - env[gv] = func + mod[gv] = func def test_equal():