[RELAY] Pass infra cleanup (#3336)
This commit is contained in:
Родитель
d6c4aba837
Коммит
c9a2f3da5b
|
@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
|
||||||
v->Visit("required", &required);
|
v->Visit("required", &required);
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_DLL static PassInfo make(int opt_level, std::string name,
|
TVM_DLL static PassInfo make(int opt_level,
|
||||||
|
std::string name,
|
||||||
tvm::Array<tvm::Expr> required);
|
tvm::Array<tvm::Expr> required);
|
||||||
|
|
||||||
static constexpr const char* _type_key = "relay.PassInfo";
|
static constexpr const char* _type_key = "relay.PassInfo";
|
||||||
|
@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
|
||||||
* type information filled in, as well as it's checked type field
|
* type information filled in, as well as it's checked type field
|
||||||
* populated with the result type.
|
* populated with the result type.
|
||||||
*
|
*
|
||||||
* \return The pass.
|
* \return The pass.
|
||||||
*/
|
*/
|
||||||
TVM_DLL Pass InferType();
|
TVM_DLL Pass InferType();
|
||||||
|
|
||||||
|
|
|
@ -14,13 +14,9 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
# pylint: disable=no-else-return
|
|
||||||
# pylint: disable=unidiomatic-typecheck
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""
|
"""
|
||||||
This file contains the pass manager for Relay which exposes different
|
Relay pass transformation infrastructure.
|
||||||
granularity of interfaces for users to implement and use passes more
|
|
||||||
conveniently.
|
|
||||||
"""
|
"""
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
@ -39,19 +35,19 @@ class PassInfo(RelayNode):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
name : str
|
|
||||||
The pass name.
|
|
||||||
|
|
||||||
opt_level : int
|
opt_level : int
|
||||||
The optimization level of this pass.
|
The optimization level of this pass.
|
||||||
|
|
||||||
|
name : str
|
||||||
|
The pass name.
|
||||||
|
|
||||||
required : List[str]
|
required : List[str]
|
||||||
The list of passes that are required by a certain pass.
|
The list of passes that are required by a certain pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, opt_level, required=None):
|
def __init__(self, opt_level, name, required=None):
|
||||||
self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
|
self.__init_handle_by_constructor__(
|
||||||
required)
|
_transform.PassInfo, opt_level, name, required)
|
||||||
|
|
||||||
|
|
||||||
@register_relay_node
|
@register_relay_node
|
||||||
|
@ -194,7 +190,7 @@ class ModulePass(Pass):
|
||||||
`module_pass`, because the design of the `module_pass` API is flexible
|
`module_pass`, because the design of the `module_pass` API is flexible
|
||||||
enough to handle the creation of a module pass in different manners. In
|
enough to handle the creation of a module pass in different manners. In
|
||||||
addition, all members of a module pass can be accessed from the base class.
|
addition, all members of a module pass can be accessed from the base class.
|
||||||
The same rule applies to FunctionPass and Sequential as well.
|
The same rule applies to FunctionPass as well.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,153 +246,6 @@ class Sequential(Pass):
|
||||||
passes, opt_level, name, required)
|
passes, opt_level, name, required)
|
||||||
|
|
||||||
|
|
||||||
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
|
|
||||||
"""Create a module pass. This function returns a callback when pass_func
|
|
||||||
is provided. Otherwise, it returns the created module level pass using the
|
|
||||||
given optimization function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
|
||||||
Module/Function]]
|
|
||||||
The implemented optimization pass.
|
|
||||||
|
|
||||||
opt_level : int
|
|
||||||
The optimization level of this module pass.
|
|
||||||
|
|
||||||
name : Optional[str]
|
|
||||||
The name of the module pass. The name could be empty. In this case, the
|
|
||||||
name of the optimization function will be used as the pass name.
|
|
||||||
|
|
||||||
required : Optional[List[str]]
|
|
||||||
The list of passes that the module pass is dependent on.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
create_module_pass : Union[Callable, ModulePass]
|
|
||||||
The callable that will create a module pass is returned when
|
|
||||||
pass_func is not passed in. Otherwise, a ModulePass object will be
|
|
||||||
directly created.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
The following code creates a module level pass and adds an abs function to
|
|
||||||
the module.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
@relay.transform.module_pass(opt_level=2)
|
|
||||||
def transform(mod, ctx):
|
|
||||||
tp = relay.TensorType((10,), "float32")
|
|
||||||
x = relay.var("x", tp)
|
|
||||||
gv = relay.GlobalVar("var")
|
|
||||||
func = relay.Function([x], relay.abs(x))
|
|
||||||
new_mod = relay.Module({gv: func})
|
|
||||||
new_mod.update(mod)
|
|
||||||
return new_mod
|
|
||||||
|
|
||||||
module_pass = transform
|
|
||||||
assert isinstance(module_pass, transform.ModulePass)
|
|
||||||
assert module_pass.info.opt_level == 2
|
|
||||||
|
|
||||||
# Given a module m, the optimization could be invoked as the follwoing:
|
|
||||||
updated_mod = module_pass(m)
|
|
||||||
# Now a function abs should be added to the module m.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if opt_level is None:
|
|
||||||
raise ValueError("Please provide opt_level for the module pass.")
|
|
||||||
|
|
||||||
required = required if required else []
|
|
||||||
if not isinstance(required, (list, tuple)):
|
|
||||||
raise TypeError("Required is expected to be the type of " +
|
|
||||||
"list/tuple.")
|
|
||||||
|
|
||||||
def create_module_pass(pass_func):
|
|
||||||
"""Internal function that creates a module pass"""
|
|
||||||
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
|
||||||
raise TypeError("pass_func must be a callable for Module pass")
|
|
||||||
|
|
||||||
return _transform.CreateModulePass(
|
|
||||||
pass_func, opt_level, name if name else pass_func.__name__,
|
|
||||||
required)
|
|
||||||
|
|
||||||
if pass_func:
|
|
||||||
return create_module_pass(pass_func)
|
|
||||||
return create_module_pass
|
|
||||||
|
|
||||||
|
|
||||||
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
|
|
||||||
"""Create a function pass. This function returns a callback when pass_func
|
|
||||||
is provided. Otherwise, it returns the created function pass using the
|
|
||||||
given optimization function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
|
||||||
Module/Function]]
|
|
||||||
The implemented optimization pass.
|
|
||||||
|
|
||||||
opt_level : int
|
|
||||||
The optimization level of this module pass.
|
|
||||||
|
|
||||||
name : Optional[str]
|
|
||||||
The name of the function pass. The name could be empty. In this case, the
|
|
||||||
name of the optimization function will be used as the pass name.
|
|
||||||
|
|
||||||
required : Optional[List[str]]
|
|
||||||
The list of passes that the module pass is dependent on.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
create_function_pass : Union[Callable, FunctionPass]
|
|
||||||
The callable that will create a function pass is returned when
|
|
||||||
pass_func is not passed in. Otherwise, a FunctionPass object will be
|
|
||||||
created.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
The following code creates a function level pass that performs constant
|
|
||||||
folding.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
@relay.transform.function_pass(opt_level=2)
|
|
||||||
def transform(func, ctx):
|
|
||||||
return ir_pass.fold_constant(func)
|
|
||||||
|
|
||||||
function_pass = transform
|
|
||||||
assert isinstance(function_pass, transform.FunctionPass)
|
|
||||||
assert function_pass.info.opt_level == 2
|
|
||||||
|
|
||||||
# Given a module m, the optimization could be invoked as the follwoing:
|
|
||||||
updated_mod = function_pass(m)
|
|
||||||
# Now constant folding should have been applied to every function in
|
|
||||||
# the provided module m. And the updated module will be returned.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if opt_level is None:
|
|
||||||
raise ValueError("Please provide opt_level for the funtion pass.")
|
|
||||||
|
|
||||||
required = required if required else []
|
|
||||||
if not isinstance(required, (list, tuple)):
|
|
||||||
raise TypeError("Required is expected to be the type of " +
|
|
||||||
"list/tuple.")
|
|
||||||
|
|
||||||
def create_function_pass(pass_func):
|
|
||||||
"""Internal function that creates a function pass"""
|
|
||||||
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
|
||||||
raise TypeError("pass_func must be a callable for Module pass")
|
|
||||||
|
|
||||||
return _transform.CreateFunctionPass(
|
|
||||||
pass_func, opt_level, name if name else pass_func.__name__,
|
|
||||||
required)
|
|
||||||
|
|
||||||
if pass_func:
|
|
||||||
return create_function_pass(pass_func)
|
|
||||||
return create_function_pass
|
|
||||||
|
|
||||||
|
|
||||||
def InferType():
|
def InferType():
|
||||||
"""Infer the type of an expr.
|
"""Infer the type of an expr.
|
||||||
|
|
||||||
|
@ -593,3 +442,150 @@ def PartialEvaluate():
|
||||||
The registered pass that performs partial evaluation on an expression.
|
The registered pass that performs partial evaluation on an expression.
|
||||||
"""
|
"""
|
||||||
return _transform.PartialEvaluate()
|
return _transform.PartialEvaluate()
|
||||||
|
|
||||||
|
|
||||||
|
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
|
||||||
|
"""Create a module pass. This function returns a callback when pass_func
|
||||||
|
is provided. Otherwise, it returns the created module level pass using the
|
||||||
|
given optimization function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
||||||
|
Module/Function]]
|
||||||
|
The implemented optimization pass.
|
||||||
|
|
||||||
|
opt_level : int
|
||||||
|
The optimization level of this module pass.
|
||||||
|
|
||||||
|
name : Optional[str]
|
||||||
|
The name of the module pass. The name could be empty. In this case, the
|
||||||
|
name of the optimization function will be used as the pass name.
|
||||||
|
|
||||||
|
required : Optional[List[str]]
|
||||||
|
The list of passes that the module pass is dependent on.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
create_module_pass : Union[Callable, ModulePass]
|
||||||
|
The callable that will create a module pass is returned when
|
||||||
|
pass_func is not passed in. Otherwise, a ModulePass object will be
|
||||||
|
directly created.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
The following code creates a module level pass and adds an abs function to
|
||||||
|
the module.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@relay.transform.module_pass(opt_level=2)
|
||||||
|
def transform(mod, ctx):
|
||||||
|
tp = relay.TensorType((10,), "float32")
|
||||||
|
x = relay.var("x", tp)
|
||||||
|
gv = relay.GlobalVar("var")
|
||||||
|
func = relay.Function([x], relay.abs(x))
|
||||||
|
new_mod = relay.Module({gv: func})
|
||||||
|
new_mod.update(mod)
|
||||||
|
return new_mod
|
||||||
|
|
||||||
|
module_pass = transform
|
||||||
|
assert isinstance(module_pass, transform.ModulePass)
|
||||||
|
assert module_pass.info.opt_level == 2
|
||||||
|
|
||||||
|
# Given a module m, the optimization could be invoked as the follwoing:
|
||||||
|
updated_mod = module_pass(m)
|
||||||
|
# Now a function abs should be added to the module m.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if opt_level is None:
|
||||||
|
raise ValueError("Please provide opt_level for the module pass.")
|
||||||
|
|
||||||
|
required = required if required else []
|
||||||
|
if not isinstance(required, (list, tuple)):
|
||||||
|
raise TypeError("Required is expected to be the type of " +
|
||||||
|
"list/tuple.")
|
||||||
|
|
||||||
|
def create_module_pass(pass_func):
|
||||||
|
"""Internal function that creates a module pass"""
|
||||||
|
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
||||||
|
raise TypeError("pass_func must be a callable for Module pass")
|
||||||
|
|
||||||
|
fname = name if name else pass_func.__name__
|
||||||
|
info = PassInfo(opt_level, fname, required)
|
||||||
|
return _transform.MakeModulePass(pass_func, info)
|
||||||
|
|
||||||
|
if pass_func:
|
||||||
|
return create_module_pass(pass_func)
|
||||||
|
return create_module_pass
|
||||||
|
|
||||||
|
|
||||||
|
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
|
||||||
|
"""Create a function pass. This function returns a callback when pass_func
|
||||||
|
is provided. Otherwise, it returns the created function pass using the
|
||||||
|
given optimization function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
||||||
|
Module/Function]]
|
||||||
|
The implemented optimization pass.
|
||||||
|
|
||||||
|
opt_level : int
|
||||||
|
The optimization level of this module pass.
|
||||||
|
|
||||||
|
name : Optional[str]
|
||||||
|
The name of the function pass. The name could be empty. In this case, the
|
||||||
|
name of the optimization function will be used as the pass name.
|
||||||
|
|
||||||
|
required : Optional[List[str]]
|
||||||
|
The list of passes that the module pass is dependent on.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
create_function_pass : Union[Callable, FunctionPass]
|
||||||
|
The callable that will create a function pass is returned when
|
||||||
|
pass_func is not passed in. Otherwise, a FunctionPass object will be
|
||||||
|
created.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
The following code creates a function level pass that performs constant
|
||||||
|
folding.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@relay.transform.function_pass(opt_level=2)
|
||||||
|
def transform(func, ctx):
|
||||||
|
return ir_pass.fold_constant(func)
|
||||||
|
|
||||||
|
function_pass = transform
|
||||||
|
assert isinstance(function_pass, transform.FunctionPass)
|
||||||
|
assert function_pass.info.opt_level == 2
|
||||||
|
|
||||||
|
# Given a module m, the optimization could be invoked as the follwoing:
|
||||||
|
updated_mod = function_pass(m)
|
||||||
|
# Now constant folding should have been applied to every function in
|
||||||
|
# the provided module m. And the updated module will be returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if opt_level is None:
|
||||||
|
raise ValueError("Please provide opt_level for the funtion pass.")
|
||||||
|
|
||||||
|
required = required if required else []
|
||||||
|
if not isinstance(required, (list, tuple)):
|
||||||
|
raise TypeError("Required is expected to be the type of " +
|
||||||
|
"list/tuple.")
|
||||||
|
|
||||||
|
def create_function_pass(pass_func):
|
||||||
|
"""Internal function that creates a function pass"""
|
||||||
|
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
||||||
|
raise TypeError("pass_func must be a callable for Module pass")
|
||||||
|
|
||||||
|
fname = name if name else pass_func.__name__
|
||||||
|
info = PassInfo(opt_level, fname, required)
|
||||||
|
return _transform.MakeFunctionPass(pass_func, info)
|
||||||
|
|
||||||
|
if pass_func:
|
||||||
|
return create_function_pass(pass_func)
|
||||||
|
return create_function_pass
|
||||||
|
|
|
@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(ModulePassNode);
|
TVM_REGISTER_NODE_TYPE(ModulePassNode);
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._transform.CreateModulePass")
|
TVM_REGISTER_API("relay._transform.MakeModulePass")
|
||||||
.set_body_typed(CreateModulePass);
|
.set_body_typed(ModulePassNode::make);
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._transform.RunPass")
|
TVM_REGISTER_API("relay._transform.RunPass")
|
||||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
|
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._transform.CreateFunctionPass")
|
TVM_REGISTER_API("relay._transform.MakeFunctionPass")
|
||||||
.set_body_typed(CreateFunctionPass);
|
.set_body_typed(FunctionPassNode::make);
|
||||||
|
|
||||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
|
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
|
||||||
|
|
|
@ -259,6 +259,12 @@ def test_function_pass():
|
||||||
test_pass_run()
|
test_pass_run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pass_info():
|
||||||
|
info = relay.transform.PassInfo(opt_level=1, name="xyz")
|
||||||
|
assert info.opt_level == 1
|
||||||
|
assert info.name == "xyz"
|
||||||
|
|
||||||
|
|
||||||
def test_sequential_pass():
|
def test_sequential_pass():
|
||||||
shape = (10, )
|
shape = (10, )
|
||||||
dtype = 'float32'
|
dtype = 'float32'
|
||||||
|
@ -449,3 +455,4 @@ if __name__ == "__main__":
|
||||||
test_function_pass()
|
test_function_pass()
|
||||||
test_sequential_pass()
|
test_sequential_pass()
|
||||||
test_sequential_with_scoping()
|
test_sequential_with_scoping()
|
||||||
|
test_pass_info()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче