зеркало из https://github.com/microsoft/nni.git
Fix serializer for complex kinds of arguments (#4487)
This commit is contained in:
Родитель
bb0a870006
Коммит
763f2c87de
|
@ -219,6 +219,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
|
|||
|
||||
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
|
||||
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
|
||||
Therefore, in some cases, some positional arguments will still be kept.
|
||||
|
||||
.. warning::
|
||||
|
||||
|
@ -451,27 +452,69 @@ def _formulate_single_argument(arg):
|
|||
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
|
||||
# This is to formulate the arguments and make them well-formed.
|
||||
if kw_only:
|
||||
# Match arguments with given arguments, so that we can use keyword arguments as much as possible.
|
||||
# Mutators don't like positional arguments. Positional arguments might not supply enough information.
|
||||
|
||||
# get arguments passed to a function, and save it as a dict
|
||||
argname_list = list(inspect.signature(func).parameters.keys())
|
||||
insp_parameters = inspect.signature(func).parameters
|
||||
argname_list = list(insp_parameters.keys())
|
||||
if is_class_init:
|
||||
argname_list = argname_list[1:]
|
||||
full_args = {}
|
||||
positional_args = []
|
||||
keyword_args = {}
|
||||
|
||||
# match arguments with given arguments
|
||||
# args should be longer than given list, because args can be used in a kwargs way
|
||||
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
|
||||
for argname, value in zip(argname_list, args):
|
||||
full_args[argname] = value
|
||||
# According to https://docs.python.org/3/library/inspect.html#inspect.Parameter, there are five kinds of parameters
|
||||
# in Python. We only try to handle POSITIONAL_ONLY and POSITIONAL_OR_KEYWORD here.
|
||||
|
||||
# Example:
|
||||
# For foo(a, b, *c, **d), a and b and c should be kept.
|
||||
# For foo(a, b, /, d), a and b should be kept.
|
||||
|
||||
for i, value in enumerate(args):
|
||||
if i >= len(argname_list):
|
||||
raise ValueError(f'{func} receives extra argument: {value}.')
|
||||
|
||||
argname = argname_list[i]
|
||||
if insp_parameters[argname].kind == inspect.Parameter.POSITIONAL_ONLY:
|
||||
# positional only. have to be kept.
|
||||
positional_args.append(value)
|
||||
|
||||
elif insp_parameters[argname].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
# this should be the most common case
|
||||
keyword_args[argname] = value
|
||||
|
||||
elif insp_parameters[argname].kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
# Any previous preprocessing might be wrong. Clean them all.
|
||||
# Any parameters that appear before a VAR_POSITIONAL should be kept positional.
|
||||
# Otherwise, VAR_POSITIONAL might not work.
|
||||
# For the cases I've tested, any parameters that appear after a VAR_POSITIONAL are considered keyword only.
|
||||
# But, if args is not long enough for VAR_POSITIONAL to be encountered, they should be handled by other if-branches.
|
||||
positional_args = args
|
||||
keyword_args = {}
|
||||
break
|
||||
|
||||
else:
|
||||
# kind has to be one of `KEYWORD_ONLY` and `VAR_KEYWORD`
|
||||
raise ValueError(f'{func} receives positional argument: {value}, but the parameter type is found to be keyword only.')
|
||||
|
||||
# use kwargs to override
|
||||
full_args.update(kwargs)
|
||||
keyword_args.update(kwargs)
|
||||
|
||||
args, kwargs = [], full_args
|
||||
if positional_args:
|
||||
# Raise a warning if some arguments are not convertible to keyword arguments.
|
||||
warnings.warn(f'Found positional arguments {positional_args} should processing parameters of {func}. '
|
||||
'We recommend always using keyword arguments to specify parameters. '
|
||||
'For example: `nn.LSTM(input_size=2, hidden_size=2)` instead of `nn.LSTM(2, 2)`.')
|
||||
|
||||
args = [_formulate_single_argument(arg) for arg in args]
|
||||
kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()}
|
||||
else:
|
||||
# keep them unprocessed
|
||||
positional_args, keyword_args = args, kwargs
|
||||
|
||||
return list(args), kwargs
|
||||
# do some extra conversions to the arguments.
|
||||
positional_args = [_formulate_single_argument(arg) for arg in positional_args]
|
||||
keyword_args = {k: _formulate_single_argument(arg) for k, arg in keyword_args.items()}
|
||||
|
||||
return positional_args, keyword_args
|
||||
|
||||
|
||||
def _is_function(obj: Any) -> bool:
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
import nni
|
||||
|
||||
|
||||
def test_positional_only():
|
||||
def foo(a, b, /, c):
|
||||
pass
|
||||
|
||||
d = nni.trace(foo)(1, 2, c=3)
|
||||
assert d.trace_args == [1, 2]
|
||||
assert d.trace_kwargs == dict(c=3)
|
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -16,6 +15,10 @@ if True: # prevent auto formatting
|
|||
sys.path.insert(0, Path(__file__).parent.as_posix())
|
||||
from imported.model import ImportTest
|
||||
|
||||
# this test cannot be directly put in this file. It will cause syntax error for python <= 3.7.
|
||||
if tuple(sys.version_info) >= (3, 8):
|
||||
from imported._test_serializer_py38 import test_positional_only
|
||||
|
||||
|
||||
@nni.trace
|
||||
class SimpleClass:
|
||||
|
@ -238,6 +241,36 @@ def test_generator():
|
|||
print(optimizer.trace_kwargs)
|
||||
|
||||
|
||||
def test_arguments_kind():
|
||||
def foo(a, b, *c, **d):
|
||||
pass
|
||||
|
||||
d = nni.trace(foo)(1, 2, 3, 4)
|
||||
assert d.trace_args == [1, 2, 3, 4]
|
||||
assert d.trace_kwargs == {}
|
||||
|
||||
d = nni.trace(foo)(a=1, b=2)
|
||||
assert d.trace_kwargs == dict(a=1, b=2)
|
||||
|
||||
d = nni.trace(foo)(1, b=2)
|
||||
# this is not perfect, but it's safe
|
||||
assert d.trace_kwargs == dict(a=1, b=2)
|
||||
|
||||
def foo(a, *, b=3, c=5):
|
||||
pass
|
||||
|
||||
d = nni.trace(foo)(1, b=2, c=3)
|
||||
assert d.trace_kwargs == dict(a=1, b=2, c=3)
|
||||
|
||||
import torch.nn as nn
|
||||
lstm = nni.trace(nn.LSTM)(2, 2)
|
||||
assert lstm.input_size == 2
|
||||
assert lstm.hidden_size == 2
|
||||
assert lstm.trace_args == [2, 2]
|
||||
|
||||
lstm = nni.trace(nn.LSTM)(input_size=2, hidden_size=2)
|
||||
assert lstm.trace_kwargs == {'input_size': 2, 'hidden_size': 2}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_simple_class()
|
||||
|
@ -245,4 +278,5 @@ if __name__ == '__main__':
|
|||
# test_nested_class()
|
||||
# test_unserializable()
|
||||
# test_basic_unit()
|
||||
test_generator()
|
||||
# test_generator()
|
||||
test_arguments_kind()
|
||||
|
|
Загрузка…
Ссылка в новой задаче