[Compression] allow dummy input as dict. (#5440)

This commit is contained in:
Super Daniel 2023-03-15 15:11:45 +08:00 коммит произвёл GitHub
Родитель 43b3d8a34d
Коммит 8693bde00e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 8 добавлений и 1 удалений

Просмотреть файл

@ -32,6 +32,13 @@ from .replacer import Replacer, DefaultReplacer
from .utils import tree_map_zip
def _normalize_input(dummy_input: Any) -> Any:
if isinstance(dummy_input, torch.Tensor):
dummy_input = (dummy_input, )
elif isinstance(dummy_input, list):
dummy_input = tuple(dummy_input)
return dummy_input
@compatibility(is_backward_compatible=True)
class ModelSpeedup(torch.fx.Interpreter):
"""
@ -88,7 +95,7 @@ class ModelSpeedup(torch.fx.Interpreter):
graph_module: GraphModule | None = None,
garbage_collect_values: bool = True,
logger: logging.Logger | None = None):
self.dummy_input = (dummy_input,) if isinstance(dummy_input, torch.Tensor) else tuple(dummy_input)
self.dummy_input = _normalize_input(dummy_input)
self.bound_model = model
self.graph_module = graph_module if isinstance(graph_module, GraphModule) else concrete_trace(model, self.dummy_input)