[RELAY] Pass infra cleanup (#3336)
This commit is contained in:
Родитель
d6c4aba837
Коммит
c9a2f3da5b
|
@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
|
|||
v->Visit("required", &required);
|
||||
}
|
||||
|
||||
TVM_DLL static PassInfo make(int opt_level, std::string name,
|
||||
TVM_DLL static PassInfo make(int opt_level,
|
||||
std::string name,
|
||||
tvm::Array<tvm::Expr> required);
|
||||
|
||||
static constexpr const char* _type_key = "relay.PassInfo";
|
||||
|
@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
|
|||
* type information filled in, as well as it's checked type field
|
||||
* populated with the result type.
|
||||
*
|
||||
* \return The pass.
|
||||
* \return The pass.
|
||||
*/
|
||||
TVM_DLL Pass InferType();
|
||||
|
||||
|
|
|
@ -14,13 +14,9 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=no-else-return
|
||||
# pylint: disable=unidiomatic-typecheck
|
||||
# pylint: disable=invalid-name
|
||||
"""
|
||||
This file contains the pass manager for Relay which exposes different
|
||||
granularity of interfaces for users to implement and use passes more
|
||||
conveniently.
|
||||
Relay pass transformation infrastructure.
|
||||
"""
|
||||
import types
|
||||
|
||||
|
@ -39,19 +35,19 @@ class PassInfo(RelayNode):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The pass name.
|
||||
|
||||
opt_level : int
|
||||
The optimization level of this pass.
|
||||
|
||||
name : str
|
||||
The pass name.
|
||||
|
||||
required : List[str]
|
||||
The list of passes that are required by a certain pass.
|
||||
"""
|
||||
|
||||
def __init__(self, name, opt_level, required=None):
|
||||
self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
|
||||
required)
|
||||
def __init__(self, opt_level, name, required=None):
|
||||
self.__init_handle_by_constructor__(
|
||||
_transform.PassInfo, opt_level, name, required)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
|
@ -194,7 +190,7 @@ class ModulePass(Pass):
|
|||
`module_pass`, because the design of the `module_pass` API is flexible
|
||||
enough to handle the creation of a module pass in different manners. In
|
||||
addition, all members of a module pass can be accessed from the base class.
|
||||
The same rule applies to FunctionPass and Sequential as well.
|
||||
The same rule applies to FunctionPass as well.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -250,153 +246,6 @@ class Sequential(Pass):
|
|||
passes, opt_level, name, required)
|
||||
|
||||
|
||||
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
|
||||
"""Create a module pass. This function returns a callback when pass_func
|
||||
is provided. Otherwise, it returns the created module level pass using the
|
||||
given optimization function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
||||
Module/Function]]
|
||||
The implemented optimization pass.
|
||||
|
||||
opt_level : int
|
||||
The optimization level of this module pass.
|
||||
|
||||
name : Optional[str]
|
||||
The name of the module pass. The name could be empty. In this case, the
|
||||
name of the optimization function will be used as the pass name.
|
||||
|
||||
required : Optional[List[str]]
|
||||
The list of passes that the module pass is dependent on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
create_module_pass : Union[Callable, ModulePass]
|
||||
The callable that will create a module pass is returned when
|
||||
pass_func is not passed in. Otherwise, a ModulePass object will be
|
||||
directly created.
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following code creates a module level pass and adds an abs function to
|
||||
the module.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@relay.transform.module_pass(opt_level=2)
|
||||
def transform(mod, ctx):
|
||||
tp = relay.TensorType((10,), "float32")
|
||||
x = relay.var("x", tp)
|
||||
gv = relay.GlobalVar("var")
|
||||
func = relay.Function([x], relay.abs(x))
|
||||
new_mod = relay.Module({gv: func})
|
||||
new_mod.update(mod)
|
||||
return new_mod
|
||||
|
||||
module_pass = transform
|
||||
assert isinstance(module_pass, transform.ModulePass)
|
||||
assert module_pass.info.opt_level == 2
|
||||
|
||||
# Given a module m, the optimization could be invoked as the follwoing:
|
||||
updated_mod = module_pass(m)
|
||||
# Now a function abs should be added to the module m.
|
||||
"""
|
||||
|
||||
if opt_level is None:
|
||||
raise ValueError("Please provide opt_level for the module pass.")
|
||||
|
||||
required = required if required else []
|
||||
if not isinstance(required, (list, tuple)):
|
||||
raise TypeError("Required is expected to be the type of " +
|
||||
"list/tuple.")
|
||||
|
||||
def create_module_pass(pass_func):
|
||||
"""Internal function that creates a module pass"""
|
||||
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
||||
raise TypeError("pass_func must be a callable for Module pass")
|
||||
|
||||
return _transform.CreateModulePass(
|
||||
pass_func, opt_level, name if name else pass_func.__name__,
|
||||
required)
|
||||
|
||||
if pass_func:
|
||||
return create_module_pass(pass_func)
|
||||
return create_module_pass
|
||||
|
||||
|
||||
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
|
||||
"""Create a function pass. This function returns a callback when pass_func
|
||||
is provided. Otherwise, it returns the created function pass using the
|
||||
given optimization function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
||||
Module/Function]]
|
||||
The implemented optimization pass.
|
||||
|
||||
opt_level : int
|
||||
The optimization level of this module pass.
|
||||
|
||||
name : Optional[str]
|
||||
The name of the function pass. The name could be empty. In this case, the
|
||||
name of the optimization function will be used as the pass name.
|
||||
|
||||
required : Optional[List[str]]
|
||||
The list of passes that the module pass is dependent on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
create_function_pass : Union[Callable, FunctionPass]
|
||||
The callable that will create a function pass is returned when
|
||||
pass_func is not passed in. Otherwise, a FunctionPass object will be
|
||||
created.
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following code creates a function level pass that performs constant
|
||||
folding.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@relay.transform.function_pass(opt_level=2)
|
||||
def transform(func, ctx):
|
||||
return ir_pass.fold_constant(func)
|
||||
|
||||
function_pass = transform
|
||||
assert isinstance(function_pass, transform.FunctionPass)
|
||||
assert function_pass.info.opt_level == 2
|
||||
|
||||
# Given a module m, the optimization could be invoked as the follwoing:
|
||||
updated_mod = function_pass(m)
|
||||
# Now constant folding should have been applied to every function in
|
||||
# the provided module m. And the updated module will be returned.
|
||||
"""
|
||||
|
||||
if opt_level is None:
|
||||
raise ValueError("Please provide opt_level for the funtion pass.")
|
||||
|
||||
required = required if required else []
|
||||
if not isinstance(required, (list, tuple)):
|
||||
raise TypeError("Required is expected to be the type of " +
|
||||
"list/tuple.")
|
||||
|
||||
def create_function_pass(pass_func):
|
||||
"""Internal function that creates a function pass"""
|
||||
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
||||
raise TypeError("pass_func must be a callable for Module pass")
|
||||
|
||||
return _transform.CreateFunctionPass(
|
||||
pass_func, opt_level, name if name else pass_func.__name__,
|
||||
required)
|
||||
|
||||
if pass_func:
|
||||
return create_function_pass(pass_func)
|
||||
return create_function_pass
|
||||
|
||||
|
||||
def InferType():
|
||||
"""Infer the type of an expr.
|
||||
|
||||
|
@ -593,3 +442,150 @@ def PartialEvaluate():
|
|||
The registered pass that performs partial evaluation on an expression.
|
||||
"""
|
||||
return _transform.PartialEvaluate()
|
||||
|
||||
|
||||
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
|
||||
"""Create a module pass. This function returns a callback when pass_func
|
||||
is provided. Otherwise, it returns the created module level pass using the
|
||||
given optimization function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
||||
Module/Function]]
|
||||
The implemented optimization pass.
|
||||
|
||||
opt_level : int
|
||||
The optimization level of this module pass.
|
||||
|
||||
name : Optional[str]
|
||||
The name of the module pass. The name could be empty. In this case, the
|
||||
name of the optimization function will be used as the pass name.
|
||||
|
||||
required : Optional[List[str]]
|
||||
The list of passes that the module pass is dependent on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
create_module_pass : Union[Callable, ModulePass]
|
||||
The callable that will create a module pass is returned when
|
||||
pass_func is not passed in. Otherwise, a ModulePass object will be
|
||||
directly created.
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following code creates a module level pass and adds an abs function to
|
||||
the module.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@relay.transform.module_pass(opt_level=2)
|
||||
def transform(mod, ctx):
|
||||
tp = relay.TensorType((10,), "float32")
|
||||
x = relay.var("x", tp)
|
||||
gv = relay.GlobalVar("var")
|
||||
func = relay.Function([x], relay.abs(x))
|
||||
new_mod = relay.Module({gv: func})
|
||||
new_mod.update(mod)
|
||||
return new_mod
|
||||
|
||||
module_pass = transform
|
||||
assert isinstance(module_pass, transform.ModulePass)
|
||||
assert module_pass.info.opt_level == 2
|
||||
|
||||
# Given a module m, the optimization could be invoked as the follwoing:
|
||||
updated_mod = module_pass(m)
|
||||
# Now a function abs should be added to the module m.
|
||||
"""
|
||||
|
||||
if opt_level is None:
|
||||
raise ValueError("Please provide opt_level for the module pass.")
|
||||
|
||||
required = required if required else []
|
||||
if not isinstance(required, (list, tuple)):
|
||||
raise TypeError("Required is expected to be the type of " +
|
||||
"list/tuple.")
|
||||
|
||||
def create_module_pass(pass_func):
|
||||
"""Internal function that creates a module pass"""
|
||||
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
||||
raise TypeError("pass_func must be a callable for Module pass")
|
||||
|
||||
fname = name if name else pass_func.__name__
|
||||
info = PassInfo(opt_level, fname, required)
|
||||
return _transform.MakeModulePass(pass_func, info)
|
||||
|
||||
if pass_func:
|
||||
return create_module_pass(pass_func)
|
||||
return create_module_pass
|
||||
|
||||
|
||||
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
|
||||
"""Create a function pass. This function returns a callback when pass_func
|
||||
is provided. Otherwise, it returns the created function pass using the
|
||||
given optimization function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pass_func : Optional[Callable[(Module/Function, PassContext) ->
|
||||
Module/Function]]
|
||||
The implemented optimization pass.
|
||||
|
||||
opt_level : int
|
||||
The optimization level of this module pass.
|
||||
|
||||
name : Optional[str]
|
||||
The name of the function pass. The name could be empty. In this case, the
|
||||
name of the optimization function will be used as the pass name.
|
||||
|
||||
required : Optional[List[str]]
|
||||
The list of passes that the module pass is dependent on.
|
||||
|
||||
Returns
|
||||
-------
|
||||
create_function_pass : Union[Callable, FunctionPass]
|
||||
The callable that will create a function pass is returned when
|
||||
pass_func is not passed in. Otherwise, a FunctionPass object will be
|
||||
created.
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following code creates a function level pass that performs constant
|
||||
folding.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@relay.transform.function_pass(opt_level=2)
|
||||
def transform(func, ctx):
|
||||
return ir_pass.fold_constant(func)
|
||||
|
||||
function_pass = transform
|
||||
assert isinstance(function_pass, transform.FunctionPass)
|
||||
assert function_pass.info.opt_level == 2
|
||||
|
||||
# Given a module m, the optimization could be invoked as the follwoing:
|
||||
updated_mod = function_pass(m)
|
||||
# Now constant folding should have been applied to every function in
|
||||
# the provided module m. And the updated module will be returned.
|
||||
"""
|
||||
|
||||
if opt_level is None:
|
||||
raise ValueError("Please provide opt_level for the funtion pass.")
|
||||
|
||||
required = required if required else []
|
||||
if not isinstance(required, (list, tuple)):
|
||||
raise TypeError("Required is expected to be the type of " +
|
||||
"list/tuple.")
|
||||
|
||||
def create_function_pass(pass_func):
|
||||
"""Internal function that creates a function pass"""
|
||||
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
|
||||
raise TypeError("pass_func must be a callable for Module pass")
|
||||
|
||||
fname = name if name else pass_func.__name__
|
||||
info = PassInfo(opt_level, fname, required)
|
||||
return _transform.MakeFunctionPass(pass_func, info)
|
||||
|
||||
if pass_func:
|
||||
return create_function_pass(pass_func)
|
||||
return create_function_pass
|
||||
|
|
|
@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
|
||||
TVM_REGISTER_NODE_TYPE(ModulePassNode);
|
||||
|
||||
TVM_REGISTER_API("relay._transform.CreateModulePass")
|
||||
.set_body_typed(CreateModulePass);
|
||||
TVM_REGISTER_API("relay._transform.MakeModulePass")
|
||||
.set_body_typed(ModulePassNode::make);
|
||||
|
||||
TVM_REGISTER_API("relay._transform.RunPass")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
|
@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
|
||||
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
|
||||
|
||||
TVM_REGISTER_API("relay._transform.CreateFunctionPass")
|
||||
.set_body_typed(CreateFunctionPass);
|
||||
TVM_REGISTER_API("relay._transform.MakeFunctionPass")
|
||||
.set_body_typed(FunctionPassNode::make);
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
|
||||
|
|
|
@ -259,6 +259,12 @@ def test_function_pass():
|
|||
test_pass_run()
|
||||
|
||||
|
||||
def test_pass_info():
|
||||
info = relay.transform.PassInfo(opt_level=1, name="xyz")
|
||||
assert info.opt_level == 1
|
||||
assert info.name == "xyz"
|
||||
|
||||
|
||||
def test_sequential_pass():
|
||||
shape = (10, )
|
||||
dtype = 'float32'
|
||||
|
@ -449,3 +455,4 @@ if __name__ == "__main__":
|
|||
test_function_pass()
|
||||
test_sequential_pass()
|
||||
test_sequential_with_scoping()
|
||||
test_pass_info()
|
||||
|
|
Загрузка…
Ссылка в новой задаче