зеркало из https://github.com/microsoft/nni.git
Add support for generator in serializer (#4465)
This commit is contained in:
Родитель
253dbfd8a7
Коммит
90f96ef553
|
@ -220,6 +220,11 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
|
|||
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
|
||||
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
|
||||
|
||||
.. warning::
|
||||
|
||||
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
|
||||
This might hang when generators produce an infinite sequence. We might introduce an API to control this behavior in future.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -431,6 +436,18 @@ def _argument_processor(arg):
|
|||
return arg
|
||||
|
||||
|
||||
def _formulate_single_argument(arg):
|
||||
# this is different from argument processor
|
||||
# it directly apply the transformation on the stored arguments
|
||||
|
||||
# expand generator into list
|
||||
# Note that some types that are generator (such as range(10)) may not be identified as generator here.
|
||||
if isinstance(arg, types.GeneratorType):
|
||||
arg = list(arg)
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
|
||||
# This is to formulate the arguments and make them well-formed.
|
||||
if kw_only:
|
||||
|
@ -451,6 +468,9 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
|
|||
|
||||
args, kwargs = [], full_args
|
||||
|
||||
args = [_formulate_single_argument(arg) for arg in args]
|
||||
kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()}
|
||||
|
||||
return list(args), kwargs
|
||||
|
||||
|
||||
|
|
|
@ -221,10 +221,28 @@ def test_lightning_earlystop():
|
|||
assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
|
||||
|
||||
|
||||
def test_generator():
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 10, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
model = Net()
|
||||
optimizer = nni.trace(optim.Adam)(model.parameters())
|
||||
print(optimizer.trace_kwargs)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_simple_class()
|
||||
# test_external_class()
|
||||
# test_nested_class()
|
||||
# test_unserializable()
|
||||
# test_basic_unit()
|
||||
test_multiprocessing_dataloader()
|
||||
test_generator()
|
||||
|
|
Загрузка…
Ссылка в новой задаче