diff --git a/nni/compression/pytorch/speedup/v2/model_speedup.py b/nni/compression/pytorch/speedup/v2/model_speedup.py index 49a4bae3e..368329b3c 100644 --- a/nni/compression/pytorch/speedup/v2/model_speedup.py +++ b/nni/compression/pytorch/speedup/v2/model_speedup.py @@ -32,6 +32,13 @@ from .replacer import Replacer, DefaultReplacer from .utils import tree_map_zip +def _normalize_input(dummy_input: Any) -> Any: + if isinstance(dummy_input, torch.Tensor): + dummy_input = (dummy_input, ) + elif isinstance(dummy_input, list): + dummy_input = tuple(dummy_input) + return dummy_input + @compatibility(is_backward_compatible=True) class ModelSpeedup(torch.fx.Interpreter): """ @@ -88,7 +95,7 @@ class ModelSpeedup(torch.fx.Interpreter): graph_module: GraphModule | None = None, garbage_collect_values: bool = True, logger: logging.Logger | None = None): - self.dummy_input = (dummy_input,) if isinstance(dummy_input, torch.Tensor) else tuple(dummy_input) + self.dummy_input = _normalize_input(dummy_input) self.bound_model = model self.graph_module = graph_module if isinstance(graph_module, GraphModule) else concrete_trace(model, self.dummy_input)