[Compression] fix tests & bugs (#5444)

This commit is contained in:
J-shang 2023-03-16 01:32:48 +08:00 коммит произвёл GitHub
Родитель d23dec38e2
Коммит f8d85ce352
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 104 добавлений и 48 удалений

Просмотреть файл

@ -85,14 +85,23 @@ class ConcreteTracer(TracerBase):
_orig_contains: ([], False, None),
_orig_index: ([], False, None),
# force-traced function
# force-traced function (the factory functions of tensor creation)
torch.arange: ([], True, None),
torch.empty: ([], True, None),
torch.eye: ([], True, None),
torch.full: ([], True, None),
torch.linspace: ([], True, None),
torch.logspace: ([], True, None),
torch.ones: ([], True, None),
torch.rand: ([], True, None),
torch.randn: ([], True, None),
torch.randint: ([], True, None),
torch.rand_like: ([], True, None),
torch.randn_like: ([], True, None),
torch.randint_like: ([], True, None),
torch.randn: ([], True, None),
# torch.rand_like: ([], True, None), # seems that xxx_like will not directly call torch._TensorBase.xxx
# torch.randn_like: ([], True, None),
# torch.randint_like: ([], True, None),
torch.randperm: ([], True, None),
torch.tensor: ([], True, None),
torch.zeros: ([], True, None),
# method
Sequential.__getitem__: ([], False, operator.getitem),
@ -276,13 +285,6 @@ class ConcreteTracer(TracerBase):
similar to _symbolic_trace.Tracer.create_proxy.
use the 'run_target' to actually execute the code, and store the value in 'value' field.
"""
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
assert isinstance(args_, tuple)
assert isinstance(kwargs_, dict)
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
def upwrapper(obj: Any):
while _orig_isinstance(obj, ep.ConcreteProxy):
obj = obj.value
@ -293,6 +295,13 @@ class ConcreteTracer(TracerBase):
# real value by execution
value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped)
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
assert isinstance(args_, tuple)
assert isinstance(kwargs_, dict)
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
proxy = self.proxy(value_unwrapped, node)
self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
return proxy
@ -526,6 +535,19 @@ class ConcreteTracer(TracerBase):
such as '__main__.FooModel' or '__main__.bar_func'. the namespace is
always needed.
"""
# fill default values
args = inspect.getfullargspec(root.forward).args[1:]
defaults = inspect.getfullargspec(root.forward).defaults
defaults = tuple() if defaults is None else defaults
if isinstance(concrete_args, (tuple, list)):
concrete_args = (*concrete_args, *defaults[len(concrete_args) + len(defaults) - len(args):])
else:
kv_default = {k: v for k, v in zip(args[-len(defaults):], defaults)}
concrete_args = {
**concrete_args,
**{n: kv_default[n] for n in args if n not in concrete_args}
}
# preprocess arguments
autowrap_modules = autowrap_modules if autowrap_modules is not None else tuple()
autowrap_leaf_function = autowrap_leaf_function if autowrap_leaf_function is not None else {}
@ -1396,6 +1418,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
fx.GraphModule: a Module created from the recorded operations from ``root``.
"""
tracer = ConcreteTracer()
graph = tracer.trace(root,
autowrap_leaf_function = autowrap_leaf_function,
autowrap_leaf_class = autowrap_leaf_class,

Просмотреть файл

@ -6,8 +6,10 @@ if TYPE_CHECKING:
from .concrete_tracer import ConcreteTracer
import ast
import builtins
import inspect
import logging
import platform
from textwrap import dedent
from types import MethodType, FunctionType
@ -21,6 +23,7 @@ from .utils import (
_orig_len,
_orig_dict,
_orig_zip,
_orig_tuple,
)
_logger = logging.getLogger(__name__)
@ -235,20 +238,27 @@ class OperatorPatcher:
assert _orig_len(closures) == _orig_len(co_freevars)
closure_dict = _orig_dict(_orig_zip(co_freevars, [c.cell_contents for c in closures]))
var_dict = {}
exec(
# use func.__code__.co_filename to make the new function easily debuggable.
compile(new_tree, func_inner.__code__.co_filename, 'exec'),
{
'patch_run': OperatorPatcherContext.patch_run,
**func_inner.__globals__,
**closure_dict,
},
var_dict)
if the_self is not None:
return var_dict['new_func'].__get__(the_self)
else:
return var_dict['new_func']
tuple_wrapped = tuple
try:
if platform.python_version_tuple() < ('3', '9'):
setattr(builtins, 'tuple', _orig_tuple)
var_dict = {}
exec(
# use func.__code__.co_filename to make the new function easily debuggable.
compile(new_tree, func_inner.__code__.co_filename, 'exec'),
{
'patch_run': OperatorPatcherContext.patch_run,
**func_inner.__globals__,
**closure_dict,
},
var_dict)
if the_self is not None:
return var_dict['new_func'].__get__(the_self)
else:
return var_dict['new_func']
finally:
if platform.python_version_tuple() < ('3', '9'):
setattr(builtins, 'tuple', tuple_wrapped)
class OperatorPatcherContext:
ctx_tracer: Optional['ConcreteTracer'] = None

Просмотреть файл

@ -12,6 +12,7 @@ import operator
import torch
from torch.nn import functional as F
from torch.fx.node import Node
from torch.utils._pytree import tree_flatten, tree_unflatten
from .utils import randomize_tensor_inplace, randomize_if_tensor, tree_map_zip, torch_float_dtype
@ -398,7 +399,26 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
assert len(node.args) == 2
input_grad = tree_map_zip(lambda t, m: (t * m).type_as(t) if isinstance(m, torch.Tensor) else t, \
model_speedup.node_infos[node].output_grad, model_speedup.node_infos[node].output_masks)
model_speedup.indirect_pass_grad(node.args[0], input_grad)
arg_1_val = model_speedup.node_infos[node.args[1]].output_randomize if isinstance(node.args[1], Node) else node.args[1]
input_node_info = model_speedup.node_infos[node.args[0]]
flat_args, spec = tree_flatten(input_node_info.output_grad)
flat_grads = [None for _ in range(len(flat_args))]
flat_grads[arg_1_val] = input_grad
input_grads = tree_unflatten(flat_grads, spec)
def add_grad(grad, input_grad):
if isinstance(input_grad, torch.Tensor):
if grad is not None and input_grad is not None:
return grad + input_grad
elif grad is None:
return input_grad
else:
return grad
else:
return grad
model_speedup.node_infos[node].output_grad = tree_map_zip(add_grad, model_speedup.node_infos[node.args[0]].output_grad, input_grads)
def detect(self, model_speedup: 'ModelSpeedup', node: Node) -> bool:
return self.detect_helper(model_speedup, node) is not None

Просмотреть файл

@ -45,7 +45,7 @@ class DefaultReplacer(Replacer):
def replace_modules(self, speedup: 'ModelSpeedup'):
for node, node_info in speedup.node_infos.items():
if node.op == 'call_module':
if node.op == 'call_module' and not node_info.replaced:
# module = speedup.fetch_attr(node.target)
# module_type = module._get_name()
module = get_nested_attr(speedup.bound_model, node.target)

Просмотреть файл

@ -366,6 +366,7 @@ class ModuleWrapper(torch.nn.Module):
for idx, target in enumerate(outputs):
target_name = f'{OUTPUT_PREFIX}{idx}'
new_outputs.append(self.patch_helper(target_name, target))
new_outputs = type(outputs)(new_outputs)
elif isinstance(outputs, dict):
new_outputs = {}
for output_name, target in outputs.items():

Просмотреть файл

@ -4,7 +4,6 @@
from __future__ import annotations
from collections import defaultdict
from copy import deepcopy
import logging
from typing import Any, Callable, Dict, List, overload
@ -97,17 +96,16 @@ class TeacherModelBasedDistiller(Distiller):
target_space.lambda_ = target_space.lambda_ if target_space.lambda_ is not None else 1.
def _register_teacher_wrappers(self):
link2targets = defaultdict(set)
teacher_config_list = []
for _, ts in self._target_spaces.items():
for target_name, target_space in ts.items():
for link in target_space.link:
teacher_config_list.append({
'op_names': [link],
'target_names': [target_name],
'target_settings': {
target_name: deepcopy(target_space.setting)
}
})
link2targets[link].add(target_name)
teacher_config_list = [{
'op_names': [link],
'target_names': list(target_names)
} for link, target_names in link2targets.items()]
return register_wrappers(self.teacher_model, teacher_config_list, mode=self.mode)
def wrap_teacher_model(self):

Просмотреть файл

@ -117,19 +117,15 @@ class SlimPruner(Pruner):
params = [{"params": scaling_factors}]
optimizer = Adam(params, 1e-2)
def optimizer_task():
optimizer.step()
optimizer.zero_grad()
evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[])
evaluator.patch_optimizer_step(before_step_tasks=[optimizer.step], after_step_tasks=[optimizer.zero_grad])
def _patch_loss(self, evaluator: Evaluator):
def loss_patch(original_loss, batch):
reg_loss = 0.
reg_loss = torch.tensor(0., device=original_loss.device)
count = 0
for _, target_scaling_factor in self.scaling_factors.items():
for _, scaling_factor in target_scaling_factor.items():
reg_loss += scaling_factor.norm(p=1) # type: ignore
reg_loss = reg_loss + scaling_factor.norm(p=1) # type: ignore
count += 1
if count > 0:
reg_loss = self.regular_scale * reg_loss / count

Просмотреть файл

@ -14,7 +14,7 @@ schedules:
always: true
variables:
filter.modified.globs: 'examples/model_compress/**,nni/algorithms/compression/**,nni/compression/**,pipelines/full-test-compression.yml,test/algo/compression/**,nni/contrib/compression/**'
filter.modified.globs: 'examples/model_compress/**,nni/algorithms/compression/**,nni/compression/**,pipelines/full-test-compression.yml,test/algo/compression/**,nni/contrib/compression/**,nni/common/**'
filter.prbody.heading: '#### Test Options'
filter.prbody.optionIndex: 3

Просмотреть файл

@ -41,7 +41,7 @@ class SimpleLightningModel(pl.LightningModule):
acc = accuracy(preds, y, 'multiclass', num_classes=10)
if stage:
self.log(f"default", loss, prog_bar=False)
self.log(f"default", acc, prog_bar=False)
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)

Просмотреть файл

@ -5,7 +5,12 @@ import unittest
import torch
from nni.compression.pytorch.speedup.jit_translate import parse_aten_schema_version_1_8_x, table_fix_schema, special_treat_dict
from nni.compression.pytorch.speedup.jit_translate import (
parse_aten_schema,
parse_aten_schema_version_1_8_x,
table_fix_schema,
special_treat_dict
)
def parse_aten_schema_origin(schema: str):
positional_num = 0
@ -40,7 +45,10 @@ class SchemaParserTestCase(unittest.TestCase):
if op_with_overload in table_fix_schema:
continue
positional_num_origin, keyword_list_origin, special_treat_origin = parse_aten_schema_origin(schema)
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema_version_1_8_x(schema)
if torch.__version__ < '1.9.0':
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema_version_1_8_x(schema)
else:
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema(schema)
assert positional_num_origin == positional_num_manual
assert keyword_list_origin == keyword_list_manual

Просмотреть файл

@ -41,7 +41,7 @@ class SimpleLightningModel(pl.LightningModule):
acc = accuracy(preds, y, 'multiclass', num_classes=10)
if stage:
self.log(f"default", loss, prog_bar=False)
self.log(f"default", acc, prog_bar=False)
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)