Clean up pass.h (#3312)
This commit is contained in:
Родитель
0af5c21614
Коммит
e3d6074a5b
|
@ -33,7 +33,8 @@ compiler stack.
|
|||
expr
|
||||
frontend
|
||||
image
|
||||
ir_pass
|
||||
analysis
|
||||
transform
|
||||
module
|
||||
nn
|
||||
op
|
||||
|
|
|
@ -18,42 +18,21 @@
|
|||
*/
|
||||
|
||||
/*!
|
||||
* \file tvm/relay/pass.h
|
||||
* \brief The set of Relay passes written in C++.
|
||||
*/
|
||||
#ifndef TVM_RELAY_PASS_H_
|
||||
#define TVM_RELAY_PASS_H_
|
||||
* \file tvm/relay/analysis.h
|
||||
* \brief The set of Relay analysis passes written in C++.
|
||||
*/
|
||||
#ifndef TVM_RELAY_ANALYSIS_H_
|
||||
#define TVM_RELAY_ANALYSIS_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/packed_func_ext.h>
|
||||
#include <tvm/relay/adt.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/relay/adt.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <tvm/runtime/vm.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
/*!
|
||||
* \brief Infer the type of an expression.
|
||||
*
|
||||
* The result of type checking is a new expression with unambigous
|
||||
* type information filled in, as well as it's checked type field
|
||||
* populated with the result type.
|
||||
*
|
||||
* \param expr The expression to type check.
|
||||
* \param mod The module used for referencing global functions, can be
|
||||
* None.
|
||||
*
|
||||
* \return A type checked expression with its checked_type field populated.
|
||||
*/
|
||||
TVM_DLL Expr InferType(const Expr& expr, const Module& mod);
|
||||
|
||||
/*!
|
||||
* \brief Infer the type of a function as if it is mapped to var in the mod.
|
||||
*
|
||||
|
@ -64,7 +43,8 @@ TVM_DLL Expr InferType(const Expr& expr, const Module& mod);
|
|||
* \return A type checked Function with its checked_type field populated.
|
||||
* \note this function mutates mod and is not thread-safe.
|
||||
*/
|
||||
TVM_DLL Function InferType(const Function& f, const Module& mod,
|
||||
TVM_DLL Function InferType(const Function& f,
|
||||
const Module& mod,
|
||||
const GlobalVar& var);
|
||||
|
||||
/*!
|
||||
|
@ -271,58 +251,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
|
|||
*/
|
||||
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
|
||||
|
||||
/*!
|
||||
* \brief Fold constant expressions.
|
||||
*
|
||||
* \param expr the expression to be optimized.
|
||||
*
|
||||
* \return The optimized expression.
|
||||
*/
|
||||
TVM_DLL Expr FoldConstant(const Expr& expr);
|
||||
|
||||
/*!
|
||||
* \brief Fuse operations into expr into seperate functions.
|
||||
*
|
||||
* \param expr The expression.
|
||||
* \param fuse_opt_level Optimization level.
|
||||
* \param mod the module.
|
||||
*
|
||||
* \return The optimized expression.
|
||||
*/
|
||||
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
|
||||
|
||||
/*!
|
||||
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
|
||||
*
|
||||
* \param expr The expression.
|
||||
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
|
||||
* rule function.
|
||||
* \param fcontext Additional callback to provide context argument for each call node.
|
||||
* \param fmulti_ref_trigger Transformation function to be called when
|
||||
* an Expr consumed by multiple callers.
|
||||
* \return The rewritten expression.
|
||||
*/
|
||||
TVM_DLL Expr ForwardRewrite(const Expr& expr,
|
||||
const std::string& rewrite_map_attr_name,
|
||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
|
||||
*
|
||||
* \param expr The expression.
|
||||
* \param rewrite_func The rewrite func that will apply to all operators.
|
||||
* \param fcontext Additional callback to provide context argument for each call node.
|
||||
* \param fmulti_ref_trigger Transformation function to be called when
|
||||
* an Expr consumed by multiple callers.
|
||||
*
|
||||
* \return The rewritten expression.
|
||||
*/
|
||||
TVM_DLL Expr ForwardRewrite(const Expr& expr,
|
||||
const FForwardRewrite& rewrite_func,
|
||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief Rewrite the annotated program.
|
||||
*
|
||||
|
@ -364,19 +292,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
|
|||
*/
|
||||
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
|
||||
|
||||
/*!
|
||||
* \brief Bind the free variables to a Relay expression.
|
||||
*
|
||||
* Parameter binding can only happen if expr is a Function.
|
||||
* binds cannot change internal arguments of internal functions.
|
||||
*
|
||||
* \param expr The function to be binded.
|
||||
* \param binds The map of arguments to
|
||||
*
|
||||
* \return The expression with all free vars bound.
|
||||
*/
|
||||
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
|
||||
|
||||
/*! \brief A hashing structure in the style of std::hash. */
|
||||
struct StructuralHash {
|
||||
/*! \brief Hash a Relay type.
|
||||
|
@ -388,7 +303,6 @@ struct StructuralHash {
|
|||
* \return the hash value.
|
||||
*/
|
||||
size_t operator()(const Type& type) const;
|
||||
|
||||
/*! \brief Hash a Relay expression.
|
||||
*
|
||||
* Implements structural hashing of a Relay expression.
|
||||
|
@ -400,20 +314,7 @@ struct StructuralHash {
|
|||
size_t operator()(const Expr& expr) const;
|
||||
};
|
||||
|
||||
namespace vm {
|
||||
|
||||
/*!
|
||||
* \brief Compile a module, and construct the virtual machine.
|
||||
*
|
||||
* \param mod The module to compile.
|
||||
*
|
||||
* \return The constructed virtual machine.
|
||||
*/
|
||||
runtime::vm::VirtualMachine CompileModule(const Module& mod);
|
||||
|
||||
} // namespace vm
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_RELAY_PASS_H_
|
||||
#endif // TVM_RELAY_ANALYSIS_H_
|
|
@ -378,36 +378,6 @@ TVM_DLL Pass FoldConstant();
|
|||
*/
|
||||
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
|
||||
|
||||
/*!
|
||||
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
|
||||
*
|
||||
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
|
||||
* rule function.
|
||||
* \param fcontext Additional callback to provide context argument for each call node.
|
||||
* \param fmulti_ref_trigger Transformation function to be called when
|
||||
* an Expr consumed by multiple callers.
|
||||
*
|
||||
* \return The pass.
|
||||
*/
|
||||
TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
|
||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||
std::function<Expr(const Expr&)>
|
||||
fmulti_ref_trigger = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
|
||||
*
|
||||
* \param rewrite_func The rewrite func that will apply to all operators.
|
||||
* \param fcontext Additional callback to provide context argument for each call node.
|
||||
* \param fmulti_ref_trigger Transformation function to be called when
|
||||
* an Expr consumed by multiple callers.
|
||||
*
|
||||
* \return The pass.
|
||||
*/
|
||||
TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
|
||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief Rewrite the annotated program.
|
||||
*
|
||||
|
@ -554,21 +524,68 @@ TVM_DLL Pass CanonicalizeCast();
|
|||
*/
|
||||
TVM_DLL Pass EtaExpand();
|
||||
|
||||
/*!
|
||||
* \brief This is a helper function that runs a some optimization passes on
|
||||
* a certain expression and returns the optimized version. With the help of this
|
||||
* function, users don't need to manually construct a module, then perform
|
||||
* passes, and finally and extract the target function/expression from the
|
||||
* returned module frequently.
|
||||
*
|
||||
* \param expr The expression to be optimized.
|
||||
* \param passes The passses that will be applied on the given expression.
|
||||
*
|
||||
* \return The optimized expression.
|
||||
*/
|
||||
TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes);
|
||||
|
||||
} // namespace transform
|
||||
|
||||
/*!
|
||||
* \brief Bind the free variables to a Relay expression. This is a helper
|
||||
* function usually called by other pass functions to help optimizations.
|
||||
*
|
||||
* \param expr The input expression.
|
||||
* \param binds The variable to expression map that will be used to help the
|
||||
* binding.
|
||||
*
|
||||
* \return The updated expression.
|
||||
*/
|
||||
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
|
||||
|
||||
/*!
|
||||
* \brief Infer the type of a function as if it is mapped to var in the mod.
|
||||
*
|
||||
* \param f the function.
|
||||
* \param mod The module used for referencing global functions.
|
||||
* \param var The global variable corresponding to the function.
|
||||
*
|
||||
* \return A type checked Function with its checked_type field populated.
|
||||
* \note this function mutates mod and is not thread-safe.
|
||||
*/
|
||||
TVM_DLL Function InferType(const Function& f,
|
||||
const Module& mod,
|
||||
const GlobalVar& var);
|
||||
|
||||
/*!
|
||||
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
|
||||
* function is used as a helper function to rewrtie an expression in a pass.
|
||||
*
|
||||
* \param expr The expression.
|
||||
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
|
||||
* rule function.
|
||||
* \param fcontext Additional callback to provide context argument for each call node.
|
||||
* \param fmulti_ref_trigger Transformation function to be called when
|
||||
* an Expr consumed by multiple callers.
|
||||
* \return The rewritten expression.
|
||||
*/
|
||||
TVM_DLL Expr ForwardRewrite(const Expr& expr,
|
||||
const std::string& rewrite_map_attr_name,
|
||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||
|
||||
/*!
|
||||
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
|
||||
* function is used as a helper function to rewrtie an expression in a pass.
|
||||
*
|
||||
* \param expr The expression.
|
||||
* \param rewrite_func The rewrite func that will apply to all operators.
|
||||
* \param fcontext Additional callback to provide context argument for each call node.
|
||||
* \param fmulti_ref_trigger Transformation function to be called when
|
||||
* an Expr consumed by multiple callers.
|
||||
*
|
||||
* \return The rewritten expression.
|
||||
*/
|
||||
TVM_DLL Expr ForwardRewrite(const Expr& expr,
|
||||
const FForwardRewrite& rewrite_func,
|
||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import nnvm
|
|||
from nnvm import testing
|
||||
from nnvm import to_relay
|
||||
import tvm
|
||||
from tvm.relay import ir_pass
|
||||
from tvm.relay import transform
|
||||
from tvm.relay import create_executor
|
||||
from tvm.contrib import graph_runtime
|
||||
import numpy as np
|
||||
|
@ -41,10 +41,11 @@ def check_model(sym, shapes, dtypes, params):
|
|||
nnvm_rts.run(**inputs)
|
||||
nnvm_out = nnvm_rts.get_output(0)
|
||||
relay_model, params = to_relay.to_relay(net, shapes, dtypes, params)
|
||||
relay_model = ir_pass.infer_type(relay_model)
|
||||
relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm')
|
||||
mod = tvm.relay.Module.from_expr(relay_model)
|
||||
mod = transform.InferType()(mod)
|
||||
relay_rts = create_executor(kind='graph', mod=mod, ctx=tvm.cpu(0), target='llvm')
|
||||
inputs.update(params)
|
||||
relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values()))
|
||||
relay_out = relay_rts.evaluate()(*list(inputs.values()))
|
||||
np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy())
|
||||
|
||||
# def test_mlp():
|
||||
|
|
|
@ -21,6 +21,7 @@ import threading
|
|||
import topi
|
||||
|
||||
from tvm import relay, autotvm
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
|
||||
from tvm.relay.ty import TupleType, TensorType
|
||||
from tvm.autotvm.task import TaskExtractEnv
|
||||
|
@ -80,6 +81,14 @@ def expr2graph(expr, target_ops, node_dict, node_list):
|
|||
task_pos += 1
|
||||
|
||||
|
||||
def _infer_type(node):
|
||||
"""A method to infer the type of a relay expression."""
|
||||
mod = relay.Module.from_expr(node)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(node, relay.Function) else entry.body
|
||||
|
||||
|
||||
def _expr2graph_impl(expr, target_ops, node_dict, node_list):
|
||||
"""Implementation to convert relay expr to graph data structure
|
||||
"""
|
||||
|
@ -99,7 +108,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
|
|||
node_entry["inputs"] += node_list[in_node_idx]["inputs"]
|
||||
else:
|
||||
node_entry["inputs"].append([in_node_idx, 0, 0])
|
||||
infer_out = relay.ir_pass.infer_type(node)
|
||||
infer_out = _infer_type(node)
|
||||
out_type = infer_out._checked_type_
|
||||
if isinstance(out_type, TensorType):
|
||||
node_entry["types"].append(out_type)
|
||||
|
@ -168,7 +177,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
|
|||
node_dict[node] = node_index
|
||||
node_list.append(node_entry)
|
||||
|
||||
relay.ir_pass.post_order_visit(expr, _traverse_expr)
|
||||
relay.analysis.post_order_visit(expr, _traverse_expr)
|
||||
|
||||
|
||||
def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names):
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
# pylint: disable=eval-used,invalid-name,too-many-arguments
|
||||
"""Utility functions"""
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
|
||||
|
||||
def has_multiple_inputs(node_list, node_idx, input_names):
|
||||
|
@ -107,4 +108,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
|
|||
rebind_dict[var] = updated_input_dict[var.name_hint]
|
||||
updated_expr = relay.expr.bind(expr, rebind_dict)
|
||||
|
||||
return relay.ir_pass.infer_type(updated_expr)
|
||||
mod = relay.Module.from_expr(updated_expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(updated_expr, relay.Function) else entry.body
|
||||
|
|
|
@ -24,7 +24,7 @@ from . import expr
|
|||
from . import expr_functor
|
||||
from . import module
|
||||
from . import adt
|
||||
from . import ir_pass
|
||||
from . import analysis
|
||||
from . import transform
|
||||
from .build_module import build, create_executor
|
||||
from .transform import build_config
|
||||
|
@ -32,6 +32,7 @@ from . import prelude
|
|||
from . import parser
|
||||
from . import debug
|
||||
from . import param_dict
|
||||
from . import feature
|
||||
|
||||
# Root operators
|
||||
from .op import Op
|
||||
|
@ -101,7 +102,7 @@ const = expr.const
|
|||
bind = expr.bind
|
||||
module_pass = transform.module_pass
|
||||
function_pass = transform.function_pass
|
||||
alpha_equal = ir_pass.alpha_equal
|
||||
alpha_equal = analysis.alpha_equal
|
||||
|
||||
# ExprFunctor
|
||||
ExprFunctor = expr_functor.ExprFunctor
|
||||
|
@ -122,3 +123,6 @@ Pass = transform.Pass
|
|||
ModulePass = transform.ModulePass
|
||||
FunctionPass = transform.FunctionPass
|
||||
Sequential = transform.Sequential
|
||||
|
||||
# Feature
|
||||
Feature = feature.Feature
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""FFI exposing the Relay type inference and checking."""
|
||||
"""FFI exposing the passes for Relay program analysis."""
|
||||
|
||||
from tvm._ffi.function import _init_api
|
||||
|
||||
_init_api("relay._ir_pass", __name__)
|
||||
_init_api("relay._analysis", __name__)
|
|
@ -20,7 +20,7 @@
|
|||
This file contains the set of passes for Relay, which exposes an interface for
|
||||
configuring the passes and scripting them in Python.
|
||||
"""
|
||||
from . import _ir_pass
|
||||
from . import _analysis
|
||||
from . import _make
|
||||
from .expr import Expr
|
||||
from .ty import Type
|
||||
|
@ -41,71 +41,7 @@ def post_order_visit(expr, fvisit):
|
|||
fvisit : function
|
||||
The visitor function to be applied.
|
||||
"""
|
||||
return _ir_pass.post_order_visit(expr, fvisit)
|
||||
|
||||
def infer_type(expr, mod=None):
|
||||
"""Infer the type of expr under the context of mod.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
mod: Optional[tvm.relay.Module]
|
||||
The global module.
|
||||
|
||||
Returns
|
||||
-------
|
||||
checked_expr : tvm.relay.Expr
|
||||
The checked expression.
|
||||
"""
|
||||
return _ir_pass.infer_type(expr, mod)
|
||||
|
||||
|
||||
def backward_fold_scale_axis(expr):
|
||||
"""Backward fold axis scaling into weights of conv2d/dense.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression, we expect that expr's types
|
||||
should be fully inferred by infer_type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
folded_expr : tvm.relay.Expr
|
||||
The folded expression after transformation.
|
||||
|
||||
Note
|
||||
----
|
||||
It is recommended to call backward_fold_scale_axis
|
||||
before using forward_fold_scale_axis.
|
||||
As backward folding targets common conv-bn pattern.
|
||||
"""
|
||||
return _ir_pass.backward_fold_scale_axis(expr)
|
||||
|
||||
|
||||
def forward_fold_scale_axis(expr):
|
||||
"""Fold the scaling of axis into weights of conv2d/dense.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression, we expect that expr's types
|
||||
should be fully inferred by infer_type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
folded_expr : tvm.relay.Expr
|
||||
The folded expression after transformation.
|
||||
|
||||
Note
|
||||
----
|
||||
It is recommended to call backward_fold_scale_axis
|
||||
before using forward_fold_scale_axis.
|
||||
As backward folding targets common conv-bn pattern.
|
||||
"""
|
||||
return _ir_pass.forward_fold_scale_axis(expr)
|
||||
return _analysis.post_order_visit(expr, fvisit)
|
||||
|
||||
|
||||
def well_formed(expr):
|
||||
|
@ -121,12 +57,13 @@ def well_formed(expr):
|
|||
well_form : bool
|
||||
Whether the input expression is well formed
|
||||
"""
|
||||
return _ir_pass.well_formed(expr)
|
||||
return _analysis.well_formed(expr)
|
||||
|
||||
|
||||
def check_kind(t, mod=None):
|
||||
"""Check that the type is well kinded and return the kind.
|
||||
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
|
||||
For example, this mean type cannot has tensor of tensor, or is a tuple type
|
||||
of 2 shapes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -149,9 +86,9 @@ def check_kind(t, mod=None):
|
|||
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
|
||||
"""
|
||||
if mod is not None:
|
||||
return _ir_pass.check_kind(t, mod)
|
||||
return _analysis.check_kind(t, mod)
|
||||
else:
|
||||
return _ir_pass.check_kind(t)
|
||||
return _analysis.check_kind(t)
|
||||
|
||||
|
||||
def free_vars(expr):
|
||||
|
@ -173,7 +110,7 @@ def free_vars(expr):
|
|||
neural networks: usually this means weights of previous
|
||||
are ordered first.
|
||||
"""
|
||||
return _ir_pass.free_vars(expr)
|
||||
return _analysis.free_vars(expr)
|
||||
|
||||
|
||||
def bound_vars(expr):
|
||||
|
@ -189,7 +126,7 @@ def bound_vars(expr):
|
|||
free : List[tvm.relay.Var]
|
||||
The list of bound variables in post-DFS order.
|
||||
"""
|
||||
return _ir_pass.bound_vars(expr)
|
||||
return _analysis.bound_vars(expr)
|
||||
|
||||
|
||||
def all_vars(expr):
|
||||
|
@ -205,7 +142,7 @@ def all_vars(expr):
|
|||
free : List[tvm.relay.Var]
|
||||
The list of all variables in post-DFS order.
|
||||
"""
|
||||
return _ir_pass.all_vars(expr)
|
||||
return _analysis.all_vars(expr)
|
||||
|
||||
|
||||
def free_type_vars(expr, mod=None):
|
||||
|
@ -225,7 +162,7 @@ def free_type_vars(expr, mod=None):
|
|||
The list of free type variables in post-DFS order
|
||||
"""
|
||||
use_mod = mod if mod is not None else Module()
|
||||
return _ir_pass.free_type_vars(expr, use_mod)
|
||||
return _analysis.free_type_vars(expr, use_mod)
|
||||
|
||||
|
||||
def bound_type_vars(expr, mod=None):
|
||||
|
@ -245,7 +182,7 @@ def bound_type_vars(expr, mod=None):
|
|||
The list of bound type variables in post-DFS order
|
||||
"""
|
||||
use_mod = mod if mod is not None else Module()
|
||||
return _ir_pass.bound_type_vars(expr, use_mod)
|
||||
return _analysis.bound_type_vars(expr, use_mod)
|
||||
|
||||
|
||||
def all_type_vars(expr, mod=None):
|
||||
|
@ -255,6 +192,7 @@ def all_type_vars(expr, mod=None):
|
|||
----------
|
||||
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
The input expression/type
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module
|
||||
|
||||
|
@ -264,41 +202,7 @@ def all_type_vars(expr, mod=None):
|
|||
The list of all type variables in post-DFS order
|
||||
"""
|
||||
use_mod = mod if mod is not None else Module()
|
||||
return _ir_pass.all_type_vars(expr, use_mod)
|
||||
|
||||
|
||||
def simplify_inference(expr):
|
||||
""" Simplify the data-flow graph for inference phase.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input Expression
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
An expression which is semantically equal to the input expression,
|
||||
but with some simplification
|
||||
"""
|
||||
return _ir_pass.simplify_inference(expr)
|
||||
|
||||
|
||||
def canonicalize_ops(expr):
|
||||
""" Canonicalize special operators to basic operators.
|
||||
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input Expression
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
An expression without bias_add
|
||||
"""
|
||||
return _ir_pass.canonicalize_ops(expr)
|
||||
return _analysis.all_type_vars(expr, use_mod)
|
||||
|
||||
|
||||
def alpha_equal(lhs, rhs):
|
||||
|
@ -342,128 +246,6 @@ def graph_equal(lhs, rhs):
|
|||
return bool(_make._graph_equal(lhs, rhs))
|
||||
|
||||
|
||||
def structural_hash(value):
|
||||
"""Hash a Relay expression structurally.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Union[tvm.relay.Expr, tvm.relay.Type]
|
||||
The expression to hash.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : int
|
||||
The hash value
|
||||
"""
|
||||
if isinstance(value, Expr):
|
||||
return int(_ir_pass._expr_hash(value))
|
||||
elif isinstance(value, Type):
|
||||
return int(_ir_pass._type_hash(value))
|
||||
else:
|
||||
msg = ("found value of type {0} expected" +
|
||||
"relay.Expr or relay.Type").format(type(value))
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def fold_constant(expr):
|
||||
"""Fold the constant expression in expr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
transformed_expr : tvm.relay.Expr
|
||||
The transformed expression.
|
||||
"""
|
||||
return _ir_pass.FoldConstant(expr)
|
||||
|
||||
|
||||
def fuse_ops(expr, opt_level=1, mod=None):
|
||||
"""Fuse operators in expr together.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
opt_level : int
|
||||
The level of fuse optimization.
|
||||
|
||||
mod : tvm.relay.Module
|
||||
The module to perform fusion over.
|
||||
|
||||
Returns
|
||||
-------
|
||||
transformed_expr : tvm.relay.Expr
|
||||
Transformed expression, containing fused result.
|
||||
"""
|
||||
return _ir_pass.FuseOps(expr, opt_level, mod)
|
||||
|
||||
|
||||
def combine_parallel_conv2d(expr, min_num_branches=3):
|
||||
"""Combine multiple conv2d into one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
min_num_branches : int
|
||||
The minimum number of parallel branches when the transformation should be applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
transformed_expr : tvm.relay.Expr
|
||||
Transformed expression
|
||||
"""
|
||||
return _ir_pass.CombineParallelConv2D(expr, min_num_branches)
|
||||
|
||||
|
||||
def alter_op_layout(expr):
|
||||
"""Alternate the layouts of operators or replace primitive operators with
|
||||
other expressions.
|
||||
This pass can be used for computing convolution in custom layouts or
|
||||
other general weight pre-transformation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
transformed_expr : tvm.relay.Expr
|
||||
Transformed expression with alternated layout.
|
||||
"""
|
||||
return _ir_pass.AlterOpLayout(expr)
|
||||
|
||||
|
||||
def rewrite_annotated_ops(expr, fallback_device):
|
||||
"""Rewrite the annotated program where annotation operators, e.g.
|
||||
`on_deivce`, mark which device an expression should be scheduled to.
|
||||
This pass helps heterogeneous execution where different operators may need
|
||||
to be allocated on various devices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
fallback_device : int
|
||||
The fallback device type. It is also used as the default device for
|
||||
operators with no annotated device.
|
||||
|
||||
Returns
|
||||
-------
|
||||
transformed_expr : tvm.relay.Expr
|
||||
Transformed expression with cross device data copy operators.
|
||||
"""
|
||||
return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device)
|
||||
|
||||
|
||||
def collect_device_info(expr):
|
||||
"""Collect the device allocation map for the given expression. The device
|
||||
ids are propagated from the `device_copy` operators.
|
||||
|
@ -478,7 +260,7 @@ def collect_device_info(expr):
|
|||
ret : Dict[tvm.relay.expr, int]
|
||||
A dictionary mapping tvm.relay.Expr to device type.
|
||||
"""
|
||||
return _ir_pass.CollectDeviceInfo(expr)
|
||||
return _analysis.CollectDeviceInfo(expr)
|
||||
|
||||
|
||||
def collect_device_annotation_ops(expr):
|
||||
|
@ -495,38 +277,7 @@ def collect_device_annotation_ops(expr):
|
|||
A dictionary mapping tvm.relay.Expr to device type where the keys are
|
||||
annotation expressions.
|
||||
"""
|
||||
return _ir_pass.CollectDeviceAnnotationOps(expr)
|
||||
|
||||
|
||||
def gradient(expr, mod=None, mode='higher_order'):
|
||||
"""
|
||||
Transform the input function,
|
||||
returning a function that calculate the original result,
|
||||
paired with gradient of the input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression, which is a Function or a GlobalVar.
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
|
||||
mode : Optional[String]
|
||||
The mode of the automatic differentiation algorithm.
|
||||
'first_order' only work on first order code, but will not produce reference nor closure.
|
||||
'higher_order' work on all code using reference and closure.
|
||||
|
||||
Returns
|
||||
-------
|
||||
expr : tvm.relay.Expr
|
||||
The transformed expression.
|
||||
"""
|
||||
if mode == 'first_order':
|
||||
return _ir_pass.first_order_gradient(expr, mod)
|
||||
elif mode == 'higher_order':
|
||||
return _ir_pass.gradient(expr, mod)
|
||||
else:
|
||||
raise Exception('unknown mode')
|
||||
return _analysis.CollectDeviceAnnotationOps(expr)
|
||||
|
||||
|
||||
def get_total_mac_number(expr):
|
||||
|
@ -543,27 +294,7 @@ def get_total_mac_number(expr):
|
|||
result : int64
|
||||
The number of MACs (multiply-accumulate) of a model
|
||||
"""
|
||||
return _ir_pass.GetTotalMacNumber(expr)
|
||||
|
||||
|
||||
def eliminate_common_subexpr(expr, fskip=None):
|
||||
"""
|
||||
Eliminate common subexpressions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
fskip : function
|
||||
The callback function that decides whether an expression should be skipped.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The output expression.
|
||||
"""
|
||||
return _ir_pass.eliminate_common_subexpr(expr, fskip)
|
||||
return _analysis.GetTotalMacNumber(expr)
|
||||
|
||||
|
||||
def unmatched_cases(match, mod=None):
|
||||
|
@ -574,15 +305,16 @@ def unmatched_cases(match, mod=None):
|
|||
----------
|
||||
match : tvm.relay.Match
|
||||
The match expression
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The module (defaults to an empty module)
|
||||
|
||||
Returns
|
||||
-------
|
||||
missing_patterns : [tvm.relay.Pattern]
|
||||
Patterns that the match expression does not catch.
|
||||
Patterns that the match expression does not catch.
|
||||
"""
|
||||
return _ir_pass.unmatched_cases(match, mod)
|
||||
return _analysis.unmatched_cases(match, mod)
|
||||
|
||||
|
||||
def detect_feature(a, b=None):
|
||||
|
@ -605,4 +337,27 @@ def detect_feature(a, b=None):
|
|||
"""
|
||||
if isinstance(a, Module):
|
||||
a, b = b, a
|
||||
return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)])
|
||||
return set([Feature(int(x)) for x in _analysis.detect_feature(a, b)])
|
||||
|
||||
|
||||
def structural_hash(value):
|
||||
"""Hash a Relay expression structurally.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Union[tvm.relay.Expr, tvm.relay.Type]
|
||||
The expression to hash.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : int
|
||||
The hash value
|
||||
"""
|
||||
if isinstance(value, Expr):
|
||||
return int(_analysis._expr_hash(value))
|
||||
elif isinstance(value, Type):
|
||||
return int(_analysis._type_hash(value))
|
||||
else:
|
||||
msg = ("found value of type {0} expected" +
|
||||
"relay.Expr or relay.Type").format(type(value))
|
||||
raise TypeError(msg)
|
|
@ -21,7 +21,7 @@ from __future__ import absolute_import
|
|||
import numpy as np
|
||||
|
||||
from . import _backend
|
||||
from .. import _make, ir_pass, transform
|
||||
from .. import _make, analysis, transform
|
||||
from .. import module
|
||||
from ... import register_func, nd
|
||||
from ..base import NodeBase, register_relay_node
|
||||
|
@ -239,7 +239,7 @@ class Executor(object):
|
|||
return self._make_executor()
|
||||
|
||||
if isinstance(expr, Function):
|
||||
assert not ir_pass.free_vars(expr)
|
||||
assert not analysis.free_vars(expr)
|
||||
|
||||
if isinstance(expr, (Function, GlobalVar)):
|
||||
return self._make_executor(expr)
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import List
|
|||
import tvm
|
||||
from .base import Span, NodeBase
|
||||
from .ty import Type, TypeParam
|
||||
from ._ir_pass import _get_checked_type
|
||||
from ._analysis import _get_checked_type
|
||||
|
||||
|
||||
class Expr(NodeBase):
|
||||
|
@ -128,4 +128,4 @@ class If(Expr):
|
|||
|
||||
def __init__(self, cond, true_value, false_value):
|
||||
# type: (Expr, Expr, Expr) -> None
|
||||
...
|
||||
...
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
"""Caffe2 frontend"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
from .. import op as _op
|
||||
|
@ -450,7 +450,7 @@ class Caffe2NetDef(object):
|
|||
else:
|
||||
outputs = out[0]
|
||||
|
||||
func = _expr.Function(ir_pass.free_vars(outputs), outputs)
|
||||
func = _expr.Function(analysis.free_vars(outputs), outputs)
|
||||
self._mod[self._mod.entry_func] = func
|
||||
|
||||
return self._mod, self._params
|
||||
|
|
|
@ -19,8 +19,8 @@ from __future__ import absolute_import as _abs
|
|||
import logging
|
||||
from topi.util import get_const_tuple
|
||||
from .. import expr as _expr
|
||||
from .. import expr as _expr
|
||||
from .. import ir_pass
|
||||
from .. import module as _module
|
||||
from .. import transform as _transform
|
||||
from .. import op as _op
|
||||
|
||||
|
||||
|
@ -407,9 +407,17 @@ def get_name(node):
|
|||
name = node.name_hint
|
||||
return name
|
||||
|
||||
|
||||
def infer_type(node):
|
||||
"""A method to infer the type of an intermediate node in the relay graph."""
|
||||
mod = _module.Module.from_expr(node)
|
||||
mod = _transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(node, _expr.Function) else entry.body
|
||||
|
||||
def infer_shape(inputs):
|
||||
"""A method to get the output shape of an intermediate node in the graph."""
|
||||
out_type = ir_pass.infer_type(inputs)
|
||||
out_type = infer_type(inputs)
|
||||
out_shapes = get_const_tuple(out_type.checked_type.shape)
|
||||
return out_shapes
|
||||
|
||||
|
@ -417,7 +425,7 @@ def infer_channels(inputs, transpose=False):
|
|||
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
|
||||
these attributes. We check the shape of weights provided to get the number.
|
||||
"""
|
||||
out_type = ir_pass.infer_type(inputs)
|
||||
out_type = infer_type(inputs)
|
||||
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
|
||||
channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
|
||||
return channels
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
import numpy as np
|
||||
import tvm
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
from .. import op as _op
|
||||
|
@ -462,6 +462,6 @@ def from_coreml(model, shape=None):
|
|||
for o in spec.description.output]
|
||||
# for now return first output
|
||||
outexpr = outexpr[0]
|
||||
func = _expr.Function(ir_pass.free_vars(outexpr), outexpr)
|
||||
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
|
||||
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
|
||||
return _module.Module.from_expr(func), params
|
||||
|
|
|
@ -23,7 +23,7 @@ from __future__ import absolute_import as _abs
|
|||
from enum import Enum
|
||||
import numpy as np
|
||||
import tvm
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
from .common import get_relay_op, new_var
|
||||
|
@ -820,7 +820,7 @@ class GraphProto(object):
|
|||
|
||||
outputs = _as_list(sym) + self._outs
|
||||
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
|
||||
sym = _expr.Function(ir_pass.free_vars(outputs), outputs)
|
||||
sym = _expr.Function(analysis.free_vars(outputs), outputs)
|
||||
return _module.Module.from_expr(sym), self._tvmparams
|
||||
|
||||
def from_darknet(net,
|
||||
|
|
|
@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
|
|||
import sys
|
||||
import numpy as np
|
||||
import tvm
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
from .. import op as _op
|
||||
|
@ -743,6 +743,6 @@ def from_keras(model, shape=None):
|
|||
outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
|
||||
for oc in model._output_coordinates]
|
||||
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
|
||||
func = _expr.Function(ir_pass.free_vars(outexpr), outexpr)
|
||||
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
|
||||
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
|
||||
return _module.Module.from_expr(func), params
|
||||
|
|
|
@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
|
|||
|
||||
import json
|
||||
import tvm
|
||||
from .. import ir_pass
|
||||
from .. import analysis, transform
|
||||
from .. import expr as _expr
|
||||
from .. import op as _op
|
||||
from .. import module as _module
|
||||
|
@ -41,6 +41,13 @@ _activation_map = {
|
|||
"relu" : _op.nn.relu
|
||||
}
|
||||
|
||||
def _infer_type(node):
|
||||
"""A method to infer the type of an intermediate node in the relay graph."""
|
||||
mod = _module.Module.from_expr(node)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(node, _expr.Function) else entry.body
|
||||
|
||||
def _mx_fully_connected(inputs, attrs):
|
||||
import mxnet as mx
|
||||
units = attrs.get_int("num_hidden")
|
||||
|
@ -89,7 +96,8 @@ def _mx_activations(inputs, attrs):
|
|||
|
||||
def _mx_compare(new_op, wrapper):
|
||||
def impl(inputs, attrs):
|
||||
dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype
|
||||
expr = _infer_type(inputs[0])
|
||||
dtype = expr.checked_type.dtype
|
||||
return wrapper(new_op)(inputs, attrs).astype(dtype)
|
||||
return impl
|
||||
|
||||
|
@ -258,7 +266,8 @@ def _mx_slice_like(inputs, attrs):
|
|||
|
||||
def _mx_slice_axis(inputs, attrs):
|
||||
assert len(inputs) == 1
|
||||
shape = ir_pass.infer_type(inputs[0]).checked_type.shape
|
||||
expr = _infer_type(inputs[0])
|
||||
shape = expr.checked_type.shape
|
||||
axis = attrs.get_int("axis")
|
||||
ax_beg = attrs.get_int("begin")
|
||||
ax_end = attrs.get_str("end")
|
||||
|
@ -302,7 +311,8 @@ def _mx_crop_like(inputs, attrs):
|
|||
if offset == (0, 0):
|
||||
new_attrs["axes"] = (2, 3)
|
||||
return _op.slice_like(*inputs, **new_attrs)
|
||||
like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape
|
||||
expr = _infer_type(inputs[1])
|
||||
like_shape = expr.checked_type.shape
|
||||
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
|
||||
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
|
||||
offset[1]+like_shape[3]]
|
||||
|
@ -532,7 +542,8 @@ def _mx_resize(inputs, attrs):
|
|||
scale_width = attrs.get_float("scale_width", None)
|
||||
height = attrs.get_int("height", 1)
|
||||
width = attrs.get_int("width", 1)
|
||||
shape = ir_pass.infer_type(inputs[0]).checked_type.shape
|
||||
expr = _infer_type(inputs[0])
|
||||
shape = expr.checked_type.shape
|
||||
if scale_height is not None:
|
||||
height = (scale_height * shape[2]).astype("int32")
|
||||
if scale_width is not None:
|
||||
|
@ -639,7 +650,8 @@ def _mx_broadcast_axis(inputs, attrs):
|
|||
assert len(axis) == len(size)
|
||||
if len(axis) == 0:
|
||||
return inputs[0]
|
||||
src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape
|
||||
expr = _infer_type(inputs[0])
|
||||
src_shape = expr.checked_type.shape
|
||||
tgt_shape = []
|
||||
for i, dim in enumerate(src_shape):
|
||||
if i not in axis:
|
||||
|
@ -734,7 +746,8 @@ def _mx_rnn_layer(inputs, attrs):
|
|||
return out, [out]
|
||||
|
||||
def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
|
||||
dtype = ir_pass.infer_type(data).checked_type.dtype
|
||||
expr = _infer_type(data)
|
||||
dtype = expr.checked_type.dtype
|
||||
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
|
||||
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
|
||||
i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1)
|
||||
|
@ -776,7 +789,8 @@ def _mx_rnn_layer(inputs, attrs):
|
|||
seq_data = inputs[0]
|
||||
concat_weight = inputs[1]
|
||||
init_states = inputs[2:]
|
||||
data_shape = ir_pass.infer_type(seq_data).checked_type.shape
|
||||
expr = _infer_type(seq_data)
|
||||
data_shape = expr.checked_type.shape
|
||||
seq_len = int(data_shape[0])
|
||||
assert len(concat_weight) == num_layers * 4 * direct
|
||||
|
||||
|
@ -1099,7 +1113,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None):
|
|||
|
||||
outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
|
||||
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
|
||||
func = _expr.Function(ir_pass.free_vars(outputs), outputs)
|
||||
func = _expr.Function(analysis.free_vars(outputs), outputs)
|
||||
return func
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import logging
|
|||
import numpy as np
|
||||
import tvm
|
||||
from ... import nd as _nd
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import transform as _transform
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
|
@ -412,7 +412,7 @@ class Reshape(OnnxOpConverter):
|
|||
else:
|
||||
data, shape = inputs
|
||||
logging.warning("Constant evaluating Reshape's shape argument, may reduce performance")
|
||||
shape_params = ir_pass.free_vars(shape)
|
||||
shape_params = analysis.free_vars(shape)
|
||||
func = _expr.Function(shape_params, shape)
|
||||
mod = _module.Module.from_expr(func)
|
||||
seq = _transform.Sequential([_transform.InferType(),
|
||||
|
@ -1106,7 +1106,7 @@ class GraphProto(object):
|
|||
# now return the outputs
|
||||
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
|
||||
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
|
||||
func = _expr.Function(ir_pass.free_vars(outputs), outputs)
|
||||
func = _expr.Function(analysis.free_vars(outputs), outputs)
|
||||
return _module.Module.from_expr(func), self._params
|
||||
|
||||
def _parse_value_proto(self, value_proto):
|
||||
|
|
|
@ -27,7 +27,8 @@ import numpy as np
|
|||
|
||||
import tvm
|
||||
from topi.util import get_const_tuple
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import transform as _transform
|
||||
from .. import expr as _expr
|
||||
from .. import op as _op
|
||||
from ..expr_functor import ExprMutator
|
||||
|
@ -38,9 +39,9 @@ __all__ = ['from_tensorflow']
|
|||
def _infer_value(input_val, params):
|
||||
from tvm.contrib import graph_runtime
|
||||
# Check that all free variables have associated parameters.
|
||||
assert all(var.name_hint in params.keys() for var in ir_pass.free_vars(
|
||||
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
|
||||
input_val)), "All inputs to infer must be available in params."
|
||||
func = _expr.Function(ir_pass.free_vars(input_val), input_val)
|
||||
func = _expr.Function(analysis.free_vars(input_val), input_val)
|
||||
with tvm.relay.build_config(opt_level=0):
|
||||
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
|
||||
ctx = tvm.context("llvm", 0)
|
||||
|
@ -235,9 +236,16 @@ def _infer_out_shapes(inputs, params):
|
|||
"""A method to get the output shape of intermediate nodes in the relay graph."""
|
||||
return [_infer_shape(inputs, params)]
|
||||
|
||||
def _infer_type(node):
|
||||
"""A method to infer the type of an intermediate node in the relay graph."""
|
||||
mod = _module.Module.from_expr(node)
|
||||
mod = _transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(node, _expr.Function) else entry.body
|
||||
|
||||
def _infer_shape(node, params=None):
|
||||
"""A method to get the output shape of an intermediate node in the relay graph."""
|
||||
out_type = ir_pass.infer_type(node)
|
||||
out_type = _infer_type(node)
|
||||
return get_const_tuple(out_type.checked_type.shape)
|
||||
|
||||
def _get_param(params, input_node):
|
||||
|
@ -1841,7 +1849,8 @@ class Loop:
|
|||
bind_map = {}
|
||||
for i, var in enumerate(self.loop_vars):
|
||||
if not isinstance(var, _expr.Var):
|
||||
var_type = ir_pass.infer_type(var).checked_type
|
||||
var_chk = _infer_type(var)
|
||||
var_type = var_chk.checked_type
|
||||
else:
|
||||
var_type = var.type_annotation
|
||||
|
||||
|
@ -2112,7 +2121,7 @@ class GraphProto(object):
|
|||
out.append(out_rnn)
|
||||
|
||||
out = out[0] if len(out) == 1 else _expr.Tuple(out)
|
||||
func = _expr.Function(ir_pass.free_vars(out), out)
|
||||
func = _expr.Function(analysis.free_vars(out), out)
|
||||
self._mod[self._mod.entry_func] = func
|
||||
return self._mod, self._params
|
||||
|
||||
|
@ -2329,7 +2338,8 @@ class GraphProto(object):
|
|||
else:
|
||||
if node_name_prefix not in self._branches:
|
||||
self._branches[node_name_prefix] = Branch()
|
||||
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
|
||||
chk_op = _infer_type(op[0])
|
||||
self._branches[node_name_prefix].cond = chk_op
|
||||
elif node.op == "NextIteration":
|
||||
op = self._nodes[node.input[0]]
|
||||
assert len(op) == 1
|
||||
|
|
|
@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
|
|||
import math
|
||||
import numpy as np
|
||||
import tvm
|
||||
from .. import ir_pass
|
||||
from .. import analysis
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
from .. import op as _op
|
||||
|
@ -914,5 +914,5 @@ def from_tflite(model, shape_dict, dtype_dict):
|
|||
params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
|
||||
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
|
||||
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
|
||||
func = _expr.Function(ir_pass.free_vars(outputs), outputs)
|
||||
func = _expr.Function(analysis.free_vars(outputs), outputs)
|
||||
return _module.Module.from_expr(func), params
|
||||
|
|
|
@ -79,15 +79,6 @@ class Module(RelayNode):
|
|||
if isinstance(val, _expr.Expr):
|
||||
if isinstance(var, _base.string_types):
|
||||
var = _expr.GlobalVar(var)
|
||||
|
||||
# TODO(@jroesch): Port this logic to C++.
|
||||
if not isinstance(val, _expr.Function):
|
||||
if isinstance(val, _expr.GlobalVar):
|
||||
val = ir_pass.eta_expand(val, self)
|
||||
else:
|
||||
val = _expr.Function([], val)
|
||||
|
||||
|
||||
_make.Module_Add(self, var, val, update)
|
||||
else:
|
||||
assert isinstance(val, _ty.Type)
|
||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
from . import _quantize
|
||||
from .. import expr as _expr
|
||||
from .. import module as _module
|
||||
from .. import ir_pass as _ir_pass
|
||||
from .. import analysis as _analysis
|
||||
from .. import transform as _transform
|
||||
from .. import op as _op
|
||||
from ... import make as _make
|
||||
|
@ -250,7 +250,7 @@ def calibrate(graph, mod=None, ctx=None):
|
|||
const_params[nclip_min] = _make_const(- (valid_range - 1))
|
||||
const_params[nclip_max] = _make_const((valid_range - 1))
|
||||
|
||||
_ir_pass.post_order_visit(graph, visit_func)
|
||||
_analysis.post_order_visit(graph, visit_func)
|
||||
return _expr.bind(graph, const_params)
|
||||
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None,
|
|||
dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
|
||||
tanh = relay.tanh(dc64)
|
||||
|
||||
args = relay.ir_pass.free_vars(tanh)
|
||||
args = relay.analysis.free_vars(tanh)
|
||||
return relay.Function(args, tanh)
|
||||
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ def _make_dense_net(num_init_features, growth_rate, block_config,
|
|||
|
||||
ret = layers.dense_add_bias(flat, units=classes, name='dense')
|
||||
|
||||
return relay.Function(relay.ir_pass.free_vars(ret), ret)
|
||||
return relay.Function(relay.analysis.free_vars(ret), ret)
|
||||
|
||||
def get_workload(densenet_size=121, classes=1000, batch_size=4,
|
||||
image_shape=(3, 224, 224), dtype='float32'):
|
||||
|
|
|
@ -54,7 +54,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"
|
|||
relu4 = relay.nn.relu(dense1)
|
||||
dense2 = layers.dense_add_bias(relu4, units=num_actions, name="dense2")
|
||||
|
||||
args = relay.ir_pass.free_vars(dense2)
|
||||
args = relay.analysis.free_vars(dense2)
|
||||
return relay.Function(args, dense2)
|
||||
|
||||
|
||||
|
|
|
@ -266,7 +266,7 @@ def get_net(batch_size,
|
|||
fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes)
|
||||
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1)
|
||||
inception_v3 = relay.nn.softmax(data=fc1)
|
||||
args = relay.ir_pass.free_vars(inception_v3)
|
||||
args = relay.analysis.free_vars(inception_v3)
|
||||
return relay.Function(args, inception_v3)
|
||||
|
||||
def get_workload(batch_size=1, num_classes=1000,
|
||||
|
|
|
@ -150,10 +150,11 @@ def create_workload(net, initializer=None, seed=0):
|
|||
params : dict of str to NDArray
|
||||
The parameters.
|
||||
"""
|
||||
net = relay.ir_pass.infer_type(net)
|
||||
mod = relay.Module.from_expr(net)
|
||||
mod = relay.transform.InferType()(mod)
|
||||
net = mod[mod.entry_func]
|
||||
shape_dict = {
|
||||
v.name_hint : v.checked_type for v in net.params}
|
||||
net.astext()
|
||||
np.random.seed(seed)
|
||||
initializer = initializer if initializer else Xavier()
|
||||
params = {}
|
||||
|
|
|
@ -154,7 +154,7 @@ def get_net(iterations, num_hidden, batch_size=1, dtype="float32"):
|
|||
|
||||
builder.ret(out)
|
||||
body = builder.get()
|
||||
args = relay.ir_pass.free_vars(body)
|
||||
args = relay.analysis.free_vars(body)
|
||||
return relay.Function(args, body, input_type)
|
||||
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def get_net(batch_size,
|
|||
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes)
|
||||
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1)
|
||||
mlp = relay.nn.softmax(data=fc3)
|
||||
args = relay.ir_pass.free_vars(mlp)
|
||||
args = relay.analysis.free_vars(mlp)
|
||||
return relay.Function(args, mlp)
|
||||
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
|
|||
weight = relay.var('fc_weight')
|
||||
fc = relay.nn.dense(data=flatten, weight=weight, units=num_classes)
|
||||
softmax = relay.nn.softmax(data=fc)
|
||||
return relay.Function(relay.ir_pass.free_vars(softmax), softmax)
|
||||
return relay.Function(relay.analysis.free_vars(softmax), softmax)
|
||||
|
||||
|
||||
def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtype='float32'):
|
||||
|
|
|
@ -169,7 +169,7 @@ def resnet(units,
|
|||
flat = relay.nn.batch_flatten(data=pool1)
|
||||
fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
|
||||
net = relay.nn.softmax(data=fc1)
|
||||
return relay.Function(relay.ir_pass.free_vars(net), net)
|
||||
return relay.Function(relay.analysis.free_vars(net), net)
|
||||
|
||||
|
||||
def get_net(batch_size,
|
||||
|
|
|
@ -119,7 +119,7 @@ def get_net(batch_size, image_shape, num_classes, version, dtype):
|
|||
net = relay.nn.global_avg_pool2d(net)
|
||||
net = relay.nn.batch_flatten(net)
|
||||
net = relay.nn.softmax(net)
|
||||
args = relay.ir_pass.free_vars(net)
|
||||
args = relay.analysis.free_vars(net)
|
||||
return relay.Function(args, net)
|
||||
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no
|
|||
feature = get_feature(data, layers, filters, batch_norm)
|
||||
classifier = get_classifier(feature, num_classes)
|
||||
symbol = relay.nn.softmax(data=classifier)
|
||||
args = relay.ir_pass.free_vars(symbol)
|
||||
args = relay.analysis.free_vars(symbol)
|
||||
return relay.Function(args, symbol)
|
||||
|
||||
|
||||
|
|
|
@ -277,6 +277,40 @@ def FoldScaleAxis():
|
|||
return _transform.FoldScaleAxis()
|
||||
|
||||
|
||||
def BackwardFoldScaleAxis():
|
||||
"""Backward fold axis scaling into weights of conv2d/dense.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret : tvm.relay.Pass
|
||||
The registered pass to backward fold expressions.
|
||||
|
||||
Note
|
||||
----
|
||||
It is recommended to call backward_fold_scale_axis
|
||||
before using forward_fold_scale_axis.
|
||||
As backward folding targets common conv-bn pattern.
|
||||
"""
|
||||
return _transform.BackwardFoldScaleAxis()
|
||||
|
||||
|
||||
def ForwardFoldScaleAxis():
|
||||
"""Fold the scaling of axis into weights of conv2d/dense.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret : tvm.relay.Pass
|
||||
The registered pass to forward fold expressions.
|
||||
|
||||
Note
|
||||
----
|
||||
It is recommended to call backward_fold_scale_axis
|
||||
before using forward_fold_scale_axis.
|
||||
As backward folding targets common conv-bn pattern.
|
||||
"""
|
||||
return _transform.ForwardFoldScaleAxis()
|
||||
|
||||
|
||||
def SimplifyInference():
|
||||
"""Simplify the data-flow graph for inference phase. An simplified expression
|
||||
which is semantically equal to the input expression will be returned.
|
||||
|
@ -406,7 +440,7 @@ def ToANormalForm():
|
|||
|
||||
Returns
|
||||
-------
|
||||
ret: tvm.relay.Pass
|
||||
ret: Union[tvm.relay.Pass, tvm.relay.Expr]
|
||||
The registered pass that transforms an expression into A Normal Form.
|
||||
"""
|
||||
return _transform.ToANormalForm()
|
||||
|
@ -454,6 +488,21 @@ def EliminateCommonSubexpr(fskip=None):
|
|||
def PartialEvaluate():
|
||||
"""Evaluate the static fragment of the code.
|
||||
|
||||
Note
|
||||
----
|
||||
This transformation could be either `Module -> Module` or `Expr -> Expr`.
|
||||
It will directly transform the input expression to a new one if the target
|
||||
expression is provided. Otherwise, it will rely on the pass manager to
|
||||
carry out transformation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Optional[tvm.relay.Expr]
|
||||
The input expression.
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret: tvm.relay.Pass
|
||||
|
@ -461,6 +510,7 @@ def PartialEvaluate():
|
|||
"""
|
||||
return _transform.PartialEvaluate()
|
||||
|
||||
|
||||
def CanonicalizeCast():
|
||||
"""
|
||||
Canonicalize cast expressions to make operator fusion more efficient.
|
||||
|
@ -473,28 +523,35 @@ def CanonicalizeCast():
|
|||
return _transform.CanonicalizeCast()
|
||||
|
||||
|
||||
def OptimizeOnExpr(expr, passes):
|
||||
"""Perform optimization passes on an expressioin.
|
||||
def gradient(expr, mod=None, mode='higher_order'):
|
||||
"""
|
||||
Transform the input function,
|
||||
returning a function that calculate the original result,
|
||||
paired with gradient of the input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
The expression for optimization.
|
||||
expr : tvm.relay.Expr
|
||||
The input expression, which is a Function or a GlobalVar.
|
||||
|
||||
passes: Union[Pass, List[Pass]]
|
||||
The list of optimizations to be applied.
|
||||
mod : Optional[tvm.relay.Module]
|
||||
|
||||
mode : Optional[String]
|
||||
The mode of the automatic differentiation algorithm.
|
||||
'first_order' only works on first order code, but will not produce
|
||||
reference nor closure.
|
||||
'higher_order' works on all code using reference and closure.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret: tvm.relay.Expr
|
||||
The optimized expression.
|
||||
expr : tvm.relay.Expr
|
||||
The transformed expression.
|
||||
"""
|
||||
if isinstance(passes, Pass):
|
||||
passes = [passes]
|
||||
if not isinstance(passes, (list, tuple)):
|
||||
raise TypeError("passes must be a pass or a list of pass objects.")
|
||||
|
||||
return _transform.OptimizeOnExpr(expr, passes)
|
||||
if mode == 'first_order':
|
||||
return _transform.first_order_gradient(expr, mod)
|
||||
if mode == 'higher_order':
|
||||
return _transform.gradient(expr, mod)
|
||||
raise Exception('unknown mode')
|
||||
|
||||
|
||||
def _wrap_class_module_pass(pass_cls, pass_info):
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
* \file relay/backend/build_module.cc
|
||||
* \brief Code generation for TVM's graph runtime.
|
||||
*/
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/build_module.h>
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include <tvm/operation.h>
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <tvm/relay/attrs/device_copy.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <topi/tags.h>
|
||||
|
|
|
@ -27,8 +27,9 @@
|
|||
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
|
||||
|
||||
#include <tvm/lowered_func.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
*/
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include "../../common/arena.h"
|
||||
|
||||
namespace tvm {
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/relay/interpreter.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/attrs/debug.h>
|
||||
#include "compile_engine.h"
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
#include <dmlc/json.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/tvm.h>
|
||||
#include <tvm/build_module.h>
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/logging.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <tvm/runtime/vm.h>
|
||||
#include <iostream>
|
||||
|
|
|
@ -28,17 +28,18 @@
|
|||
#include <tvm/logging.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/runtime/vm.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
namespace vm {
|
||||
|
||||
runtime::vm::VirtualMachine CompileModule(const Module& mod);
|
||||
|
||||
using tvm::runtime::Object;
|
||||
using tvm::runtime::ObjectTag;
|
||||
using tvm::runtime::vm::VirtualMachine;
|
||||
|
||||
|
||||
VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
|
||||
auto vm = CompileModule(module);
|
||||
vm.Init(ctxs);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/runtime/ndarray.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include "type_functor.h"
|
||||
#include "../../lang/attr_functor.h"
|
||||
|
||||
|
|
|
@ -345,7 +345,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
|
|||
ExprApplyVisit(fvisit).VisitExpr(e);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.post_order_visit")
|
||||
TVM_REGISTER_API("relay._analysis.post_order_visit")
|
||||
.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
|
||||
PostOrderVisit(expr, [f](const Expr& n) {
|
||||
f(n);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/runtime/ndarray.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/attrs.h>
|
||||
#include "type_functor.h"
|
||||
#include "../../lang/attr_functor.h"
|
||||
|
@ -412,12 +412,12 @@ size_t StructuralHash::operator()(const Expr& expr) const {
|
|||
return RelayHashHandler().ExprHash(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass._expr_hash")
|
||||
TVM_REGISTER_API("relay._analysis._expr_hash")
|
||||
.set_body_typed<int64_t(NodeRef)>([](NodeRef ref) {
|
||||
return static_cast<int64_t>(RelayHashHandler().Hash(ref));
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass._type_hash")
|
||||
TVM_REGISTER_API("relay._analysis._type_hash")
|
||||
.set_body_typed<int64_t(Type)>([](Type type) {
|
||||
return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
|
||||
});
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
* \brief The global module in Relay.
|
||||
*/
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace tvm {
|
||||
|
@ -184,7 +185,26 @@ TVM_REGISTER_API("relay._make.Module")
|
|||
.set_body_typed(ModuleNode::make);
|
||||
|
||||
TVM_REGISTER_API("relay._make.Module_Add")
|
||||
.set_body_method<Module>(&ModuleNode::Add);
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
Module mod = args[0];
|
||||
GlobalVar var = args[1];
|
||||
NodeRef val = args[2];
|
||||
bool update = args[3];
|
||||
CHECK(val->derived_from<ExprNode>());
|
||||
if (val->derived_from<FunctionNode>()) {
|
||||
mod->Add(var, Downcast<Function>(val), update);
|
||||
} else if (val->derived_from<GlobalVarNode>()) {
|
||||
GlobalVar gv = Downcast<GlobalVar>(val);
|
||||
auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
|
||||
mod_copy = transform::EtaExpand()(mod_copy);
|
||||
auto func = mod_copy->Lookup(gv->name_hint);
|
||||
mod->Add(var, Downcast<Function>(func), update);
|
||||
} else {
|
||||
auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {});
|
||||
mod->Add(var, func, update);
|
||||
}
|
||||
*ret = mod;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._module.Module_AddDef")
|
||||
.set_body_method<Module>(&ModuleNode::AddDef);
|
||||
|
@ -197,39 +217,39 @@ TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
|
|||
|
||||
TVM_REGISTER_API("relay._module.Module_Lookup")
|
||||
.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
|
||||
return mod->Lookup(var);
|
||||
});
|
||||
return mod->Lookup(var);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._module.Module_Lookup_str")
|
||||
.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
|
||||
return mod->Lookup(var);
|
||||
});
|
||||
return mod->Lookup(var);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._module.Module_LookupDef")
|
||||
.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
|
||||
return mod->LookupDef(var);
|
||||
});
|
||||
return mod->LookupDef(var);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._module.Module_LookupDef_str")
|
||||
.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
|
||||
return mod->LookupDef(var);
|
||||
});
|
||||
return mod->LookupDef(var);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._module.Module_FromExpr")
|
||||
.set_body_typed<Module(Expr)>([](Expr e) {
|
||||
return ModuleNode::FromExpr(e);
|
||||
return ModuleNode::FromExpr(e);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._module.Module_Update")
|
||||
.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
|
||||
mod->Update(from);
|
||||
});
|
||||
mod->Update(from);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<ModuleNode>(
|
||||
[](const ModuleNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "ModuleNode( " << node->functions << ")";
|
||||
});
|
||||
[](const ModuleNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "ModuleNode( " << node->functions << ")";
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
other expressions. This pass can be used for computing convolution in
|
||||
custom layouts or other general weight pre-transformation.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <tvm/relay/attrs/transform.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
@ -348,9 +349,6 @@ Expr AlterOpLayout(const Expr& expr) {
|
|||
return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
|
||||
.set_body_typed(AlterOpLayout);
|
||||
|
||||
} // namespace alter_op_layout
|
||||
|
||||
namespace transform {
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
* \file canonicalize_cast.cc
|
||||
* \brief Canonicalize cast expressions to make operator fusion more efficient.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
* \brief Canonicalize special operators to basic operators.
|
||||
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
@ -61,9 +61,6 @@ Expr CanonicalizeOps(const Expr& e) {
|
|||
return BiasAddSimplifier().Mutate(e);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
|
||||
.set_body_typed(CanonicalizeOps);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass CanonicalizeOps() {
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
* convolution branches, such as Inception block.
|
||||
*/
|
||||
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/attrs/transform.h>
|
||||
|
@ -355,9 +355,6 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
|
|||
return ParallelConv2DCombiner(min_num_branches).Combine(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
|
||||
.set_body_typed(CombineParallelConv2D);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass CombineParallelConv2D(uint64_t min_num_branches) {
|
||||
|
|
|
@ -28,8 +28,9 @@
|
|||
* CalcDep turn an expr into a dependency graph of expr,
|
||||
* GenLet turn the dependency graph into a let list, taking only the used value.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include "let_list.h"
|
||||
|
||||
namespace tvm {
|
||||
|
|
|
@ -34,7 +34,6 @@
|
|||
#include <tvm/relay/attrs/annotation.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
||||
#include <memory>
|
||||
|
@ -559,13 +558,13 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
|
|||
return AnnotatationVisitor::GetAnnotations(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
|
||||
TVM_REGISTER_API("relay._analysis.CollectDeviceInfo")
|
||||
.set_body_typed(CollectDeviceInfo);
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
|
||||
TVM_REGISTER_API("relay._analysis.RewriteDeviceAnnotation")
|
||||
.set_body_typed(RewriteAnnotatedOps);
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
|
||||
TVM_REGISTER_API("relay._analysis.CollectDeviceAnnotationOps")
|
||||
.set_body_typed(CollectDeviceAnnotationOps);
|
||||
|
||||
namespace transform {
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
* to replace an expression with a previously appeared expression with the same input and
|
||||
* attributes. The fskip callback argument allows us to skip specific expressions.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <unordered_map>
|
||||
|
@ -85,9 +85,6 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
|
|||
return CommonSubexprEliminator(callback)(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr")
|
||||
.set_body_typed<Expr(Expr, PackedFunc)>(EliminateCommonSubexpr);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass EliminateCommonSubexpr(PackedFunc fskip) {
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
* \brief Add abstraction over a function. For example, abs will become (fun x -> abs x).
|
||||
*
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
@ -44,10 +45,8 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
|
|||
original_type_params = func->type_params;
|
||||
ret_type = func->ret_type;
|
||||
} else {
|
||||
auto inferred = InferType(e, mod);
|
||||
CHECK(inferred->is_type<FunctionNode>());
|
||||
|
||||
auto func = GetRef<Function>(inferred.as_derived<FunctionNode>());
|
||||
CHECK(e->is_type<FunctionNode>());
|
||||
auto func = GetRef<Function>(e.as_derived<FunctionNode>());
|
||||
original_params = func->params;
|
||||
original_type_params = func->type_params;
|
||||
ret_type = func->ret_type;
|
||||
|
@ -62,19 +61,18 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
|
|||
auto new_func =
|
||||
FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params);
|
||||
|
||||
return InferType(new_func, mod);
|
||||
return new_func;
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass EtaExpand() {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
return Downcast<Function>(EtaExpand(f, m));
|
||||
};
|
||||
return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
|
||||
return Downcast<Function>(EtaExpand(f, m));
|
||||
};
|
||||
Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {});
|
||||
return Sequential({expanded, InferType()});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._transform.EtaExpand")
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
* \brief Detect features used in Expr/Module
|
||||
*/
|
||||
#include <tvm/relay/feature.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/module.h>
|
||||
|
@ -97,7 +97,7 @@ Array<Integer> PyDetectFeature(const Expr& expr, const Module& mod) {
|
|||
return static_cast<Array<Integer>>(fs);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.detect_feature")
|
||||
TVM_REGISTER_API("relay._analysis.detect_feature")
|
||||
.set_body_typed(PyDetectFeature);
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
* Copyright (c) 2018 by Contributors
|
||||
* \file constant_folding.cc
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <tvm/relay/interpreter.h>
|
||||
|
@ -156,9 +156,13 @@ class ConstantFolder : public ExprMutator {
|
|||
}
|
||||
// Constant evaluate a expression.
|
||||
Expr ConstEvaluate(Expr expr) {
|
||||
expr = InferType(expr, Module(nullptr));
|
||||
expr = FuseOps(expr, 0, Module(nullptr));
|
||||
expr = InferType(expr, Module(nullptr));
|
||||
std::vector<transform::Pass> passes = {transform::FuseOps(0),
|
||||
transform::InferType()};
|
||||
auto mod = ModuleNode::FromExpr(expr);
|
||||
auto seq = transform::Sequential(passes);
|
||||
mod = seq(mod);
|
||||
auto entry_func = mod->Lookup(mod->entry_func);
|
||||
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
|
||||
return ValueToExpr(executor_(expr));
|
||||
}
|
||||
// Evaluate shape_of op
|
||||
|
@ -213,9 +217,6 @@ Expr FoldConstant(const Expr& expr) {
|
|||
Module(nullptr), ctx, target)).Mutate(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
|
||||
.set_body_typed(FoldConstant);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass FoldConstant() {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
* conv/dense operators.
|
||||
*/
|
||||
#include <tvm/data_layout.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
@ -545,10 +545,6 @@ Expr ForwardFoldScaleAxis(const Expr& data) {
|
|||
data, "FScaleAxisForwardRewrite", fcontext);
|
||||
}
|
||||
|
||||
// Expose the FoldScaleAxisFoward
|
||||
TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis")
|
||||
.set_body_typed<Expr(Expr)>(ForwardFoldScaleAxis);
|
||||
|
||||
//----------------------------------------
|
||||
// Implement backward transformations.
|
||||
//----------------------------------------
|
||||
|
@ -947,9 +943,6 @@ Expr BackwardFoldScaleAxis(const Expr& data) {
|
|||
return make_node<BackwardTransformerNode>()->Fold(data);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis")
|
||||
.set_body_typed<Expr(Expr)>(BackwardFoldScaleAxis);
|
||||
|
||||
} // namespace fold_scale_axis
|
||||
|
||||
namespace transform {
|
||||
|
@ -964,6 +957,9 @@ Pass ForwardFoldScaleAxis() {
|
|||
{ir::StringImm::make("InferType")});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis")
|
||||
.set_body_typed(ForwardFoldScaleAxis);
|
||||
|
||||
Pass BackwardFoldScaleAxis() {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
|
@ -974,6 +970,9 @@ Pass BackwardFoldScaleAxis() {
|
|||
{ir::StringImm::make("InferType")});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis")
|
||||
.set_body_typed(BackwardFoldScaleAxis);
|
||||
|
||||
Pass FoldScaleAxis() {
|
||||
// FoldScaleAxis pass contains the following three passes. Therefore, we can
|
||||
// register it as a sequential pass.
|
||||
|
|
|
@ -23,9 +23,9 @@
|
|||
* \file forward_rewrite.cc
|
||||
* \brief Apply rewriting rules in a forward fashion.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include "pass_util.h"
|
||||
|
||||
namespace tvm {
|
||||
|
@ -206,37 +206,5 @@ Expr ForwardRewrite(const Expr& expr,
|
|||
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
|
||||
}
|
||||
|
||||
namespace transform {
|
||||
|
||||
using std::function;
|
||||
|
||||
Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
|
||||
function<NodeRef(const Call&)> fcontext,
|
||||
function<Expr(const Expr&)> fmulti_ref_trigger) {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
return Downcast<Function>(ForwardRewrite(f,
|
||||
rewrite_map_attr_name,
|
||||
fcontext,
|
||||
fmulti_ref_trigger));
|
||||
};
|
||||
return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {});
|
||||
}
|
||||
|
||||
Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
|
||||
function<NodeRef(const Call&)> fcontext,
|
||||
function<Expr(const Expr&)> fmulti_ref_trigger) {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
return Downcast<Function>(ForwardRewrite(f,
|
||||
rewrite_func,
|
||||
fcontext,
|
||||
fmulti_ref_trigger));
|
||||
};
|
||||
return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {});
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
* Fuse necessary ops into a single one.
|
||||
*/
|
||||
#include <tvm/expr_operator.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
@ -963,9 +963,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
|
|||
}
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.FuseOps")
|
||||
.set_body_typed(FuseOps);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass FuseOps(int fuse_opt_level) {
|
||||
|
|
|
@ -26,7 +26,8 @@
|
|||
#include <tvm/lowered_func.h>
|
||||
#include <tvm/operation.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include "pattern_util.h"
|
||||
#include "let_list.h"
|
||||
#include "../ir/type_functor.h"
|
||||
|
@ -246,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
|
|||
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
|
||||
TVM_REGISTER_API("relay._analysis.first_order_gradient")
|
||||
.set_body_typed(FirstOrderGradient);
|
||||
|
||||
struct ReverseADType : TypeMutator {
|
||||
|
@ -351,7 +352,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
|
|||
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.gradient")
|
||||
TVM_REGISTER_API("relay._transform.gradient")
|
||||
.set_body_typed(Gradient);
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
* We check this by ensuring the `dtype` field of a Tensor always
|
||||
* contains a data type such as `int`, `float`, `uint`.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/error.h>
|
||||
#include "../ir/type_functor.h"
|
||||
|
||||
|
@ -183,7 +183,7 @@ Kind KindCheck(const Type& t, const Module& mod) {
|
|||
return kc.Check(t);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.check_kind")
|
||||
TVM_REGISTER_API("relay._analysis.check_kind")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
if (args.size() == 1) {
|
||||
*ret = KindCheck(args[0], ModuleNode::make({}, {}));
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include <tvm/relay/op.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/data_layout.h>
|
||||
#include "pattern_util.h"
|
||||
|
||||
|
@ -188,7 +188,7 @@ int64_t GetTotalMacNumber(const Expr& expr) {
|
|||
return MacCounter::GetTotalMacNumber(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
|
||||
TVM_REGISTER_API("relay._analysis.GetTotalMacNumber")
|
||||
.set_body_typed(GetTotalMacNumber);
|
||||
|
||||
} // namespace mac_count
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
#include <tvm/relay/error.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <stack>
|
||||
|
||||
namespace tvm {
|
||||
|
@ -236,15 +235,15 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
|
|||
}
|
||||
|
||||
// expose for testing only
|
||||
TVM_REGISTER_API("relay._ir_pass.unmatched_cases")
|
||||
.set_body_typed<Array<Pattern>(const Match&,
|
||||
const Module&)>([](const Match& match,
|
||||
const Module& mod_ref) {
|
||||
Module call_mod = mod_ref;
|
||||
if (!call_mod.defined()) {
|
||||
call_mod = ModuleNode::make({}, {});
|
||||
}
|
||||
return UnmatchedCases(match, call_mod);
|
||||
});
|
||||
TVM_REGISTER_API("relay._analysis.unmatched_cases")
|
||||
.set_body_typed<Array<Pattern>(const Match&, const Module&)>(
|
||||
[](const Match& match, const Module& mod_ref) {
|
||||
Module call_mod = mod_ref;
|
||||
if (!call_mod.defined()) {
|
||||
call_mod = ModuleNode::make({}, {});
|
||||
}
|
||||
return UnmatchedCases(match, call_mod);
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -91,7 +91,8 @@
|
|||
*
|
||||
* These assumptions do not affect the correctness of the algorithm, however.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/relay/interpreter.h>
|
||||
|
@ -740,9 +741,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
|
||||
// Constant evaluate a expression.
|
||||
PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
|
||||
Expr infered = InferType(expr, Module(nullptr));
|
||||
Expr fused = FuseOps(infered, 0, Module(nullptr));
|
||||
Expr fused_infered = InferType(fused, Module(nullptr));
|
||||
std::vector<transform::Pass> passes = {transform::FuseOps(0),
|
||||
transform::InferType()};
|
||||
auto mod = ModuleNode::FromExpr(expr);
|
||||
auto seq = transform::Sequential(passes);
|
||||
mod = seq(mod);
|
||||
auto entry_func = mod->Lookup(mod->entry_func);
|
||||
auto fused_infered =
|
||||
expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
|
||||
return Reify(executor_(fused_infered), ll);
|
||||
}
|
||||
|
||||
|
|
|
@ -573,18 +573,6 @@ class PassContext::Internal {
|
|||
}
|
||||
};
|
||||
|
||||
Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes) {
|
||||
auto mod = ModuleNode::FromExpr(expr);
|
||||
Sequential seq(passes);
|
||||
auto pass_ctx = PassContext::Create();
|
||||
pass_ctx->opt_level = 3;
|
||||
tvm::With<PassContext> ctx_scope(pass_ctx);
|
||||
mod = seq(mod);
|
||||
CHECK(mod.defined());
|
||||
auto entry_func = mod->Lookup(mod->entry_func);
|
||||
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._transform.GetCurrentPassContext")
|
||||
.set_body_typed(PassContext::Current);
|
||||
|
||||
|
@ -594,9 +582,6 @@ TVM_REGISTER_API("relay._transform.EnterPassContext")
|
|||
TVM_REGISTER_API("relay._transform.ExitPassContext")
|
||||
.set_body_typed(PassContext::Internal::ExitScope);
|
||||
|
||||
TVM_REGISTER_API("relay._transform.OptimizeOnExpr")
|
||||
.set_body_typed(OptimizeOnExpr);
|
||||
|
||||
} // namespace transform
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -27,9 +27,10 @@
|
|||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <tvm/base.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -259,6 +260,13 @@ Expr QuantizeRealize(const Call& ref_call,
|
|||
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
|
||||
}
|
||||
|
||||
Expr FoldConstantOpt(const Expr& expr) {
|
||||
auto mod = ModuleNode::FromExpr(expr);
|
||||
mod = transform::FoldConstant()(mod);
|
||||
auto entry_func = mod->Lookup(mod->entry_func);
|
||||
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
|
||||
}
|
||||
|
||||
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
|
||||
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
|
||||
|
||||
|
@ -290,7 +298,8 @@ Expr Conv2dRealize(const Call& ref_call,
|
|||
|
||||
Expr ret = CallNode::make(ref_call->op,
|
||||
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
|
||||
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
|
||||
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
|
||||
Expr dom_scale = FoldConstantOpt(mul);
|
||||
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
|
||||
}
|
||||
|
||||
|
@ -323,7 +332,8 @@ Expr DenseRealize(const Call& ref_call,
|
|||
|
||||
Expr ret = CallNode::make(ref_call->op,
|
||||
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
|
||||
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
|
||||
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
|
||||
Expr dom_scale = FoldConstantOpt(mul);
|
||||
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
|
||||
}
|
||||
|
||||
|
@ -356,7 +366,8 @@ Expr MulRealize(const Call& ref_call,
|
|||
}
|
||||
|
||||
Expr ret = ForwardOp(ref_call, {ldata, rdata});
|
||||
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
|
||||
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
|
||||
Expr dom_scale = FoldConstantOpt(mul);
|
||||
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
|
||||
}
|
||||
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
* Copyright (c) 2018 by Contributors
|
||||
* \file simplify_inference.cc
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
@ -103,9 +103,6 @@ Expr SimplifyInference(const Expr& e) {
|
|||
return InferenceSimplifier().Mutate(e);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.simplify_inference")
|
||||
.set_body_typed(SimplifyInference);
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass SimplifyInference() {
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
*
|
||||
* \brief Turn implicit sharing into observable sharing.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
*
|
||||
* \brief Turn A normal form into graph normal form.
|
||||
*/
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include "let_list.h"
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#include <tvm/relay/error.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include "./pass_util.h"
|
||||
#include "type_solver.h"
|
||||
|
@ -813,11 +813,6 @@ Function InferType(const Function& func,
|
|||
return Downcast<Function>(func_ret);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.infer_type")
|
||||
.set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
|
||||
return InferType(expr, mod_ref);
|
||||
});
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass InferType() {
|
||||
|
|
|
@ -512,7 +512,7 @@ bool TypeSolver::Solve() {
|
|||
}
|
||||
|
||||
// Expose type solver only for debugging purposes.
|
||||
TVM_REGISTER_API("relay._ir_pass._test_type_solver")
|
||||
TVM_REGISTER_API("relay._analysis._test_type_solver")
|
||||
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
|
||||
using runtime::PackedFunc;
|
||||
using runtime::TypedPackedFunc;
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/error.h>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
*
|
||||
* \brief Utility functions for Relay.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include "pass_util.h"
|
||||
|
@ -274,10 +274,10 @@ tvm::Array<Var> AllVars(const Expr& expr) {
|
|||
return VarVisitor().All(expr);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.free_vars")
|
||||
TVM_REGISTER_API("relay._analysis.free_vars")
|
||||
.set_body_typed(FreeVars);
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.bound_vars")
|
||||
TVM_REGISTER_API("relay._analysis.bound_vars")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
NodeRef x = args[0];
|
||||
if (x.as_derived<ExprNode>()) {
|
||||
|
@ -287,10 +287,10 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars")
|
|||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.all_vars")
|
||||
TVM_REGISTER_API("relay._analysis.all_vars")
|
||||
.set_body_typed(AllVars);
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
|
||||
TVM_REGISTER_API("relay._analysis.free_type_vars")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
NodeRef x = args[0];
|
||||
Module mod = args[1];
|
||||
|
@ -301,7 +301,7 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
|
|||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
|
||||
TVM_REGISTER_API("relay._analysis.bound_type_vars")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
NodeRef x = args[0];
|
||||
Module mod = args[1];
|
||||
|
@ -312,7 +312,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
|
|||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.all_type_vars")
|
||||
TVM_REGISTER_API("relay._analysis.all_type_vars")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
NodeRef x = args[0];
|
||||
Module mod = args[1];
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
* \file well_formed.cc
|
||||
* \brief check that expression is well formed.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <unordered_set>
|
||||
|
@ -78,7 +78,7 @@ bool WellFormed(const Expr& e) {
|
|||
return WellFormedChecker().CheckWellFormed(e);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.well_formed")
|
||||
TVM_REGISTER_API("relay._analysis.well_formed")
|
||||
.set_body_typed(WellFormed);
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
#include <tvm/tvm.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <topi/generic/injective.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <tvm/runtime/module.h>
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
#include <tvm/tvm.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
|
||||
TEST(Relay, SelfReference) {
|
||||
using namespace tvm;
|
||||
|
@ -32,10 +33,9 @@ TEST(Relay, SelfReference) {
|
|||
auto y = relay::VarNode::make("y", tensor_type);
|
||||
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
|
||||
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
|
||||
auto empty_module =
|
||||
relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{},
|
||||
Map<relay::GlobalTypeVar, relay::TypeData>{});
|
||||
auto type_fx = relay::InferType(fx, empty_module);
|
||||
auto mod = relay::ModuleNode::FromExpr(fx);
|
||||
mod = relay::transform::InferType()(mod);
|
||||
auto type_fx = mod->Lookup(mod->entry_func);
|
||||
|
||||
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
|
||||
CHECK(AlphaEqual(type_fx->checked_type(), expected));
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <tvm/packed_func_ext.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/module.h>
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/analysis.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include <tvm/relay/type.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
|
@ -100,7 +100,9 @@ TEST(Relay, Sequential) {
|
|||
relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {});
|
||||
|
||||
// Infer type for the expected function.
|
||||
auto expected = relay::InferType(expected_func, relay::Module(nullptr));
|
||||
auto mod1 = relay::ModuleNode::FromExpr(expected_func);
|
||||
mod1 = relay::transform::InferType()(mod1);
|
||||
auto expected = mod1->Lookup(mod1->entry_func);
|
||||
CHECK(relay::AlphaEqual(f, expected));
|
||||
}
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ def get_net(batch_size, image_shape, num_classes, dtype):
|
|||
net = relay.nn.relu(net)
|
||||
net = relay.nn.global_avg_pool2d(net)
|
||||
net = relay.nn.softmax(net, axis=1)
|
||||
args = relay.ir_pass.free_vars(net)
|
||||
args = relay.analysis.free_vars(net)
|
||||
return relay.Function(args, net)
|
||||
|
||||
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
# under the License.
|
||||
"""Test graph equality of caffe2 models."""
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from model_zoo import c2_squeezenet, relay_squeezenet
|
||||
|
||||
|
||||
def compare_graph(f1, f2):
|
||||
f1 = relay.ir_pass.infer_type(f1)
|
||||
f2 = relay.ir_pass.infer_type(f2)
|
||||
assert relay.ir_pass.alpha_equal(f1, f2)
|
||||
def compare_graph(lhs_mod, func):
|
||||
rhs_mod = relay.Module.from_expr(func)
|
||||
rhs_mod = transform.InferType()(rhs_mod)
|
||||
assert relay.analysis.alpha_equal(lhs_mod[lhs_mod.entry_func],
|
||||
rhs_mod[rhs_mod.entry_func])
|
||||
|
||||
|
||||
def test_squeeze_net():
|
||||
|
@ -31,7 +33,7 @@ def test_squeeze_net():
|
|||
mod, _, = relay.frontend.from_caffe2(
|
||||
c2_squeezenet.init_net, c2_squeezenet.predict_net, shape_dict, dtype_dict)
|
||||
relay_func, _ = relay_squeezenet()
|
||||
compare_graph(mod[mod.entry_func], relay_func)
|
||||
compare_graph(mod, relay_func)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -16,12 +16,11 @@
|
|||
# under the License.
|
||||
import mxnet as mx
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
import model_zoo
|
||||
|
||||
def compare_graph(f1, f2):
|
||||
f1 = relay.ir_pass.infer_type(f1)
|
||||
f2 = relay.ir_pass.infer_type(f2)
|
||||
assert relay.ir_pass.alpha_equal(f1, f2)
|
||||
assert relay.analysis.alpha_equal(f1, f2)
|
||||
|
||||
def test_mlp():
|
||||
shape = {"data": (1, 1, 28, 28)}
|
||||
|
@ -97,7 +96,10 @@ def test_multi_outputs():
|
|||
y = F.var("y", shape=yshape)
|
||||
z = F.split(x, **kwargs)
|
||||
z = F.subtract(F.add(z[0], z[2]), y)
|
||||
return relay.Function(relay.ir_pass.free_vars(z), z)
|
||||
func = relay.Function(relay.analysis.free_vars(z), z)
|
||||
mod = relay.Module.from_expr(func)
|
||||
mod = transform.InferType()(mod)
|
||||
return mod[mod.entry_func]
|
||||
|
||||
mx_sym = mx_compose(mx, num_outputs=3, axis=1)
|
||||
mod, _ = relay.frontend.from_mxnet(
|
||||
|
|
|
@ -20,7 +20,8 @@ import nnvm
|
|||
|
||||
from tvm import relay
|
||||
from tvm import autotvm
|
||||
from tvm.relay.ir_pass import infer_type, alpha_equal
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.analysis import alpha_equal
|
||||
|
||||
|
||||
def test_alter_layout_conv2d():
|
||||
|
@ -57,12 +58,11 @@ def test_alter_layout_conv2d():
|
|||
n15 = relay.reshape(n14, newshape=[1, 1, 3, 3, 224, 224])
|
||||
n16 = relay.transpose(n15, axes=[0, 1, 4, 2, 5, 3])
|
||||
net = relay.reshape(n16, newshape=[1, 1, 672, 672])
|
||||
args = relay.ir_pass.free_vars(net)
|
||||
args = relay.analysis.free_vars(net)
|
||||
return relay.Function(args, net)
|
||||
|
||||
# orig net
|
||||
N = convnet()
|
||||
N = infer_type(N)
|
||||
|
||||
# trigger a test
|
||||
# for each known alter_conv2d
|
||||
|
@ -75,11 +75,12 @@ def test_alter_layout_conv2d():
|
|||
for tgt in targets:
|
||||
with tvm.target.create(tgt) as target:
|
||||
with autotvm.tophub.context(target):
|
||||
O = relay.ir_pass.alter_op_layout(N)
|
||||
O = relay.ir_pass.infer_type(O)
|
||||
mod = relay.Module.from_expr(N)
|
||||
mod = transform.AlterOpLayout()(mod)
|
||||
O = mod[mod.entry_func]
|
||||
|
||||
# graph should differ
|
||||
assert not relay.ir_pass.alpha_equal(N, O)
|
||||
assert not relay.analysis.alpha_equal(N, O)
|
||||
|
||||
if __name__ == "__main__":
|
||||
np.random.seed(42)
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
||||
from tvm.relay import testing, create_executor
|
||||
from tvm.relay.backend.interpreter import ConstructorValue
|
||||
from tvm.relay import create_executor
|
||||
from tvm.relay.prelude import Prelude
|
||||
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
|
||||
|
||||
|
@ -125,8 +123,14 @@ def test_nat_value():
|
|||
|
||||
|
||||
def test_nat_constructor():
|
||||
assert relay.ir_pass.infer_type(z(), mod).checked_type == nat()
|
||||
assert relay.ir_pass.infer_type(s(z()), mod).checked_type == nat()
|
||||
func = relay.Function([], z())
|
||||
test_z = relay.GlobalVar("test_z")
|
||||
mod[test_z] = func
|
||||
assert mod[test_z].body.checked_type == nat()
|
||||
test_sz = relay.GlobalVar("test_sz")
|
||||
func = relay.Function([], s(z()))
|
||||
mod[test_sz] = func
|
||||
assert mod[test_sz].body.checked_type == nat()
|
||||
|
||||
|
||||
def test_double():
|
||||
|
@ -142,8 +146,10 @@ def test_add():
|
|||
|
||||
|
||||
def test_list_constructor():
|
||||
a = relay.TypeVar("a")
|
||||
assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat())
|
||||
test_consz = relay.GlobalVar("test_consz")
|
||||
func = relay.Function([], cons(z(), nil()))
|
||||
mod[test_consz] = func
|
||||
assert mod[test_consz].body.checked_type == l(nat())
|
||||
|
||||
def test_hd_tl():
|
||||
expected = list(range(10))
|
||||
|
|
|
@ -26,8 +26,10 @@ def test_compile_engine():
|
|||
x = relay.var("x", shape=shape)
|
||||
y = relay.add(x, x)
|
||||
z = relay.add(y, x)
|
||||
f = relay.ir_pass.infer_type(relay.Function([x], z))
|
||||
return f
|
||||
f = relay.Function([x], z)
|
||||
mod = relay.Module.from_expr(f)
|
||||
mod = relay.transform.InferType()(mod)
|
||||
return mod[mod.entry_func]
|
||||
z1 = engine.lower(get_func((10,)), "llvm")
|
||||
z2 = engine.lower(get_func((10,)), "llvm")
|
||||
z3 = engine.lower(get_func(()), "llvm")
|
||||
|
@ -55,7 +57,7 @@ def test_compile_placeholder_bypass():
|
|||
y = relay.var("y", shape=(2, 3))
|
||||
z = relay.var("z", shape=(2, 3))
|
||||
result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)])
|
||||
func = relay.Function(relay.ir_pass.free_vars(result), result)
|
||||
func = relay.Function(relay.analysis.free_vars(result), result)
|
||||
with relay.build_config(opt_level=0):
|
||||
graph, lib, params = relay.build(relay.Module.from_expr(func), 'llvm')
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
|||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.contrib import graph_runtime
|
||||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm.relay.scope_builder import ScopeBuilder
|
||||
from tvm.relay.op import add
|
||||
from tvm.relay.module import Module
|
||||
|
@ -124,9 +123,9 @@ def test_plan_memory():
|
|||
z = relay.exp(z)
|
||||
z = relay.exp(z)
|
||||
func = relay.Function([x, y], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = relay.ir_pass.fuse_ops(func, opt_level=0)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
mod = relay.Module.from_expr(func)
|
||||
mod = relay.transform.FuseOps(0)(mod)
|
||||
func = mod[mod.entry_func]
|
||||
smap = relay.backend._backend.GraphPlanMemory(func)
|
||||
storage_ids = set()
|
||||
device_types = set()
|
||||
|
|
|
@ -227,7 +227,7 @@ def test_tuple_passing():
|
|||
gv = relay.GlobalVar('fn')
|
||||
mod[gv] = fn
|
||||
mod.entry_func = gv
|
||||
mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod)
|
||||
mod = relay.transform.InferType()(mod)
|
||||
|
||||
ctx = tvm.cpu()
|
||||
target = tvm.target.create('llvm')
|
||||
|
|
|
@ -19,7 +19,10 @@ from tvm import relay
|
|||
|
||||
def check_type_err(expr, msg):
|
||||
try:
|
||||
expr = relay.ir_pass.infer_type(expr)
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = relay.transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
expr = entry if isinstance(expr, relay.Function) else entry.body
|
||||
assert False
|
||||
except tvm.TVMError as err:
|
||||
assert msg in str(err)
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import detect_feature, gradient
|
||||
from tvm.relay.analysis import detect_feature
|
||||
from tvm.relay.transform import gradient
|
||||
from tvm.relay.feature import Feature
|
||||
from tvm.relay.prelude import Prelude
|
||||
|
||||
|
@ -46,7 +47,9 @@ def test_ad():
|
|||
t = relay.TensorType(shape, dtype)
|
||||
x = relay.var("x", t)
|
||||
func = relay.Function([x], x + x)
|
||||
back_func = relay.ir_pass.infer_type(gradient(func))
|
||||
mod = relay.Module.from_expr(gradient(func))
|
||||
mod = relay.transform.InferType()(mod)
|
||||
back_func = mod[mod.entry_func]
|
||||
feats = detect_feature(back_func)
|
||||
assert feats == set([
|
||||
Feature.fVar,
|
||||
|
|
|
@ -28,11 +28,11 @@ def test_bind_params():
|
|||
fexpected =relay.Function(
|
||||
[y],
|
||||
relay.add(relay.const(1, "float32"), y))
|
||||
assert relay.ir_pass.alpha_equal(fbinded, fexpected)
|
||||
assert relay.analysis.alpha_equal(fbinded, fexpected)
|
||||
|
||||
zbinded = relay.bind(z, {y: x})
|
||||
zexpected = relay.add(x, x)
|
||||
assert relay.ir_pass.alpha_equal(zbinded, zexpected)
|
||||
assert relay.analysis.alpha_equal(zbinded, zexpected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -19,7 +19,7 @@ import tvm
|
|||
from tvm import relay
|
||||
from tvm.expr import *
|
||||
from tvm.relay import op
|
||||
from tvm.relay.ir_pass import graph_equal
|
||||
from tvm.relay.analysis import graph_equal
|
||||
|
||||
|
||||
def check_json_roundtrip(node):
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import alpha_equal
|
||||
from tvm.relay.analysis import alpha_equal
|
||||
from nose.tools import nottest, raises
|
||||
from numpy import isclose
|
||||
from typing import Union
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import well_formed
|
||||
from tvm.relay.analysis import well_formed
|
||||
from tvm.relay.prelude import Prelude
|
||||
|
||||
def test_let():
|
||||
|
|
|
@ -14,16 +14,24 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import tvm
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import gradient, infer_type
|
||||
from tvm.relay.transform import gradient
|
||||
from tvm.relay.testing import ctx_list
|
||||
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = relay.transform.InferType()(mod)
|
||||
return mod[mod.entry_func]
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
one = np.ones_like(x)
|
||||
return one / (one + np.exp(-x))
|
||||
|
||||
|
||||
def relu(x):
|
||||
x_copy = np.copy(x)
|
||||
np.maximum(x_copy, 0, x_copy)
|
||||
|
@ -41,7 +49,7 @@ def test_unary_op():
|
|||
data = np.random.rand(*shape).astype(dtype)
|
||||
ref_grad = ref(data)
|
||||
fwd_func = relay.Function([x], y)
|
||||
bwd_func = infer_type(gradient(fwd_func))
|
||||
bwd_func = run_infer_type(gradient(fwd_func))
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
intrp = relay.create_executor(ctx=ctx, target=target)
|
||||
|
@ -73,7 +81,7 @@ def test_binary_op():
|
|||
y_data = np.random.rand(*s).astype(t.dtype)
|
||||
ref_grad0, ref_grad1 = ref(x_data, y_data)
|
||||
fwd_func = relay.Function([x, y], z)
|
||||
bwd_func = infer_type(gradient(fwd_func))
|
||||
bwd_func = run_infer_type(gradient(fwd_func))
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
intrp = relay.create_executor(ctx=ctx, target=target)
|
||||
|
|
|
@ -14,13 +14,19 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import math
|
||||
import tvm
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
import topi.testing
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
def sigmoid(x):
|
||||
one = np.ones_like(x)
|
||||
return one / (one + np.exp(-x))
|
||||
|
@ -44,7 +50,8 @@ def test_unary_op():
|
|||
# test printer
|
||||
assert ("{}(%x)".format(y.op.name)) in y.astext()
|
||||
# test type inference
|
||||
assert relay.ir_pass.infer_type(y).checked_type == tp
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == tp
|
||||
|
||||
if ref is not None:
|
||||
data = np.random.rand(*shape).astype(dtype)
|
||||
|
@ -84,7 +91,8 @@ def test_binary_op():
|
|||
z = opfunc(x, y)
|
||||
# test printer
|
||||
assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
|
||||
assert relay.ir_pass.infer_type(z).checked_type == t1
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == t1
|
||||
|
||||
if ref is not None:
|
||||
t1 = relay.TensorType((5, 10, 5))
|
||||
|
@ -134,7 +142,7 @@ def test_bias_add():
|
|||
x = relay.var("x", shape=xshape)
|
||||
bias = relay.var("bias")
|
||||
z = relay.nn.bias_add(x, bias)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert "axis=" not in zz.astext()
|
||||
assert zz.args[1].checked_type == relay.TensorType(bshape)
|
||||
|
||||
|
@ -153,8 +161,8 @@ def test_expand_dims_infer_type():
|
|||
x = relay.var("x", shape=(n, t, d))
|
||||
y = relay.expand_dims(x, axis=2)
|
||||
assert "axis=2" in y.astext()
|
||||
checked = relay.ir_pass.infer_type(y)
|
||||
assert checked.checked_type == relay.TensorType((n, t, 1, 100))
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, t, 1, 100))
|
||||
|
||||
|
||||
def test_softmax():
|
||||
|
@ -162,7 +170,7 @@ def test_softmax():
|
|||
x = relay.var("x", shape=shape)
|
||||
y = relay.nn.softmax(x, axis=1)
|
||||
assert "nn.softmax" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(shape)
|
||||
func = relay.Function([x], y)
|
||||
x_data = np.random.uniform(size=shape).astype("float32")
|
||||
|
@ -178,7 +186,7 @@ def test_log_softmax():
|
|||
x = relay.var("x", shape=shape)
|
||||
y = relay.nn.log_softmax(x, axis=1)
|
||||
assert "nn.log_softmax" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(shape)
|
||||
func = relay.Function([x], y)
|
||||
x_data = np.random.uniform(size=shape).astype("float32")
|
||||
|
@ -195,16 +203,16 @@ def test_concatenate():
|
|||
y = relay.var("y", shape=(n, t, d))
|
||||
z = relay.concatenate((x, y), axis=-1)
|
||||
assert "axis=" in z.astext()
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((n, t, 200))
|
||||
|
||||
x = relay.exp(x)
|
||||
z = relay.concatenate((x, y), axis=2)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((n, t, 200))
|
||||
|
||||
z = relay.concatenate((x, y), axis=1)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((n, t + t, 100))
|
||||
|
||||
x = relay.var("x", shape=(10, 5))
|
||||
|
@ -233,7 +241,7 @@ def test_dropout():
|
|||
x = relay.var("x", input_ty)
|
||||
y = relay.nn.dropout(x, rate=0.75)
|
||||
assert "rate=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == input_ty
|
||||
|
||||
|
||||
|
@ -246,7 +254,7 @@ def test_batch_norm():
|
|||
moving_var = relay.var("moving_var", relay.TensorType((2,)))
|
||||
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
|
||||
center=False, scale=False)
|
||||
yy = relay.ir_pass.infer_type(y.astuple())
|
||||
yy = run_infer_type(y.astuple())
|
||||
assert "center=" in yy.astext()
|
||||
assert yy.checked_type == relay.ty.TupleType(tvm.convert([
|
||||
relay.TensorType((3, 2, 1), "float32"),
|
||||
|
@ -261,7 +269,7 @@ def test_batch_norm():
|
|||
|
||||
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
|
||||
axis=0, center=False, scale=False)
|
||||
yy = relay.ir_pass.infer_type(y.astuple())
|
||||
yy = run_infer_type(y.astuple())
|
||||
assert yy.checked_type == relay.ty.TupleType(tvm.convert([
|
||||
relay.ty.TensorType((3, 2, 1), "float32"),
|
||||
relay.ty.TensorType((3,), "float32"),
|
||||
|
@ -276,7 +284,7 @@ def test_batch_norm():
|
|||
moving_var = relay.var("moving_var", relay.TensorType((3,)))
|
||||
y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var,
|
||||
axis=-1, center=False, scale=False)
|
||||
yy = relay.ir_pass.infer_type(y.astuple())
|
||||
yy = run_infer_type(y.astuple())
|
||||
assert yy.checked_type == relay.ty.TupleType(tvm.convert([
|
||||
relay.ty.TensorType((1, 2, 3), "float32"),
|
||||
relay.ty.TensorType((3,), "float32"),
|
||||
|
@ -290,7 +298,7 @@ def test_dense():
|
|||
w = relay.var("w", relay.TensorType((2, w), "float32"))
|
||||
y = relay.nn.dense(x, w, units=2)
|
||||
"units=2" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32")
|
||||
|
||||
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2
|
||||
|
@ -298,14 +306,14 @@ def test_dense():
|
|||
wh, ww = tvm.var("wh"), tvm.var("ww")
|
||||
w = relay.var("w", relay.TensorType((ww, wh), "float32"))
|
||||
y = relay.nn.dense(x, w)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, h, ww), "float32")
|
||||
|
||||
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), 2
|
||||
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
|
||||
w = relay.var("w", relay.IncompleteType())
|
||||
y = relay.nn.dense(x, w, units=2)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, h, 2), "float32")
|
||||
|
||||
x = relay.var("x", shape=(10, 5))
|
||||
|
|
|
@ -20,10 +20,17 @@ import numpy as np
|
|||
import tvm
|
||||
import topi.testing
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
import topi
|
||||
import topi.testing
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
def test_collapse_sum_like():
|
||||
shape = (3, 4, 5, 6)
|
||||
shape_like = (4, 5, 6)
|
||||
|
@ -31,7 +38,7 @@ def test_collapse_sum_like():
|
|||
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
|
||||
y = relay.Var("y", relay.ty.TensorType(shape_like, dtype))
|
||||
z = relay.collapse_sum_like(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
|
||||
|
||||
func = relay.Function([x, y], z)
|
||||
|
@ -50,7 +57,7 @@ def test_broadcast_to():
|
|||
dtype = "float32"
|
||||
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
|
||||
z = relay.broadcast_to(x, shape=shape_like)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
|
||||
|
||||
func = relay.Function([x], z)
|
||||
|
@ -69,7 +76,7 @@ def test_broadcast_to_like():
|
|||
x = relay.Var("x", relay.ty.TensorType(shape , dtype))
|
||||
y = relay.Var("y", relay.ty.TensorType(shape_like, dtype))
|
||||
z = relay.broadcast_to_like(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.ty.TensorType(shape_like, dtype)
|
||||
|
||||
func = relay.Function([x, y], z)
|
||||
|
@ -106,7 +113,7 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"):
|
|||
x = relay.var("data", relay.TensorType(data, dtype))
|
||||
y = relay.var("slice_like", relay.TensorType(slice_like, dtype))
|
||||
z = relay.slice_like(x, y, axes)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
if axes:
|
||||
assert "axes" in z.astext()
|
||||
assert zz.checked_type == relay.ty.TensorType(output, dtype)
|
||||
|
@ -144,7 +151,7 @@ def test_reverse_reshape():
|
|||
def verify_reverse_reshape(shape, newshape, oshape):
|
||||
x = relay.var("x", relay.TensorType(shape, "float32"))
|
||||
z = relay.reverse_reshape(x, newshape=newshape)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert "newshape=" in z.astext()
|
||||
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
|
||||
|
||||
|
@ -166,7 +173,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
|
|||
x = relay.var("x", relay.TensorType(x_shape, dtype))
|
||||
y = relay.var("y", relay.TensorType(y_shape, dtype))
|
||||
z = relay.nn.batch_matmul(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)
|
||||
|
||||
func = relay.Function([x, y], z)
|
||||
|
@ -185,7 +192,7 @@ def test_batch_matmul():
|
|||
x = relay.var("x", relay.TensorType((b, m, k), "float32"))
|
||||
y = relay.var("y", relay.TensorType((b, n, k), "float32"))
|
||||
z = relay.nn.batch_matmul(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((b, m, n), "float32")
|
||||
|
||||
verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
|
||||
|
@ -197,7 +204,7 @@ def test_shape_of():
|
|||
shape = (10, 5, 12)
|
||||
x = relay.var("x", shape=shape)
|
||||
func = relay.Function([x], relay.op.shape_of(x))
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
x_data = np.random.rand(*shape).astype('float32')
|
||||
for target, ctx in ctx_list():
|
||||
# Because using graph executor, this op will be optimized after
|
||||
|
@ -256,7 +263,8 @@ def test_sequence_mask():
|
|||
data = relay.var("data", relay.TensorType(data_shape, dtype))
|
||||
valid_length = relay.var("valid_length", relay.TensorType((nbatch,), itype))
|
||||
out = relay.sequence_mask(data, valid_length, mask_value, axis)
|
||||
assert relay.ir_pass.infer_type(out).checked_type == relay.ty.TensorType(data_shape, dtype)
|
||||
checked = run_infer_type(out)
|
||||
assert checked.checked_type == relay.ty.TensorType(data_shape, dtype)
|
||||
func = relay.Function([data, valid_length], out)
|
||||
data_np = np.random.uniform(size=data_shape).astype(dtype)
|
||||
valid_length_np = np.random.randint(0, max_length, size=nbatch).astype(itype)
|
||||
|
|
|
@ -16,12 +16,19 @@
|
|||
# under the License.
|
||||
""" Support level2 operator test cases.
|
||||
"""
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
import numpy as np
|
||||
import topi.testing
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
def test_conv2d_infer_type():
|
||||
# symbolic in batch dimension
|
||||
n, c, h, w = tvm.var("n"), 10, 224, 224
|
||||
|
@ -31,7 +38,7 @@ def test_conv2d_infer_type():
|
|||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
channels=2)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, 2, 224, 224), "float32")
|
||||
assert yy.args[1].checked_type == relay.TensorType(
|
||||
|
@ -44,7 +51,7 @@ def test_conv2d_infer_type():
|
|||
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
|
||||
y = relay.nn.conv2d(x, w, out_dtype="int32")
|
||||
assert "out_dtype=\"int32\"" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, 2, 222, 222), "int32")
|
||||
|
||||
|
@ -59,7 +66,7 @@ def test_conv2d_infer_type():
|
|||
data_layout="NCHW4n4c",
|
||||
kernel_layout="OIHW4o4i",
|
||||
out_dtype="int32")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(1, 4, 224, 224, 4, 4), "int32")
|
||||
assert yy.args[1].checked_type == relay.TensorType(
|
||||
|
@ -75,7 +82,7 @@ def test_conv2d_infer_type():
|
|||
channels=16,
|
||||
data_layout="NHWC",
|
||||
out_dtype="int32")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, h, w, 16), "int32")
|
||||
|
||||
|
@ -169,7 +176,7 @@ def test_conv2d_transpose_infer_type():
|
|||
padding=(1, 1),
|
||||
channels=15)
|
||||
assert "channels=15" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, 15, 10, 12), "float32")
|
||||
assert yy.args[1].checked_type == relay.TensorType(
|
||||
|
@ -183,7 +190,7 @@ def test_conv2d_transpose_infer_type():
|
|||
output_padding=(1, 1),
|
||||
channels=11,
|
||||
data_layout="NHWC")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, 15, 15, 11), "float32")
|
||||
|
||||
|
@ -219,12 +226,12 @@ def test_upsampling_infer_type():
|
|||
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")
|
||||
"method=\"BINLINEAR\"" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, h*2, w*2), "float32")
|
||||
n, c = tvm.var("n"), tvm.var("c")
|
||||
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
|
||||
y = relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
|
||||
|
||||
|
||||
|
@ -233,7 +240,7 @@ def _test_pool2d(opfunc, reffunc):
|
|||
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = opfunc(x, pool_size=(1, 1))
|
||||
assert "pool_size=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32")
|
||||
# test execution
|
||||
dtype = "float32"
|
||||
|
@ -253,13 +260,13 @@ def _test_global_pool2d(opfunc, reffunc):
|
|||
n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224
|
||||
x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
|
||||
y = opfunc(x, layout="NHWC")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, 1, 1, c), "float32")
|
||||
|
||||
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
|
||||
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = opfunc(x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, 1, 1), "float32")
|
||||
# test execution
|
||||
dtype = "float32"
|
||||
|
@ -320,17 +327,17 @@ def test_flatten_infer_type():
|
|||
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
|
||||
x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32"))
|
||||
y = relay.nn.batch_flatten(x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((d1, ((d2*d3)*d4)), "float32")
|
||||
|
||||
x = relay.var("x", relay.TensorType((3, 2, 4, 3), "float32"))
|
||||
y = relay.nn.batch_flatten(x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((3, 24), "float32")
|
||||
|
||||
x = relay.var("x", relay.TensorType((d1, 2, d3, 3), "float32"))
|
||||
y = relay.nn.batch_flatten(x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((d1, ((2*d3)*3)), "float32")
|
||||
|
||||
shape = (1, 5, 10, 10)
|
||||
|
@ -338,7 +345,7 @@ def test_flatten_infer_type():
|
|||
dtype = "float32"
|
||||
x = relay.var("x", relay.TensorType(shape, dtype))
|
||||
z = relay.nn.batch_flatten(x)
|
||||
yy = relay.ir_pass.infer_type(z)
|
||||
yy = run_infer_type(z)
|
||||
assert yy.checked_type == relay.TensorType(o_shape, dtype)
|
||||
func = relay.Function([x], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
|
||||
|
@ -358,14 +365,14 @@ def test_pad_infer_type():
|
|||
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))
|
||||
"pad_width=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((3, 6, 9, 12), "float32")
|
||||
|
||||
# some symbolic values
|
||||
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
|
||||
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32")
|
||||
|
||||
def test_pad_run():
|
||||
|
@ -389,7 +396,7 @@ def test_lrn():
|
|||
x = relay.var("x", shape=(n, c , h, w))
|
||||
y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75)
|
||||
"alpha=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c , h, w))
|
||||
|
||||
shape = (1, 5, 10, 10)
|
||||
|
@ -401,7 +408,7 @@ def test_lrn():
|
|||
alpha=.00001
|
||||
beta=0.75
|
||||
z = relay.nn.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
|
||||
yy = relay.ir_pass.infer_type(z)
|
||||
yy = run_infer_type(z)
|
||||
assert yy.checked_type == relay.TensorType(shape, dtype)
|
||||
func = relay.Function([x], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
|
||||
|
@ -420,7 +427,7 @@ def test_l2_normalize():
|
|||
x = relay.var("x", shape=(n, c , h, w))
|
||||
y = relay.nn.l2_normalize(x, eps=0.001, axis=[1])
|
||||
"axis=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c , h, w))
|
||||
|
||||
shape = (1, 5, 10, 10)
|
||||
|
@ -429,7 +436,7 @@ def test_l2_normalize():
|
|||
eps=0.001
|
||||
axis=1
|
||||
z = relay.nn.l2_normalize(x, eps=0.001, axis=[axis])
|
||||
yy = relay.ir_pass.infer_type(z)
|
||||
yy = run_infer_type(z)
|
||||
assert yy.checked_type == relay.TensorType(shape, dtype)
|
||||
func = relay.Function([x], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
|
||||
|
@ -477,7 +484,7 @@ def _test_upsampling(layout, method):
|
|||
ishape, oshape = get_shape()
|
||||
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
|
||||
y = relay.nn.upsampling(x, scale=scale, layout=layout, method=method)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
|
||||
dshape = (1,) + ishape
|
||||
x = relay.var("x", shape=dshape)
|
||||
|
|
|
@ -16,17 +16,23 @@
|
|||
# under the License.
|
||||
""" Support level3 operator test cases.
|
||||
"""
|
||||
import tvm
|
||||
import numpy as np
|
||||
from tvm import relay
|
||||
from tvm.relay import create_executor
|
||||
from tvm.relay.testing import ctx_list
|
||||
from nose.tools import raises
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay import create_executor, transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
def test_zeros_ones():
|
||||
for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
|
||||
y = op(shape=(124, 50), dtype="float64")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((124, 50), "float64")
|
||||
intrp = create_executor()
|
||||
intrp_res = intrp.evaluate(y).asnumpy()
|
||||
|
@ -46,7 +52,7 @@ def test_unary_identity():
|
|||
shape = (8, 9, 4)
|
||||
x = relay.var("x", relay.TensorType(shape, "float32"))
|
||||
y = op(x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(shape, "float32")
|
||||
|
||||
if ref is not None:
|
||||
|
@ -59,20 +65,20 @@ def test_unary_identity():
|
|||
def test_cast():
|
||||
x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
|
||||
y = x.astype("int32")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert "dtype=" in yy.astext()
|
||||
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
|
||||
|
||||
x = relay.var("x", relay.TensorType((8, 9, 4), "float32"))
|
||||
y = relay.cast(x, "int32")
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert "dtype=" in yy.astext()
|
||||
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
|
||||
|
||||
def test_clip():
|
||||
a = relay.var("a", relay.TensorType((10, 4), "float32"))
|
||||
y = relay.clip(a, 1., 4.)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((10, 4), "float32")
|
||||
|
||||
data = np.random.rand(10, 4).astype('float32')
|
||||
|
@ -105,13 +111,13 @@ def test_transpose_infer_type():
|
|||
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
|
||||
y = relay.transpose(x, axes=(1, 0, 2))
|
||||
assert "axes=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(t, n, 100), "float32")
|
||||
|
||||
y = relay.transpose(x)
|
||||
assert "axes=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(100, t, n), "float32")
|
||||
|
||||
|
@ -138,7 +144,7 @@ def test_squeeze_infer_type():
|
|||
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
|
||||
y = relay.squeeze(x, axis=(2,))
|
||||
assert "axis=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(1, 4), "float32")
|
||||
|
||||
|
@ -146,7 +152,7 @@ def test_squeeze_infer_type():
|
|||
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
|
||||
y = relay.squeeze(x)
|
||||
assert "axis=" not in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(4,), "float32")
|
||||
|
||||
|
@ -156,7 +162,7 @@ def test_squeeze_bad_axes_infer_type():
|
|||
n, t, d = 1, 4, 1
|
||||
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
|
||||
y = relay.squeeze(x, axis=(1,))
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
|
||||
|
||||
def test_reshape_infer_type():
|
||||
|
@ -164,7 +170,7 @@ def test_reshape_infer_type():
|
|||
x = relay.var("x", relay.TensorType((n, t, d1, d2), "float32"))
|
||||
y = relay.reshape(x, newshape=(n, t, 2000))
|
||||
assert "newshape=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(
|
||||
(n, t, 2000), "float32")
|
||||
|
||||
|
@ -172,7 +178,7 @@ def test_reshape():
|
|||
def verify_reshape(shape, newshape, oshape):
|
||||
x = relay.var("x", relay.TensorType(shape, "float32"))
|
||||
z = relay.reshape(x, newshape=newshape)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert "newshape=" in z.astext()
|
||||
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
|
||||
|
||||
|
@ -205,7 +211,7 @@ def test_reshape_like_infer_type():
|
|||
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
|
||||
y = relay.var("y", relay.TensorType((1,6), "float32"))
|
||||
z = relay.reshape_like(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((1, 6), "float32")
|
||||
|
||||
# symbolic shape
|
||||
|
@ -213,7 +219,7 @@ def test_reshape_like_infer_type():
|
|||
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = relay.var("y", relay.TensorType((1, 8, 8), "float32"))
|
||||
z = relay.reshape_like(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")
|
||||
|
||||
|
||||
|
@ -226,7 +232,7 @@ def test_reshape_like():
|
|||
x = relay.var("x", relay.TensorType(shape, "float32"))
|
||||
y = relay.var("x", relay.TensorType(oshape, "float32"))
|
||||
z = relay.reshape_like(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")
|
||||
|
||||
func = relay.Function([x, y], z)
|
||||
|
@ -245,8 +251,7 @@ def test_take_infer_type():
|
|||
x = relay.var("x", relay.TensorType(dshape, "float32"))
|
||||
indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
|
||||
y = relay.take(x, indices, axis=axis)
|
||||
y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(oshape, "float32")
|
||||
|
||||
d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3")
|
||||
|
@ -301,8 +306,7 @@ def test_split_infer_type():
|
|||
def verify_split(dshape, indices_or_sections, ret_type, axis=None):
|
||||
x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
|
||||
y = relay.split(x, indices_or_sections, axis=axis)
|
||||
y.astext()
|
||||
yy = relay.ir_pass.infer_type(y.astuple())
|
||||
yy = run_infer_type(y.astuple())
|
||||
assert yy.checked_type == ret_type
|
||||
|
||||
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
|
||||
|
@ -347,14 +351,14 @@ def test_full_infer_type():
|
|||
# default settings: match input dtype
|
||||
x = relay.var("x", relay.TensorType((), "int8"))
|
||||
y = relay.full(x, ())
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((), "int8")
|
||||
|
||||
# change the shape and dtype
|
||||
x = relay.var("x", relay.TensorType((), "float32"))
|
||||
y = relay.full(x, (1, 2), "int8")
|
||||
"shape=" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((1, 2), "int8")
|
||||
|
||||
|
||||
|
@ -378,7 +382,7 @@ def test_full_like_infer_type():
|
|||
base = relay.var("base", relay.TensorType((1, 2, 3), "float32"))
|
||||
fill = relay.var("fill", relay.TensorType((), "float32"))
|
||||
y = relay.full_like(base, fill)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((1, 2, 3), "float32")
|
||||
|
||||
# symbolic shape
|
||||
|
@ -386,7 +390,7 @@ def test_full_like_infer_type():
|
|||
base = relay.var("base", relay.TensorType((n, c, h, w), "float32"))
|
||||
fill = relay.var("fill", relay.TensorType((), "float32"))
|
||||
y = relay.full_like(base, fill)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
|
||||
|
||||
|
||||
|
@ -414,7 +418,7 @@ def test_infer_type_leaky_relu():
|
|||
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
|
||||
y = relay.nn.leaky_relu(x, alpha=0.1)
|
||||
"alpha=0.1" in y.astext()
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
|
||||
|
||||
shape = (1, 5, 10, 10)
|
||||
|
@ -422,8 +426,8 @@ def test_infer_type_leaky_relu():
|
|||
x = relay.var("x", relay.TensorType(shape, dtype))
|
||||
z = relay.nn.leaky_relu(x, alpha=0.1)
|
||||
assert "alpha=0.1" in z.astext()
|
||||
yy = relay.ir_pass.infer_type(z)
|
||||
assert yy.checked_type == relay.TensorType(shape, dtype)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType(shape, dtype)
|
||||
func = relay.Function([x], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
|
||||
ref_res = np.where(x_data > 0, x_data, x_data * 0.1)
|
||||
|
@ -443,7 +447,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
|
|||
else:
|
||||
y = relay.var("alpha", relay.IncompleteType())
|
||||
z = relay.nn.prelu(x, y, axis=axis)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
if axis != 1:
|
||||
assert "axis" in z.astext()
|
||||
assert zz.checked_type == relay.ty.TensorType(output, dtype)
|
||||
|
@ -577,7 +581,7 @@ def test_reverse():
|
|||
def verify_reverse(dshape, axis):
|
||||
x = relay.var("x", relay.TensorType(dshape, "float32"))
|
||||
z = relay.reverse(x, axis=axis)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
|
||||
func = relay.Function([x], z)
|
||||
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
|
||||
|
|
|
@ -17,9 +17,16 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
import topi.testing
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
def test_binary_op():
|
||||
def check_binary_op(opfunc, ref):
|
||||
n = tvm.var("n")
|
||||
|
@ -30,7 +37,8 @@ def test_binary_op():
|
|||
z = opfunc(x, y)
|
||||
# test printer
|
||||
assert ("{}(%x, %y)".format(z.op.name)) in z.astext()
|
||||
assert relay.ir_pass.infer_type(z).checked_type == t1
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == t1
|
||||
|
||||
if ref is not None:
|
||||
t1 = relay.TensorType((5, 10, 5))
|
||||
|
@ -62,8 +70,7 @@ def test_cmp_type():
|
|||
x = relay.var("x", relay.TensorType((10, 4), "float32"))
|
||||
y = relay.var("y", relay.TensorType((5, 10, 1), "float32"))
|
||||
z = op(x, y)
|
||||
z.astext()
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((5, 10, 4), "bool")
|
||||
|
||||
if ref is not None:
|
||||
|
@ -94,7 +101,7 @@ def test_binary_int_broadcast():
|
|||
x = relay.var("x", relay.TensorType((10, 4), "int32"))
|
||||
y = relay.var("y", relay.TensorType((5, 10, 1), "int32"))
|
||||
z = op(x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((5, 10, 4), "int32")
|
||||
|
||||
if ref is not None:
|
||||
|
@ -120,7 +127,7 @@ def test_where():
|
|||
x = relay.var("x", relay.TensorType(shape, dtype))
|
||||
y = relay.var("y", relay.TensorType(shape, dtype))
|
||||
z = relay.where(cond, x, y)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType(shape, dtype)
|
||||
|
||||
func = relay.Function([cond, x, y], z)
|
||||
|
@ -142,7 +149,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
|
|||
|
||||
x = relay.var("x", relay.TensorType(data, dtype))
|
||||
z = test_func(x, axis, keepdims, exclude)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
if axis:
|
||||
assert "axis=" in z.astext()
|
||||
if keepdims:
|
||||
|
@ -224,7 +231,7 @@ def test_strided_slice():
|
|||
x = relay.var("x", relay.TensorType(dshape, "float32"))
|
||||
z = relay.strided_slice(x, begin=begin, end=end, strides=strides)
|
||||
func = relay.Function([x], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
text = func.astext()
|
||||
assert "begin=" in text
|
||||
assert "end=" in text
|
||||
|
|
|
@ -20,21 +20,28 @@ import math
|
|||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
import topi.testing
|
||||
|
||||
def run_infer_type(expr):
|
||||
mod = relay.Module.from_expr(expr)
|
||||
mod = transform.InferType()(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
def test_resize_infer_type():
|
||||
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
|
||||
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
|
||||
th, tw = tvm.var("th"), tvm.var("tw")
|
||||
z = relay.image.resize(x, (th, tw))
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
|
||||
|
||||
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
|
||||
z= relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False)
|
||||
assert "size=" in z.astext()
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
|
||||
|
||||
def test_resize():
|
||||
|
@ -52,7 +59,7 @@ def test_resize():
|
|||
x = relay.var("x", relay.TensorType(dshape, "float32"))
|
||||
z = relay.image.resize(x, size, layout, method, False)
|
||||
assert "size=" in z.astext()
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
|
||||
func = relay.Function([x], z)
|
||||
|
||||
|
@ -109,7 +116,7 @@ def test_multibox_prior():
|
|||
check_type_only=False):
|
||||
|
||||
z = relay.vision.multibox_prior(x, sizes, ratios, steps, offsets, clip)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
if check_size:
|
||||
assert "sizes=" in z.astext()
|
||||
assert zz.checked_type == relay.TensorType(
|
||||
|
@ -121,7 +128,7 @@ def test_multibox_prior():
|
|||
|
||||
data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
|
||||
func = relay.Function([x], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
for target, ctx in ctx_list():
|
||||
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
|
||||
op_res1 = intrp1.evaluate(func)(data)
|
||||
|
@ -176,7 +183,7 @@ def test_get_valid_counts():
|
|||
z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index)
|
||||
assert "score_threshold" in z.astext()
|
||||
func = relay.Function([x], z.astuple())
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
for target, ctx in ctx_list():
|
||||
if target == 'cuda':
|
||||
return
|
||||
|
@ -205,8 +212,8 @@ def test_non_max_suppression():
|
|||
top_k = top_k)
|
||||
assert "iou_threshold" in z.astext()
|
||||
assert "iou_threshold" in z_indices.astext()
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz_indices = relay.ir_pass.infer_type(z_indices)
|
||||
zz = run_infer_type(z)
|
||||
zz_indices = run_infer_type(z_indices)
|
||||
assert zz.checked_type == relay.ty.TensorType(dshape, "float32")
|
||||
assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32")
|
||||
|
||||
|
@ -214,9 +221,9 @@ def test_non_max_suppression():
|
|||
return
|
||||
|
||||
func = relay.Function([x0, x1], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
func_indices = relay.Function([x0, x1], z_indices)
|
||||
func_indices = relay.ir_pass.infer_type(func_indices)
|
||||
func_indices = run_infer_type(func_indices)
|
||||
for target, ctx in ctx_list():
|
||||
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
|
||||
op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
|
||||
|
@ -288,7 +295,7 @@ def test_multibox_transform_loc():
|
|||
|
||||
mtl = relay.vision.multibox_transform_loc(
|
||||
cls_prob=cls_prob, loc_pred=loc_pred, anchor=anchors)
|
||||
ret = relay.ir_pass.infer_type(mtl.astuple())
|
||||
ret = run_infer_type(mtl.astuple())
|
||||
ref_type = relay.ty.TupleType(
|
||||
tvm.convert([
|
||||
relay.ty.TensorType((1, num_anchors, 6), "float32"),
|
||||
|
@ -299,7 +306,7 @@ def test_multibox_transform_loc():
|
|||
|
||||
nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
|
||||
func = relay.Function([cls_prob, loc_pred, anchors], nms)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
for target, ctx in ctx_list():
|
||||
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
|
||||
op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds,
|
||||
|
@ -330,7 +337,7 @@ def test_multibox_transform_loc():
|
|||
anchor=anchors,
|
||||
threshold=threshold,
|
||||
variances=variances)
|
||||
ret = relay.ir_pass.infer_type(ret.astuple())
|
||||
ret = run_infer_type(ret.astuple())
|
||||
ref_type = relay.ty.TupleType(
|
||||
tvm.convert([
|
||||
relay.ty.TensorType((n, num_anchors, 6), "float32"),
|
||||
|
@ -349,15 +356,14 @@ def test_roi_align():
|
|||
z = relay.vision.roi_align(data, rois, pooled_size=(pooled_size, pooled_size),
|
||||
spatial_scale=spatial_scale, sample_ratio=sample_ratio,
|
||||
layout="NCHW")
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
|
||||
zz = run_infer_type(z)
|
||||
batch, channel, in_size, _ = data_shape
|
||||
num_roi = rois_shape[0]
|
||||
assert zz.checked_type == relay.ty.TensorType(
|
||||
(num_roi, channel, pooled_size, pooled_size), "float32")
|
||||
|
||||
func = relay.Function([data, rois], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
np_data = np.random.uniform(size=data_shape).astype("float32")
|
||||
np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
|
||||
np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi)
|
||||
|
@ -382,15 +388,14 @@ def test_roi_pool():
|
|||
rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32"))
|
||||
z = relay.vision.roi_pool(data, rois, pooled_size=(pooled_size, pooled_size),
|
||||
spatial_scale=spatial_scale, layout="NCHW")
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
|
||||
zz = run_infer_type(z)
|
||||
batch, channel, in_size, _ = data_shape
|
||||
num_roi = rois_shape[0]
|
||||
assert zz.checked_type == relay.ty.TensorType(
|
||||
(num_roi, channel, pooled_size, pooled_size), "float32")
|
||||
|
||||
func = relay.Function([data, rois], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
np_data = np.random.uniform(size=data_shape).astype("float32")
|
||||
np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
|
||||
np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi).astype('float32')
|
||||
|
@ -414,12 +419,11 @@ def test_proposal():
|
|||
bbox_pred = relay.var("bbox_pred", relay.ty.TensorType(np_bbox_pred.shape, "float32"))
|
||||
im_info = relay.var("im_info", relay.ty.TensorType(np_im_info.shape, "float32"))
|
||||
z = relay.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
|
||||
zz = run_infer_type(z)
|
||||
assert zz.checked_type == relay.ty.TensorType(np_out.shape, "float32")
|
||||
|
||||
func = relay.Function([cls_prob, bbox_pred, im_info], z)
|
||||
func = relay.ir_pass.infer_type(func)
|
||||
func = run_infer_type(func)
|
||||
for target in ['cuda']:
|
||||
if not tvm.module.enabled(target):
|
||||
print("Skip test because %s is not enabled." % target)
|
||||
|
@ -478,7 +482,7 @@ def test_yolo_reorg_infer_shape():
|
|||
def verify_yolo_reorg(shape, stride, out_shape):
|
||||
x = relay.var("x", relay.TensorType(shape, "float32"))
|
||||
z = relay.vision.yolo_reorg(x, stride=stride)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert "stride=" in z.astext()
|
||||
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
|
||||
|
||||
|
@ -493,7 +497,7 @@ def test_yolo_reorg():
|
|||
|
||||
x = relay.var("x", relay.TensorType(shape, "float32"))
|
||||
z = relay.vision.yolo_reorg(x, stride=stride)
|
||||
zz = relay.ir_pass.infer_type(z)
|
||||
zz = run_infer_type(z)
|
||||
assert "stride=" in z.astext()
|
||||
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")
|
||||
|
||||
|
@ -527,7 +531,7 @@ def test_deformable_conv2d():
|
|||
weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1])
|
||||
out_shape = (batch, out_channel, size, size)
|
||||
offset_shape = (batch, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, out_shape[2], out_shape[3])
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
yy = run_infer_type(y)
|
||||
assert yy.checked_type == relay.TensorType(out_shape)
|
||||
assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type
|
||||
assert yy.args[2].checked_type == relay.TensorType(weight_shape)
|
||||
|
|
|
@ -14,17 +14,17 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import tvm
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay import ir_pass
|
||||
from tvm.relay import analysis
|
||||
|
||||
def alpha_equal(x, y):
|
||||
"""
|
||||
Wrapper around alpha equality which ensures that
|
||||
the hash function respects equality.
|
||||
"""
|
||||
return ir_pass.alpha_equal(x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y)
|
||||
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
|
||||
|
||||
def test_tensor_type_alpha_equal():
|
||||
t1 = relay.TensorType((3, 4), "float32")
|
||||
|
@ -604,14 +604,14 @@ def test_hash_unequal():
|
|||
y2 = relay.var("y2", shape=(10, 10), dtype="float32")
|
||||
func2 = relay.Function([x2, y2], relay.add(x2, y2))
|
||||
|
||||
assert ir_pass.structural_hash(func1) == ir_pass.structural_hash(func2)
|
||||
assert analysis.structural_hash(func1) == analysis.structural_hash(func2)
|
||||
|
||||
# func3 is same as func1 but with different var shapes
|
||||
x3 = relay.var("x3", shape=(20, 10), dtype="float32")
|
||||
y3 = relay.var("y3", shape=(20, 10), dtype="float32")
|
||||
func3 = relay.Function([x3, y3], relay.add(x3, y3))
|
||||
|
||||
assert not ir_pass.structural_hash(func1) == ir_pass.structural_hash(func3)
|
||||
assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tensor_type_alpha_equal()
|
||||
|
|
|
@ -19,7 +19,18 @@ import tvm
|
|||
|
||||
from tvm import relay
|
||||
from tvm.relay.op import register_alter_op_layout
|
||||
from tvm.relay.ir_pass import *
|
||||
from tvm.relay import transform, analysis
|
||||
|
||||
|
||||
def run_opt_pass(expr, passes):
|
||||
passes = passes if isinstance(passes, list) else [passes]
|
||||
mod = relay.Module.from_expr(expr)
|
||||
seq = transform.Sequential(passes)
|
||||
with transform.PassContext(opt_level=3):
|
||||
mod = seq(mod)
|
||||
entry = mod[mod.entry_func]
|
||||
return entry if isinstance(expr, relay.Function) else entry.body
|
||||
|
||||
|
||||
def test_alter_op():
|
||||
"""Test directly replacing an operator with a new one"""
|
||||
|
@ -52,13 +63,10 @@ def test_alter_op():
|
|||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = run_opt_pass(a, transform.AlterOpLayout())
|
||||
b = run_opt_pass(expected(), transform.InferType())
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
|
||||
def test_alter_return_none():
|
||||
|
@ -77,12 +85,11 @@ def test_alter_return_none():
|
|||
return None
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = run_opt_pass(a, transform.AlterOpLayout())
|
||||
|
||||
b = before()
|
||||
b = infer_type(b)
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert(called[0])
|
||||
|
||||
|
||||
|
@ -102,7 +109,7 @@ def test_alter_layout():
|
|||
y = relay.nn.max_pool2d(y, pool_size=(2, 2))
|
||||
y = relay.cast(y, 'int32')
|
||||
y = relay.nn.batch_flatten(y)
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=102)
|
||||
|
@ -135,20 +142,17 @@ def test_alter_layout():
|
|||
y = relay.cast(y, 'int32')
|
||||
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
y = relay.nn.batch_flatten(y)
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(),
|
||||
transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
|
||||
def test_alter_layout_dual_path():
|
||||
|
@ -172,7 +176,7 @@ def test_alter_layout_dual_path():
|
|||
y1 = relay.nn.relu(y1)
|
||||
y2 = relay.nn.batch_flatten(y)
|
||||
ret = relay.Tuple([y1, y2])
|
||||
y = relay.Function(free_vars(ret), ret)
|
||||
y = relay.Function(analysis.free_vars(ret), ret)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=103)
|
||||
|
@ -203,18 +207,16 @@ def test_alter_layout_dual_path():
|
|||
y2 = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
y2 = relay.nn.batch_flatten(y2)
|
||||
ret = relay.Tuple([y1, y2])
|
||||
y = relay.Function(free_vars(ret), ret)
|
||||
y = relay.Function(analysis.free_vars(ret), ret)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, transform.AlterOpLayout())
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
def test_alter_layout_resnet():
|
||||
"""Test alternating the layout of a residual block
|
||||
|
@ -236,7 +238,7 @@ def test_alter_layout_resnet():
|
|||
y2 = relay.nn.relu(y2)
|
||||
y = y + y2
|
||||
y = relay.nn.global_max_pool2d(y)
|
||||
return relay.Function(free_vars(y), y)
|
||||
return relay.Function(analysis.free_vars(y), y)
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=104)
|
||||
def alter_conv2d(attrs, inputs, tinfos):
|
||||
|
@ -264,17 +266,15 @@ def test_alter_layout_resnet():
|
|||
y = y + y2
|
||||
y = relay.nn.global_max_pool2d(y, layout="NCHW16c")
|
||||
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
return relay.Function(free_vars(y), y)
|
||||
return relay.Function(analysis.free_vars(y), y)
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, transform.AlterOpLayout())
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
|
||||
def test_alter_layout_broadcast_op():
|
||||
|
@ -287,7 +287,7 @@ def test_alter_layout_broadcast_op():
|
|||
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
|
||||
y = relay.nn.bias_add(y, bias) # test broadcasting to lhs
|
||||
y = relay.multiply(scale, y) # test broadcasting to rhs
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=105)
|
||||
|
@ -311,20 +311,17 @@ def test_alter_layout_broadcast_op():
|
|||
y = relay.add(y, bias) # test broadcasting to lhs
|
||||
y = relay.multiply(scale, y) # test broadcasting to rhs
|
||||
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(),
|
||||
transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
def test_alter_layout_scalar():
|
||||
"""Test alternating the layout of a conv2d.
|
||||
|
@ -335,7 +332,7 @@ def test_alter_layout_scalar():
|
|||
weight = relay.var("weight")
|
||||
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
|
||||
y = relay.add(y, relay.const(1, "float32"))
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=106)
|
||||
|
@ -358,20 +355,17 @@ def test_alter_layout_scalar():
|
|||
y = relay.add(y, relay.const(1.0, "float32"))
|
||||
|
||||
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(),
|
||||
transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
def test_alter_layout_concatenate():
|
||||
""" """
|
||||
|
@ -388,7 +382,7 @@ def test_alter_layout_concatenate():
|
|||
kernel_size=(3, 3),
|
||||
padding=(1, 1))
|
||||
ret = relay.concatenate([y, y1], axis=1)
|
||||
y = relay.Function(free_vars(ret), ret)
|
||||
y = relay.Function(analysis.free_vars(ret), ret)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=107)
|
||||
|
@ -415,18 +409,16 @@ def test_alter_layout_concatenate():
|
|||
data_layout='NCHW16c')
|
||||
ret = relay.concatenate([y, y1], axis=1)
|
||||
ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
|
||||
y = relay.Function(free_vars(ret), ret)
|
||||
y = relay.Function(analysis.free_vars(ret), ret)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, transform.AlterOpLayout())
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
|
||||
def test_alter_layout_nchw_upsamping_op():
|
||||
|
@ -437,7 +429,7 @@ def test_alter_layout_nchw_upsamping_op():
|
|||
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
|
||||
y = relay.nn.upsampling(y, scale=2)
|
||||
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=108)
|
||||
|
@ -456,21 +448,17 @@ def test_alter_layout_nchw_upsamping_op():
|
|||
y = relay.nn.upsampling(y, scale=2, layout="NCHW16c")
|
||||
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
|
||||
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(),
|
||||
transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
|
||||
def test_alter_layout_strided_slice():
|
||||
|
@ -480,7 +468,7 @@ def test_alter_layout_strided_slice():
|
|||
weight = relay.var('weight', shape=(32, 32, 3, 3))
|
||||
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
|
||||
y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=109)
|
||||
|
@ -498,21 +486,17 @@ def test_alter_layout_strided_slice():
|
|||
data_layout="NCHW4c")
|
||||
y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
|
||||
y = relay.layout_transform(y, "NCHW4c", "NCHW")
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(),
|
||||
transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
|
||||
|
||||
def test_alter_layout_depthwise_conv2d():
|
||||
"""Test depthwise_conv2d operator"""
|
||||
|
@ -520,7 +504,7 @@ def test_alter_layout_depthwise_conv2d():
|
|||
x = relay.var("x", shape=(1, 32, 56, 56))
|
||||
w = relay.var("w", shape=(32, 1, 3, 3))
|
||||
y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32)
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
import topi
|
||||
|
@ -538,20 +522,17 @@ def test_alter_layout_depthwise_conv2d():
|
|||
groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o",
|
||||
out_layout="NCHW8c")
|
||||
y = relay.layout_transform(y, "NCHW8c", "NCHW")
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(),
|
||||
transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert(alpha_equal(a, b))
|
||||
assert(analysis.alpha_equal(a, b))
|
||||
|
||||
def test_alter_layout_prelu():
|
||||
"""Test PRelu operator"""
|
||||
|
@ -561,7 +542,7 @@ def test_alter_layout_prelu():
|
|||
alpha = relay.var("alpha", relay.IncompleteType())
|
||||
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
|
||||
y = relay.nn.prelu(y, alpha)
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
@register_alter_op_layout("nn.conv2d", level=111)
|
||||
|
@ -584,20 +565,16 @@ def test_alter_layout_prelu():
|
|||
data_layout="NCHW16c")
|
||||
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||
y = relay.nn.prelu(y, alpha)
|
||||
y = relay.Function(free_vars(y), y)
|
||||
y = relay.Function(analysis.free_vars(y), y)
|
||||
return y
|
||||
|
||||
a = before()
|
||||
a = infer_type(a)
|
||||
a = canonicalize_ops(a)
|
||||
a = infer_type(a)
|
||||
a = alter_op_layout(a)
|
||||
a = infer_type(a)
|
||||
a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()])
|
||||
|
||||
b = expected()
|
||||
b = infer_type(b)
|
||||
b = run_opt_pass(b, transform.InferType())
|
||||
|
||||
assert(alpha_equal(a, b))
|
||||
assert(analysis.alpha_equal(a, b))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче