Add support for generator in serializer (#4465)

This commit is contained in:
Yuge Zhang 2022-01-17 12:14:25 +08:00 коммит произвёл GitHub
Родитель 253dbfd8a7
Коммит 90f96ef553
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 39 добавлений и 1 удалений

Просмотреть файл

@ -220,6 +220,11 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
.. warning::
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
This might hang when generators produce an infinite sequence. We might introduce an API to control this behavior in future.
Example:
.. code-block:: python
@ -431,6 +436,18 @@ def _argument_processor(arg):
return arg
def _formulate_single_argument(arg):
# this is different from argument processor
# it directly apply the transformation on the stored arguments
# expand generator into list
# Note that some types that are generator (such as range(10)) may not be identified as generator here.
if isinstance(arg, types.GeneratorType):
arg = list(arg)
return arg
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
# This is to formulate the arguments and make them well-formed.
if kw_only:
@ -451,6 +468,9 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
args, kwargs = [], full_args
args = [_formulate_single_argument(arg) for arg in args]
kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()}
return list(args), kwargs

Просмотреть файл

@ -221,10 +221,28 @@ def test_lightning_earlystop():
assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
def test_generator():
import torch.nn as nn
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 10, 1)
def forward(self, x):
return self.conv(x)
model = Net()
optimizer = nni.trace(optim.Adam)(model.parameters())
print(optimizer.trace_kwargs)
if __name__ == '__main__':
# test_simple_class()
# test_external_class()
# test_nested_class()
# test_unserializable()
# test_basic_unit()
test_multiprocessing_dataloader()
test_generator()