зеркало из https://github.com/microsoft/nni.git
[Compression] fix tests & bugs (#5444)
This commit is contained in:
Родитель
d23dec38e2
Коммит
f8d85ce352
|
@ -85,14 +85,23 @@ class ConcreteTracer(TracerBase):
|
|||
_orig_contains: ([], False, None),
|
||||
_orig_index: ([], False, None),
|
||||
|
||||
# force-traced function
|
||||
# force-traced function (the factory functions of tensor creation)
|
||||
torch.arange: ([], True, None),
|
||||
torch.empty: ([], True, None),
|
||||
torch.eye: ([], True, None),
|
||||
torch.full: ([], True, None),
|
||||
torch.linspace: ([], True, None),
|
||||
torch.logspace: ([], True, None),
|
||||
torch.ones: ([], True, None),
|
||||
torch.rand: ([], True, None),
|
||||
torch.randn: ([], True, None),
|
||||
torch.randint: ([], True, None),
|
||||
torch.rand_like: ([], True, None),
|
||||
torch.randn_like: ([], True, None),
|
||||
torch.randint_like: ([], True, None),
|
||||
torch.randn: ([], True, None),
|
||||
# torch.rand_like: ([], True, None), # seems that xxx_like will not directly call torch._TensorBase.xxx
|
||||
# torch.randn_like: ([], True, None),
|
||||
# torch.randint_like: ([], True, None),
|
||||
torch.randperm: ([], True, None),
|
||||
torch.tensor: ([], True, None),
|
||||
torch.zeros: ([], True, None),
|
||||
|
||||
# method
|
||||
Sequential.__getitem__: ([], False, operator.getitem),
|
||||
|
@ -276,13 +285,6 @@ class ConcreteTracer(TracerBase):
|
|||
similar to _symbolic_trace.Tracer.create_proxy.
|
||||
use the 'run_target' to actually execute the code, and store the value in 'value' field.
|
||||
"""
|
||||
args_ = self.create_arg(args)
|
||||
kwargs_ = self.create_arg(kwargs)
|
||||
assert isinstance(args_, tuple)
|
||||
assert isinstance(kwargs_, dict)
|
||||
|
||||
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
|
||||
|
||||
def upwrapper(obj: Any):
|
||||
while _orig_isinstance(obj, ep.ConcreteProxy):
|
||||
obj = obj.value
|
||||
|
@ -293,6 +295,13 @@ class ConcreteTracer(TracerBase):
|
|||
# real value by execution
|
||||
value_unwrapped = self.run_target(kind, target, args_unwrapped, kwargs_unwrapped)
|
||||
|
||||
args_ = self.create_arg(args)
|
||||
kwargs_ = self.create_arg(kwargs)
|
||||
assert isinstance(args_, tuple)
|
||||
assert isinstance(kwargs_, dict)
|
||||
|
||||
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
|
||||
|
||||
proxy = self.proxy(value_unwrapped, node)
|
||||
self.node_to_originating_module[proxy.node] = self.current_module_qualified_name
|
||||
return proxy
|
||||
|
@ -526,6 +535,19 @@ class ConcreteTracer(TracerBase):
|
|||
such as '__main__.FooModel' or '__main__.bar_func'. the namespace is
|
||||
always needed.
|
||||
"""
|
||||
# fill default values
|
||||
args = inspect.getfullargspec(root.forward).args[1:]
|
||||
defaults = inspect.getfullargspec(root.forward).defaults
|
||||
defaults = tuple() if defaults is None else defaults
|
||||
if isinstance(concrete_args, (tuple, list)):
|
||||
concrete_args = (*concrete_args, *defaults[len(concrete_args) + len(defaults) - len(args):])
|
||||
else:
|
||||
kv_default = {k: v for k, v in zip(args[-len(defaults):], defaults)}
|
||||
concrete_args = {
|
||||
**concrete_args,
|
||||
**{n: kv_default[n] for n in args if n not in concrete_args}
|
||||
}
|
||||
|
||||
# preprocess arguments
|
||||
autowrap_modules = autowrap_modules if autowrap_modules is not None else tuple()
|
||||
autowrap_leaf_function = autowrap_leaf_function if autowrap_leaf_function is not None else {}
|
||||
|
@ -1396,6 +1418,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
|||
fx.GraphModule: a Module created from the recorded operations from ``root``.
|
||||
"""
|
||||
tracer = ConcreteTracer()
|
||||
|
||||
graph = tracer.trace(root,
|
||||
autowrap_leaf_function = autowrap_leaf_function,
|
||||
autowrap_leaf_class = autowrap_leaf_class,
|
||||
|
|
|
@ -6,8 +6,10 @@ if TYPE_CHECKING:
|
|||
from .concrete_tracer import ConcreteTracer
|
||||
|
||||
import ast
|
||||
import builtins
|
||||
import inspect
|
||||
import logging
|
||||
import platform
|
||||
|
||||
from textwrap import dedent
|
||||
from types import MethodType, FunctionType
|
||||
|
@ -21,6 +23,7 @@ from .utils import (
|
|||
_orig_len,
|
||||
_orig_dict,
|
||||
_orig_zip,
|
||||
_orig_tuple,
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -235,20 +238,27 @@ class OperatorPatcher:
|
|||
assert _orig_len(closures) == _orig_len(co_freevars)
|
||||
closure_dict = _orig_dict(_orig_zip(co_freevars, [c.cell_contents for c in closures]))
|
||||
|
||||
var_dict = {}
|
||||
exec(
|
||||
# use func.__code__.co_filename to make the new function easily debuggable.
|
||||
compile(new_tree, func_inner.__code__.co_filename, 'exec'),
|
||||
{
|
||||
'patch_run': OperatorPatcherContext.patch_run,
|
||||
**func_inner.__globals__,
|
||||
**closure_dict,
|
||||
},
|
||||
var_dict)
|
||||
if the_self is not None:
|
||||
return var_dict['new_func'].__get__(the_self)
|
||||
else:
|
||||
return var_dict['new_func']
|
||||
tuple_wrapped = tuple
|
||||
try:
|
||||
if platform.python_version_tuple() < ('3', '9'):
|
||||
setattr(builtins, 'tuple', _orig_tuple)
|
||||
var_dict = {}
|
||||
exec(
|
||||
# use func.__code__.co_filename to make the new function easily debuggable.
|
||||
compile(new_tree, func_inner.__code__.co_filename, 'exec'),
|
||||
{
|
||||
'patch_run': OperatorPatcherContext.patch_run,
|
||||
**func_inner.__globals__,
|
||||
**closure_dict,
|
||||
},
|
||||
var_dict)
|
||||
if the_self is not None:
|
||||
return var_dict['new_func'].__get__(the_self)
|
||||
else:
|
||||
return var_dict['new_func']
|
||||
finally:
|
||||
if platform.python_version_tuple() < ('3', '9'):
|
||||
setattr(builtins, 'tuple', tuple_wrapped)
|
||||
|
||||
class OperatorPatcherContext:
|
||||
ctx_tracer: Optional['ConcreteTracer'] = None
|
||||
|
|
|
@ -12,6 +12,7 @@ import operator
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.fx.node import Node
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from .utils import randomize_tensor_inplace, randomize_if_tensor, tree_map_zip, torch_float_dtype
|
||||
|
||||
|
@ -398,7 +399,26 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
|
|||
assert len(node.args) == 2
|
||||
input_grad = tree_map_zip(lambda t, m: (t * m).type_as(t) if isinstance(m, torch.Tensor) else t, \
|
||||
model_speedup.node_infos[node].output_grad, model_speedup.node_infos[node].output_masks)
|
||||
model_speedup.indirect_pass_grad(node.args[0], input_grad)
|
||||
arg_1_val = model_speedup.node_infos[node.args[1]].output_randomize if isinstance(node.args[1], Node) else node.args[1]
|
||||
|
||||
input_node_info = model_speedup.node_infos[node.args[0]]
|
||||
flat_args, spec = tree_flatten(input_node_info.output_grad)
|
||||
flat_grads = [None for _ in range(len(flat_args))]
|
||||
flat_grads[arg_1_val] = input_grad
|
||||
input_grads = tree_unflatten(flat_grads, spec)
|
||||
|
||||
def add_grad(grad, input_grad):
|
||||
if isinstance(input_grad, torch.Tensor):
|
||||
if grad is not None and input_grad is not None:
|
||||
return grad + input_grad
|
||||
elif grad is None:
|
||||
return input_grad
|
||||
else:
|
||||
return grad
|
||||
else:
|
||||
return grad
|
||||
|
||||
model_speedup.node_infos[node].output_grad = tree_map_zip(add_grad, model_speedup.node_infos[node.args[0]].output_grad, input_grads)
|
||||
|
||||
def detect(self, model_speedup: 'ModelSpeedup', node: Node) -> bool:
|
||||
return self.detect_helper(model_speedup, node) is not None
|
||||
|
|
|
@ -45,7 +45,7 @@ class DefaultReplacer(Replacer):
|
|||
|
||||
def replace_modules(self, speedup: 'ModelSpeedup'):
|
||||
for node, node_info in speedup.node_infos.items():
|
||||
if node.op == 'call_module':
|
||||
if node.op == 'call_module' and not node_info.replaced:
|
||||
# module = speedup.fetch_attr(node.target)
|
||||
# module_type = module._get_name()
|
||||
module = get_nested_attr(speedup.bound_model, node.target)
|
||||
|
|
|
@ -366,6 +366,7 @@ class ModuleWrapper(torch.nn.Module):
|
|||
for idx, target in enumerate(outputs):
|
||||
target_name = f'{OUTPUT_PREFIX}{idx}'
|
||||
new_outputs.append(self.patch_helper(target_name, target))
|
||||
new_outputs = type(outputs)(new_outputs)
|
||||
elif isinstance(outputs, dict):
|
||||
new_outputs = {}
|
||||
for output_name, target in outputs.items():
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, overload
|
||||
|
||||
|
@ -97,17 +96,16 @@ class TeacherModelBasedDistiller(Distiller):
|
|||
target_space.lambda_ = target_space.lambda_ if target_space.lambda_ is not None else 1.
|
||||
|
||||
def _register_teacher_wrappers(self):
|
||||
link2targets = defaultdict(set)
|
||||
teacher_config_list = []
|
||||
for _, ts in self._target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
for link in target_space.link:
|
||||
teacher_config_list.append({
|
||||
'op_names': [link],
|
||||
'target_names': [target_name],
|
||||
'target_settings': {
|
||||
target_name: deepcopy(target_space.setting)
|
||||
}
|
||||
})
|
||||
link2targets[link].add(target_name)
|
||||
teacher_config_list = [{
|
||||
'op_names': [link],
|
||||
'target_names': list(target_names)
|
||||
} for link, target_names in link2targets.items()]
|
||||
return register_wrappers(self.teacher_model, teacher_config_list, mode=self.mode)
|
||||
|
||||
def wrap_teacher_model(self):
|
||||
|
|
|
@ -117,19 +117,15 @@ class SlimPruner(Pruner):
|
|||
params = [{"params": scaling_factors}]
|
||||
optimizer = Adam(params, 1e-2)
|
||||
|
||||
def optimizer_task():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[])
|
||||
evaluator.patch_optimizer_step(before_step_tasks=[optimizer.step], after_step_tasks=[optimizer.zero_grad])
|
||||
|
||||
def _patch_loss(self, evaluator: Evaluator):
|
||||
def loss_patch(original_loss, batch):
|
||||
reg_loss = 0.
|
||||
reg_loss = torch.tensor(0., device=original_loss.device)
|
||||
count = 0
|
||||
for _, target_scaling_factor in self.scaling_factors.items():
|
||||
for _, scaling_factor in target_scaling_factor.items():
|
||||
reg_loss += scaling_factor.norm(p=1) # type: ignore
|
||||
reg_loss = reg_loss + scaling_factor.norm(p=1) # type: ignore
|
||||
count += 1
|
||||
if count > 0:
|
||||
reg_loss = self.regular_scale * reg_loss / count
|
||||
|
|
|
@ -14,7 +14,7 @@ schedules:
|
|||
always: true
|
||||
|
||||
variables:
|
||||
filter.modified.globs: 'examples/model_compress/**,nni/algorithms/compression/**,nni/compression/**,pipelines/full-test-compression.yml,test/algo/compression/**,nni/contrib/compression/**'
|
||||
filter.modified.globs: 'examples/model_compress/**,nni/algorithms/compression/**,nni/compression/**,pipelines/full-test-compression.yml,test/algo/compression/**,nni/contrib/compression/**,nni/common/**'
|
||||
filter.prbody.heading: '#### Test Options'
|
||||
filter.prbody.optionIndex: 3
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class SimpleLightningModel(pl.LightningModule):
|
|||
acc = accuracy(preds, y, 'multiclass', num_classes=10)
|
||||
|
||||
if stage:
|
||||
self.log(f"default", loss, prog_bar=False)
|
||||
self.log(f"default", acc, prog_bar=False)
|
||||
self.log(f"{stage}_loss", loss, prog_bar=True)
|
||||
self.log(f"{stage}_acc", acc, prog_bar=True)
|
||||
|
||||
|
|
|
@ -5,7 +5,12 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from nni.compression.pytorch.speedup.jit_translate import parse_aten_schema_version_1_8_x, table_fix_schema, special_treat_dict
|
||||
from nni.compression.pytorch.speedup.jit_translate import (
|
||||
parse_aten_schema,
|
||||
parse_aten_schema_version_1_8_x,
|
||||
table_fix_schema,
|
||||
special_treat_dict
|
||||
)
|
||||
|
||||
def parse_aten_schema_origin(schema: str):
|
||||
positional_num = 0
|
||||
|
@ -40,7 +45,10 @@ class SchemaParserTestCase(unittest.TestCase):
|
|||
if op_with_overload in table_fix_schema:
|
||||
continue
|
||||
positional_num_origin, keyword_list_origin, special_treat_origin = parse_aten_schema_origin(schema)
|
||||
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema_version_1_8_x(schema)
|
||||
if torch.__version__ < '1.9.0':
|
||||
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema_version_1_8_x(schema)
|
||||
else:
|
||||
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema(schema)
|
||||
|
||||
assert positional_num_origin == positional_num_manual
|
||||
assert keyword_list_origin == keyword_list_manual
|
||||
|
|
|
@ -41,7 +41,7 @@ class SimpleLightningModel(pl.LightningModule):
|
|||
acc = accuracy(preds, y, 'multiclass', num_classes=10)
|
||||
|
||||
if stage:
|
||||
self.log(f"default", loss, prog_bar=False)
|
||||
self.log(f"default", acc, prog_bar=False)
|
||||
self.log(f"{stage}_loss", loss, prog_bar=True)
|
||||
self.log(f"{stage}_acc", acc, prog_bar=True)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче