Fix serializer for complex kinds of arguments (#4487)

This commit is contained in:
Yuge Zhang 2022-01-27 12:32:33 +08:00 коммит произвёл GitHub
Родитель bb0a870006
Коммит 763f2c87de
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 101 добавлений и 14 удалений

Просмотреть файл

@ -219,6 +219,7 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Therefore, in some cases, some positional arguments will still be kept.
.. warning::
@ -451,27 +452,69 @@ def _formulate_single_argument(arg):
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
# This is to formulate the arguments and make them well-formed.
if kw_only:
# Match arguments with given arguments, so that we can use keyword arguments as much as possible.
# Mutators don't like positional arguments. Positional arguments might not supply enough information.
# get arguments passed to a function, and save it as a dict
argname_list = list(inspect.signature(func).parameters.keys())
insp_parameters = inspect.signature(func).parameters
argname_list = list(insp_parameters.keys())
if is_class_init:
argname_list = argname_list[1:]
full_args = {}
positional_args = []
keyword_args = {}
# match arguments with given arguments
# args should be longer than given list, because args can be used in a kwargs way
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# According to https://docs.python.org/3/library/inspect.html#inspect.Parameter, there are five kinds of parameters
# in Python. We only try to handle POSITIONAL_ONLY and POSITIONAL_OR_KEYWORD here.
# Example:
# For foo(a, b, *c, **d), a and b and c should be kept.
# For foo(a, b, /, d), a and b should be kept.
for i, value in enumerate(args):
if i >= len(argname_list):
raise ValueError(f'{func} receives extra argument: {value}.')
argname = argname_list[i]
if insp_parameters[argname].kind == inspect.Parameter.POSITIONAL_ONLY:
# positional only. have to be kept.
positional_args.append(value)
elif insp_parameters[argname].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# this should be the most common case
keyword_args[argname] = value
elif insp_parameters[argname].kind == inspect.Parameter.VAR_POSITIONAL:
# Any previous preprocessing might be wrong. Clean them all.
# Any parameters that appear before a VAR_POSITIONAL should be kept positional.
# Otherwise, VAR_POSITIONAL might not work.
# For the cases I've tested, any parameters that appear after a VAR_POSITIONAL are considered keyword only.
# But, if args is not long enough for VAR_POSITIONAL to be encountered, they should be handled by other if-branches.
positional_args = args
keyword_args = {}
break
else:
# kind has to be one of `KEYWORD_ONLY` and `VAR_KEYWORD`
raise ValueError(f'{func} receives positional argument: {value}, but the parameter type is found to be keyword only.')
# use kwargs to override
full_args.update(kwargs)
keyword_args.update(kwargs)
args, kwargs = [], full_args
if positional_args:
# Raise a warning if some arguments are not convertible to keyword arguments.
warnings.warn(f'Found positional arguments {positional_args} should processing parameters of {func}. '
'We recommend always using keyword arguments to specify parameters. '
'For example: `nn.LSTM(input_size=2, hidden_size=2)` instead of `nn.LSTM(2, 2)`.')
args = [_formulate_single_argument(arg) for arg in args]
kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()}
else:
# keep them unprocessed
positional_args, keyword_args = args, kwargs
return list(args), kwargs
# do some extra conversions to the arguments.
positional_args = [_formulate_single_argument(arg) for arg in positional_args]
keyword_args = {k: _formulate_single_argument(arg) for k, arg in keyword_args.items()}
return positional_args, keyword_args
def _is_function(obj: Any) -> bool:

Просмотреть файл

@ -0,0 +1,10 @@
import nni
def test_positional_only():
def foo(a, b, /, c):
pass
d = nni.trace(foo)(1, 2, c=3)
assert d.trace_args == [1, 2]
assert d.trace_kwargs == dict(c=3)

Просмотреть файл

@ -1,5 +1,4 @@
import math
import re
import sys
from pathlib import Path
@ -16,6 +15,10 @@ if True: # prevent auto formatting
sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest
# this test cannot be directly put in this file. It will cause syntax error for python <= 3.7.
if tuple(sys.version_info) >= (3, 8):
from imported._test_serializer_py38 import test_positional_only
@nni.trace
class SimpleClass:
@ -238,6 +241,36 @@ def test_generator():
print(optimizer.trace_kwargs)
def test_arguments_kind():
def foo(a, b, *c, **d):
pass
d = nni.trace(foo)(1, 2, 3, 4)
assert d.trace_args == [1, 2, 3, 4]
assert d.trace_kwargs == {}
d = nni.trace(foo)(a=1, b=2)
assert d.trace_kwargs == dict(a=1, b=2)
d = nni.trace(foo)(1, b=2)
# this is not perfect, but it's safe
assert d.trace_kwargs == dict(a=1, b=2)
def foo(a, *, b=3, c=5):
pass
d = nni.trace(foo)(1, b=2, c=3)
assert d.trace_kwargs == dict(a=1, b=2, c=3)
import torch.nn as nn
lstm = nni.trace(nn.LSTM)(2, 2)
assert lstm.input_size == 2
assert lstm.hidden_size == 2
assert lstm.trace_args == [2, 2]
lstm = nni.trace(nn.LSTM)(input_size=2, hidden_size=2)
assert lstm.trace_kwargs == {'input_size': 2, 'hidden_size': 2}
if __name__ == '__main__':
# test_simple_class()
@ -245,4 +278,5 @@ if __name__ == '__main__':
# test_nested_class()
# test_unserializable()
# test_basic_unit()
test_generator()
# test_generator()
test_arguments_kind()