зеркало из https://github.com/microsoft/nni.git
[Compression] allow dummy input as dict. (#5440)
This commit is contained in:
Родитель
43b3d8a34d
Коммит
8693bde00e
|
@ -32,6 +32,13 @@ from .replacer import Replacer, DefaultReplacer
|
|||
from .utils import tree_map_zip
|
||||
|
||||
|
||||
def _normalize_input(dummy_input: Any) -> Any:
|
||||
if isinstance(dummy_input, torch.Tensor):
|
||||
dummy_input = (dummy_input, )
|
||||
elif isinstance(dummy_input, list):
|
||||
dummy_input = tuple(dummy_input)
|
||||
return dummy_input
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ModelSpeedup(torch.fx.Interpreter):
|
||||
"""
|
||||
|
@ -88,7 +95,7 @@ class ModelSpeedup(torch.fx.Interpreter):
|
|||
graph_module: GraphModule | None = None,
|
||||
garbage_collect_values: bool = True,
|
||||
logger: logging.Logger | None = None):
|
||||
self.dummy_input = (dummy_input,) if isinstance(dummy_input, torch.Tensor) else tuple(dummy_input)
|
||||
self.dummy_input = _normalize_input(dummy_input)
|
||||
self.bound_model = model
|
||||
self.graph_module = graph_module if isinstance(graph_module, GraphModule) else concrete_trace(model, self.dummy_input)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче