зеркало из https://github.com/microsoft/nni.git
[Compression] fix typehints (#4800)
This commit is contained in:
Родитель
d49864ce28
Коммит
cbac2c5c0f
|
@ -1,9 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Tuple, Any
|
||||
from typing import Any, List, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
@ -29,7 +28,33 @@ def _setattr(model: Module, name: str, module: Module):
|
|||
name_list = name.split(".")
|
||||
setattr(parent_module, name_list[-1], module)
|
||||
else:
|
||||
raise '{} not exist.'.format(name)
|
||||
raise Exception('{} not exist.'.format(name))
|
||||
|
||||
|
||||
class ModuleWrapper(Module):
|
||||
"""
|
||||
Wrap a module to enable data parallel, forward method customization and buffer registeration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module
|
||||
The module user wants to compress.
|
||||
config
|
||||
The configurations that users specify for compression.
|
||||
module_name
|
||||
The name of the module to compress, wrapper module shares same name.
|
||||
"""
|
||||
|
||||
def __init__(self, module: Module, module_name: str, config: Dict):
|
||||
super().__init__()
|
||||
# origin layer information
|
||||
self.module = module
|
||||
self.name = module_name
|
||||
# config information
|
||||
self.config = config
|
||||
|
||||
def forward(self, *inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Compressor:
|
||||
|
@ -46,7 +71,7 @@ class Compressor:
|
|||
|
||||
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
|
||||
self.is_wrapped = False
|
||||
if model is not None:
|
||||
if model is not None and config_list is not None:
|
||||
self.reset(model=model, config_list=config_list)
|
||||
else:
|
||||
_logger.warning('This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.')
|
||||
|
@ -63,6 +88,7 @@ class Compressor:
|
|||
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
|
||||
"""
|
||||
assert isinstance(model, Module), 'Only support compressing pytorch Module, but the type of model is {}.'.format(type(model))
|
||||
|
||||
self.bound_model = model
|
||||
self.config_list = config_list
|
||||
self.validate_config(model=model, config_list=config_list)
|
||||
|
@ -70,7 +96,7 @@ class Compressor:
|
|||
self._unwrap_model()
|
||||
|
||||
self._modules_to_compress = None
|
||||
self.modules_wrapper = collections.OrderedDict()
|
||||
self.modules_wrapper = {}
|
||||
for layer, config in self._detect_modules_to_compress():
|
||||
wrapper = self._wrap_modules(layer, config)
|
||||
self.modules_wrapper[layer.name] = wrapper
|
||||
|
@ -93,6 +119,8 @@ class Compressor:
|
|||
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
|
||||
The model will be instrumented and user should never edit it after calling this method.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
|
||||
if self._modules_to_compress is None:
|
||||
self._modules_to_compress = []
|
||||
for name, module in self.bound_model.named_modules():
|
||||
|
@ -118,6 +146,8 @@ class Compressor:
|
|||
Optional[Dict]
|
||||
The retrieved configuration for this layer, if None, this layer should not be compressed.
|
||||
"""
|
||||
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
|
||||
ret = None
|
||||
for config in self.config_list:
|
||||
config = config.copy()
|
||||
|
@ -142,32 +172,26 @@ class Compressor:
|
|||
return None
|
||||
return ret
|
||||
|
||||
def get_modules_wrapper(self) -> Dict[str, Module]:
|
||||
def get_modules_wrapper(self) -> Dict[str, ModuleWrapper]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
OrderedDict[str, Module]
|
||||
An ordered dict, key is the name of the module, value is the wrapper of the module.
|
||||
Dict[str, ModuleWrapper]
|
||||
An dict, key is the name of the module, value is the wrapper of the module.
|
||||
"""
|
||||
return self.modules_wrapper
|
||||
raise NotImplementedError
|
||||
|
||||
def _wrap_model(self):
|
||||
"""
|
||||
Wrap all modules that needed to be compressed.
|
||||
"""
|
||||
if not self.is_wrapped:
|
||||
for _, wrapper in reversed(self.get_modules_wrapper().items()):
|
||||
_setattr(self.bound_model, wrapper.name, wrapper)
|
||||
self.is_wrapped = True
|
||||
raise NotImplementedError
|
||||
|
||||
def _unwrap_model(self):
|
||||
"""
|
||||
Unwrap all modules that needed to be compressed.
|
||||
"""
|
||||
if self.is_wrapped:
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
_setattr(self.bound_model, wrapper.name, wrapper.module)
|
||||
self.is_wrapped = False
|
||||
raise NotImplementedError
|
||||
|
||||
def set_wrappers_attribute(self, name: str, value: Any):
|
||||
"""
|
||||
|
@ -182,7 +206,7 @@ class Compressor:
|
|||
value
|
||||
Value of the variable.
|
||||
"""
|
||||
for wrapper in self.get_modules_wrapper():
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
wrapper.register_buffer(name, value.clone())
|
||||
else:
|
||||
|
@ -216,8 +240,10 @@ class Compressor:
|
|||
Dict[int, List[str]]
|
||||
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
|
||||
"""
|
||||
self._unwrap_model()
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
|
||||
self._unwrap_model()
|
||||
module_groups = {}
|
||||
for name, module in self.bound_model.named_modules():
|
||||
if module == self.bound_model:
|
||||
|
@ -259,7 +285,7 @@ class Compressor:
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _wrap_modules(self, layer: LayerInfo, config: Dict):
|
||||
def _wrap_modules(self, layer: LayerInfo, config: Dict) -> ModuleWrapper:
|
||||
"""
|
||||
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
|
||||
|
||||
|
@ -297,4 +323,6 @@ class Compressor:
|
|||
torch.nn.Module
|
||||
model with specified modules compressed.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
return self.bound_model
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, OrderedDict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .compressor import Compressor, LayerInfo, _setattr
|
||||
|
||||
|
@ -37,15 +38,15 @@ class PrunerModuleWrapper(Module):
|
|||
# config information
|
||||
self.config = config
|
||||
|
||||
self.weight = Parameter(torch.empty(self.module.weight.size()))
|
||||
|
||||
# register buffer for mask
|
||||
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
|
||||
self.bias = Parameter(torch.empty(self.module.bias.size()))
|
||||
else:
|
||||
self.register_buffer("bias_mask", None)
|
||||
pruning_target_names = ['weight', 'bias']
|
||||
for pruning_target_name in pruning_target_names:
|
||||
pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
|
||||
pruning_target = getattr(self.module, pruning_target_name, None)
|
||||
if hasattr(self.module, pruning_target_name) and pruning_target is not None:
|
||||
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
|
||||
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
|
||||
else:
|
||||
self.register_buffer(pruning_target_mask_name, None)
|
||||
|
||||
def _weight2buffer(self):
|
||||
"""
|
||||
|
@ -89,7 +90,17 @@ class Pruner(Compressor):
|
|||
def reset(self, model: Optional[Module] = None, config_list: Optional[List[Dict]] = None):
|
||||
super().reset(model=model, config_list=config_list)
|
||||
|
||||
def _wrap_modules(self, layer: LayerInfo, config: Dict):
|
||||
def get_modules_wrapper(self) -> OrderedDict[str, PrunerModuleWrapper]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
OrderedDict[str, PrunerModuleWrapper]
|
||||
An ordered dict, key is the name of the module, value is the wrapper of the module.
|
||||
"""
|
||||
assert self.modules_wrapper is not None, 'Bound model has not be wrapped.'
|
||||
return self.modules_wrapper
|
||||
|
||||
def _wrap_modules(self, layer: LayerInfo, config: Dict) -> PrunerModuleWrapper:
|
||||
"""
|
||||
Create a wrapper module to replace the original one.
|
||||
|
||||
|
@ -99,6 +110,11 @@ class Pruner(Compressor):
|
|||
The layer to instrument the mask.
|
||||
config
|
||||
The configuration for generating the mask.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PrunerModuleWrapper
|
||||
The wrapper of the module in layerinfo.
|
||||
"""
|
||||
_logger.debug("Module detected to compress : %s.", layer.name)
|
||||
wrapper = PrunerModuleWrapper(layer.module, layer.name, config)
|
||||
|
@ -114,8 +130,10 @@ class Pruner(Compressor):
|
|||
Wrap all modules that needed to be compressed.
|
||||
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
|
||||
if not self.is_wrapped:
|
||||
for _, wrapper in reversed(self.get_modules_wrapper().items()):
|
||||
for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
|
||||
_setattr(self.bound_model, wrapper.name, wrapper)
|
||||
wrapper._weight2buffer()
|
||||
self.is_wrapped = True
|
||||
|
@ -125,8 +143,10 @@ class Pruner(Compressor):
|
|||
Unwrap all modules that needed to be compressed.
|
||||
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
|
||||
if self.is_wrapped:
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
_setattr(self.bound_model, wrapper.name, wrapper.module)
|
||||
wrapper._weight2parameter()
|
||||
self.is_wrapped = False
|
||||
|
@ -191,7 +211,7 @@ class Pruner(Compressor):
|
|||
dim
|
||||
The pruned dim.
|
||||
"""
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
weight_mask = wrapper.weight_mask
|
||||
mask_size = weight_mask.size()
|
||||
if len(mask_size) == 1:
|
||||
|
|
|
@ -5,7 +5,7 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from typing import List, Dict, Tuple, Optional, Union
|
||||
|
||||
import json_tricks
|
||||
import torch
|
||||
|
@ -19,7 +19,7 @@ class Task:
|
|||
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
|
||||
_reference_counter = {}
|
||||
|
||||
def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str,
|
||||
def __init__(self, task_id: int, model_path: Union[str, Path], masks_path: Union[str, Path], config_list_path: Union[str, Path],
|
||||
speedup: Optional[bool] = True, finetune: Optional[bool] = True, evaluate: Optional[bool] = True):
|
||||
"""
|
||||
Parameters
|
||||
|
@ -87,7 +87,7 @@ class Task:
|
|||
config_list = json_tricks.load(f)
|
||||
return model, masks, config_list
|
||||
|
||||
def referenced_paths(self) -> List[str]:
|
||||
def referenced_paths(self) -> List[Union[str, Path]]:
|
||||
"""
|
||||
Return the path list that need to count reference in this task.
|
||||
"""
|
||||
|
@ -111,7 +111,7 @@ class Task:
|
|||
|
||||
|
||||
class TaskResult:
|
||||
def __init__(self, task_id: int, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
|
||||
def __init__(self, task_id: Union[int, str], compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
|
||||
pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None:
|
||||
"""
|
||||
Parameters
|
||||
|
|
|
@ -82,12 +82,13 @@ class AMCTaskGenerator(TaskGenerator):
|
|||
|
||||
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
|
||||
# append experience & update agent policy
|
||||
if task_result.task_id != 'origin':
|
||||
if self.action is not None:
|
||||
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
|
||||
self.T.append([reward, self.observation, observation, self.action, done])
|
||||
self.observation = observation.copy()
|
||||
|
||||
if done:
|
||||
assert task_result.score is not None, 'task_result.score should not be None if environment is done.'
|
||||
final_reward = task_result.score - 1
|
||||
# agent observe and update policy
|
||||
for _, s_t, s_t1, a_t, d_t in self.T:
|
||||
|
|
|
@ -46,7 +46,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
|
|||
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
|
||||
self._iterative_pruner_reset(model, new_config_list, masks)
|
||||
self.iterative_pruner.compress()
|
||||
_, _, _, _, config_list = self.iterative_pruner.get_best_result()
|
||||
best_result = self.iterative_pruner.get_best_result()
|
||||
assert best_result is not None, 'Best result does not exist, iterative pruner may not start pruning.'
|
||||
_, _, _, _, config_list = best_result
|
||||
return config_list
|
||||
|
||||
|
||||
|
@ -149,7 +151,7 @@ class AutoCompressPruner(IterativePruner):
|
|||
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
|
||||
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
|
||||
dummy_input: Optional[Tensor] = None, evaluator: Callable[[Module], float] = None):
|
||||
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
|
||||
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
|
||||
origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import List, Dict, Tuple, Callable, Optional
|
|||
from schema import And, Or, Optional as SchemaOptional, SchemaError
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -77,10 +77,10 @@ INTERNAL_SCHEMA = {
|
|||
|
||||
|
||||
class BasicPruner(Pruner):
|
||||
def __init__(self, model: Module, config_list: List[Dict]):
|
||||
self.data_collector: DataCollector = None
|
||||
self.metrics_calculator: MetricsCalculator = None
|
||||
self.sparsity_allocator: SparsityAllocator = None
|
||||
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
|
||||
self.data_collector: Optional[DataCollector] = None
|
||||
self.metrics_calculator: Optional[MetricsCalculator] = None
|
||||
self.sparsity_allocator: Optional[SparsityAllocator] = None
|
||||
|
||||
super().__init__(model, config_list)
|
||||
|
||||
|
@ -114,6 +114,8 @@ class BasicPruner(Pruner):
|
|||
Tuple[Module, Dict]
|
||||
Return the wrapped model and mask.
|
||||
"""
|
||||
assert self.bound_model is not None and self.config_list is not None, 'Model and/or config_list are not set in this pruner, please set them by reset() before compress().'
|
||||
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
|
||||
data = self.data_collector.collect()
|
||||
_logger.debug('Collected Data:\n%s', data)
|
||||
metrics = self.metrics_calculator.calculate_metrics(data)
|
||||
|
@ -553,8 +555,8 @@ class SlimPruner(BasicPruner):
|
|||
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
|
||||
def patched_criterion(input_tensor: Tensor, target: Tensor):
|
||||
sum_l1 = 0
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
sum_l1 += torch.norm(wrapper.module.weight, p=1)
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
sum_l1 += torch.norm(wrapper.module.weight, p=1) # type: ignore
|
||||
return criterion(input_tensor, target) + self._scale * sum_l1
|
||||
return patched_criterion
|
||||
|
||||
|
@ -654,11 +656,11 @@ class ActivationPruner(BasicPruner):
|
|||
|
||||
def _choose_activation(self, activation: str = 'relu') -> Callable:
|
||||
if activation == 'relu':
|
||||
return nn.functional.relu
|
||||
return F.relu
|
||||
elif activation == 'relu6':
|
||||
return nn.functional.relu6
|
||||
return F.relu6
|
||||
else:
|
||||
raise 'Unsupported activatoin {}'.format(activation)
|
||||
raise Exception('Unsupported activatoin {}'.format(activation))
|
||||
|
||||
def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]:
|
||||
assert len(buffer) == 0, 'Buffer pass to activation pruner collector is not empty.'
|
||||
|
@ -684,7 +686,7 @@ class ActivationPruner(BasicPruner):
|
|||
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
|
||||
1, collector_infos=[collector_info])
|
||||
else:
|
||||
self.data_collector.reset(collector_infos=[collector_info])
|
||||
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
|
||||
if self.metrics_calculator is None:
|
||||
self.metrics_calculator = self._get_metrics_calculator()
|
||||
if self.sparsity_allocator is None:
|
||||
|
@ -999,13 +1001,13 @@ class TaylorFOWeightPruner(BasicPruner):
|
|||
return (weight_tensor.detach() * grad.detach()).data.pow(2)
|
||||
|
||||
def reset_tools(self):
|
||||
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()}
|
||||
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector)
|
||||
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
|
||||
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) # type: ignore
|
||||
if self.data_collector is None:
|
||||
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
|
||||
1, collector_infos=[collector_info])
|
||||
else:
|
||||
self.data_collector.reset(collector_infos=[collector_info])
|
||||
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
|
||||
if self.metrics_calculator is None:
|
||||
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0)
|
||||
if self.sparsity_allocator is None:
|
||||
|
@ -1095,24 +1097,26 @@ class ADMMPruner(BasicPruner):
|
|||
For detailed example please refer to :githublink:`examples/model_compress/pruning/admm_pruning_torch.py <examples/model_compress/pruning/admm_pruning_torch.py>`
|
||||
"""
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
|
||||
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], trainer: Callable[[Module, Optimizer, Callable], None],
|
||||
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int,
|
||||
training_epochs: int, granularity: str = 'fine-grained'):
|
||||
self.trainer = trainer
|
||||
if isinstance(traced_optimizer, OptimizerConstructHelper):
|
||||
self.optimizer_helper = traced_optimizer
|
||||
else:
|
||||
assert model is not None, 'Model is required if traced_optimizer is provided.'
|
||||
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
|
||||
self.criterion = criterion
|
||||
self.iterations = iterations
|
||||
self.training_epochs = training_epochs
|
||||
assert granularity in ['fine-grained', 'coarse-grained']
|
||||
self.granularity = granularity
|
||||
self.Z, self.U = {}, {}
|
||||
super().__init__(model, config_list)
|
||||
|
||||
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
|
||||
def reset(self, model: Module, config_list: List[Dict]):
|
||||
super().reset(model, config_list)
|
||||
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()}
|
||||
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
|
||||
self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()}
|
||||
|
||||
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
|
||||
|
@ -1156,6 +1160,8 @@ class ADMMPruner(BasicPruner):
|
|||
Tuple[Module, Dict]
|
||||
Return the wrapped model and mask.
|
||||
"""
|
||||
assert self.bound_model is not None
|
||||
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
|
||||
for i in range(self.iterations):
|
||||
_logger.info('======= ADMM Iteration %d Start =======', i)
|
||||
data = self.data_collector.collect()
|
||||
|
@ -1169,11 +1175,10 @@ class ADMMPruner(BasicPruner):
|
|||
self.Z[name] = self.Z[name].mul(mask['weight'])
|
||||
self.U[name] = self.U[name] + data[name] - self.Z[name]
|
||||
|
||||
self.Z = None
|
||||
self.U = None
|
||||
self.Z, self.U = {}, {}
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
metrics = self.metrics_calculator.calculate_metrics(data)
|
||||
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
|
||||
masks = self.sparsity_allocator.generate_sparsity(metrics)
|
||||
|
||||
self.load_masks(masks)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple, Callable, Optional
|
||||
from typing import Dict, List, Tuple, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -36,8 +36,8 @@ class PruningScheduler(BasePruningScheduler):
|
|||
reset_weight
|
||||
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
|
||||
"""
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
|
||||
speedup: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
|
||||
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
|
||||
reset_weight: bool = False):
|
||||
self.pruner = pruner
|
||||
self.task_generator = task_generator
|
||||
|
@ -155,5 +155,5 @@ class PruningScheduler(BasePruningScheduler):
|
|||
torch.cuda.empty_cache()
|
||||
return result
|
||||
|
||||
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
|
||||
def get_best_result(self) -> Optional[Tuple[Union[int, str], Module, Dict[str, Dict[str, Tensor]], Optional[float], List[Dict]]]:
|
||||
return self.task_generator.get_best_result()
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Callable, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
@ -293,9 +294,9 @@ class SimulatedAnnealingPruner(IterativePruner):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
model : Module
|
||||
model : Optional[Module]
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
config_list : List[Dict]
|
||||
config_list : Optional[List[Dict]]
|
||||
The origin config list provided by the user.
|
||||
evaluator : Callable[[Module], float]
|
||||
Evaluate the pruned model and give a score.
|
||||
|
@ -312,7 +313,7 @@ class SimulatedAnnealingPruner(IterativePruner):
|
|||
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
|
||||
pruning_params : Dict
|
||||
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
|
||||
log_dir : str
|
||||
log_dir : Union[str, Path]
|
||||
The log directory use to saving the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
|
@ -337,9 +338,9 @@ class SimulatedAnnealingPruner(IterativePruner):
|
|||
For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>`
|
||||
"""
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100,
|
||||
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], evaluator: Callable[[Module], float], start_temperature: float = 100,
|
||||
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
|
||||
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
|
||||
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
|
@ -350,7 +351,7 @@ class SimulatedAnnealingPruner(IterativePruner):
|
|||
log_dir=log_dir,
|
||||
keep_intermediate_result=keep_intermediate_result)
|
||||
if 'traced_optimizer' in pruning_params:
|
||||
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
|
||||
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
|
||||
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=False)
|
||||
|
|
|
@ -7,7 +7,8 @@ from typing import Dict, List, Tuple, Callable
|
|||
|
||||
import torch
|
||||
from torch import autograd, Tensor
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer, Adam
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
|
||||
|
@ -41,15 +42,15 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
|
|||
"""
|
||||
def __init__(self, module: Module, module_name: str, config: Dict):
|
||||
super().__init__(module, module_name, config)
|
||||
self.weight_score = Parameter(torch.empty(self.weight.size()))
|
||||
self.weight_score = Parameter(torch.empty(self.weight.size())) # type: ignore
|
||||
torch.nn.init.constant_(self.weight_score, val=0.0)
|
||||
|
||||
def forward(self, *inputs):
|
||||
# apply mask to weight, bias
|
||||
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach_()`
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach_()))
|
||||
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
self.module.bias = torch.mul(self.bias, self.bias_mask)
|
||||
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
|
||||
return self.module(*inputs)
|
||||
|
||||
|
||||
|
@ -58,7 +59,7 @@ class _StraightThrough(autograd.Function):
|
|||
Straight through the gradient to the score, then the score = initial_score + sum(-lr * grad(weight) * weight).
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(self, score, masks):
|
||||
def forward(ctx, score, masks):
|
||||
return masks
|
||||
|
||||
@staticmethod
|
||||
|
@ -71,12 +72,13 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
|
|||
Collect all weight_score in wrappers as data used to calculate metrics.
|
||||
"""
|
||||
def collect(self) -> Dict[str, Tensor]:
|
||||
assert self.compressor.bound_model is not None
|
||||
for _ in range(self.training_epochs):
|
||||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
data = {}
|
||||
for _, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
data[wrapper.name] = wrapper.weight_score.data
|
||||
data[wrapper.name] = wrapper.weight_score.data # type: ignore
|
||||
return data
|
||||
|
||||
|
||||
|
@ -193,6 +195,7 @@ class MovementPruner(BasicPruner):
|
|||
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
|
||||
|
||||
# use Adam to update the weight_score
|
||||
assert self.bound_model is not None
|
||||
params = [{"params": [p for n, p in self.bound_model.named_parameters() if "weight_score" in n and p.requires_grad]}]
|
||||
optimizer = Adam(params, 1e-2)
|
||||
self.step_counter = 0
|
||||
|
@ -205,10 +208,10 @@ class MovementPruner(BasicPruner):
|
|||
if self.step_counter > self.warm_up_step:
|
||||
self.cubic_schedule(self.step_counter)
|
||||
data = {}
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
data[wrapper.name] = wrapper.weight_score.data
|
||||
metrics = self.metrics_calculator.calculate_metrics(data)
|
||||
masks = self.sparsity_allocator.generate_sparsity(metrics)
|
||||
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
|
||||
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
|
||||
self.load_masks(masks)
|
||||
|
||||
if self.data_collector is None:
|
||||
|
@ -232,15 +235,15 @@ class MovementPruner(BasicPruner):
|
|||
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config)
|
||||
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
|
||||
# move newly registered buffers to the same device of weight
|
||||
wrapper.to(layer.module.weight.device)
|
||||
wrapper.to(layer.module.weight.device) # type: ignore
|
||||
return wrapper
|
||||
|
||||
def compress(self) -> Tuple[Module, Dict]:
|
||||
# sparsity grow from 0
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
wrapper.config['total_sparsity'] = 0
|
||||
result = super().compress()
|
||||
# del weight_score
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
wrapper.weight_score = None
|
||||
return result
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo, Task, TaskResult
|
||||
from nni.algorithms.compression.v2.pytorch.base import Pruner, LayerInfo, Task, TaskResult
|
||||
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -29,7 +29,7 @@ class DataCollector:
|
|||
The compressor binded with this DataCollector.
|
||||
"""
|
||||
|
||||
def __init__(self, compressor: Compressor):
|
||||
def __init__(self, compressor: Pruner):
|
||||
self.compressor = compressor
|
||||
|
||||
def reset(self):
|
||||
|
@ -76,10 +76,10 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
|
||||
"""
|
||||
|
||||
def __init__(self, compressor: Compressor, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
|
||||
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
|
||||
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
|
||||
opt_before_tasks: List = [], opt_after_tasks: List = [],
|
||||
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Callable[[Callable], Callable] = None):
|
||||
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -165,6 +165,7 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
|
||||
def _reset_optimizer(self):
|
||||
parameter_name_map = self.compressor.get_origin2wrapped_parameter_name_map()
|
||||
assert self.compressor.bound_model is not None
|
||||
self.optimizer = self.optimizer_helper.call(self.compressor.bound_model, parameter_name_map)
|
||||
|
||||
def _patch_optimizer(self):
|
||||
|
@ -187,11 +188,11 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
self._hook_buffer[self._hook_id] = {}
|
||||
|
||||
if collector_info.hook_type == 'forward':
|
||||
self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector)
|
||||
self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
|
||||
elif collector_info.hook_type == 'backward':
|
||||
self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector)
|
||||
self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
|
||||
elif collector_info.hook_type == 'tensor':
|
||||
self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector)
|
||||
self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
|
||||
else:
|
||||
_logger.warning('Skip unsupported hook type: %s', collector_info.hook_type)
|
||||
|
||||
|
@ -210,7 +211,7 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
assert all(isinstance(layer_info, LayerInfo) for layer_info in layers)
|
||||
for layer in layers:
|
||||
self._hook_buffer[hook_id][layer.name] = []
|
||||
handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name]))
|
||||
handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name])) # type: ignore
|
||||
self._hook_handles[hook_id][layer.name] = handle
|
||||
|
||||
def _add_tensor_hook(self, hook_id: int, tensors: Dict[str, Tensor],
|
||||
|
@ -286,7 +287,7 @@ class MetricsCalculator:
|
|||
self.block_sparse_size = [1] * len(self.dim)
|
||||
if self.dim is not None:
|
||||
assert all(i >= 0 for i in self.dim)
|
||||
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
|
||||
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
|
||||
|
||||
def calculate_metrics(self, data: Dict) -> Dict[str, Tensor]:
|
||||
"""
|
||||
|
@ -334,7 +335,7 @@ class SparsityAllocator:
|
|||
Inherit the mask already in the wrapper if set True.
|
||||
"""
|
||||
|
||||
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None,
|
||||
def __init__(self, pruner: Pruner, dim: Optional[Union[int, List[int]]] = None,
|
||||
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
|
||||
self.pruner = pruner
|
||||
self.dim = dim if not isinstance(dim, int) else [dim]
|
||||
|
@ -345,7 +346,7 @@ class SparsityAllocator:
|
|||
self.block_sparse_size = [1] * len(self.dim)
|
||||
if self.dim is not None:
|
||||
assert all(i >= 0 for i in self.dim)
|
||||
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
|
||||
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
|
||||
self.continuous_mask = continuous_mask
|
||||
|
||||
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
|
||||
|
@ -384,7 +385,7 @@ class SparsityAllocator:
|
|||
weight_mask = weight_mask.expand(expand_size).reshape(reshape_size)
|
||||
|
||||
wrapper = self.pruner.get_modules_wrapper()[name]
|
||||
weight_size = wrapper.weight.data.size()
|
||||
weight_size = wrapper.weight.data.size() # type: ignore
|
||||
|
||||
if self.dim is None:
|
||||
assert weight_mask.size() == weight_size
|
||||
|
@ -401,7 +402,7 @@ class SparsityAllocator:
|
|||
expand_mask = {'weight': weight_mask.expand(weight_size).clone()}
|
||||
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
|
||||
# If we support more kind of masks, this place need refactor.
|
||||
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size():
|
||||
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): # type: ignore
|
||||
expand_mask['bias'] = weight_mask.clone()
|
||||
return expand_mask
|
||||
|
||||
|
@ -463,7 +464,7 @@ class TaskGenerator:
|
|||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
"""
|
||||
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
|
||||
origin_config_list: Optional[List[Dict]] = [], log_dir: str = '.', keep_intermediate_result: bool = False):
|
||||
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
|
||||
self._log_dir = log_dir
|
||||
self._keep_intermediate_result = keep_intermediate_result
|
||||
|
||||
|
@ -486,7 +487,7 @@ class TaskGenerator:
|
|||
self._save_data('origin', model, masks, config_list)
|
||||
|
||||
self._task_id_candidate = 0
|
||||
self._tasks: Dict[int, Task] = {}
|
||||
self._tasks: Dict[Union[int, str], Task] = {}
|
||||
self._pending_tasks: List[Task] = self.init_pending_tasks()
|
||||
|
||||
self._best_score = None
|
||||
|
@ -560,7 +561,7 @@ class TaskGenerator:
|
|||
self._dump_tasks_info()
|
||||
return task
|
||||
|
||||
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
|
||||
def get_best_result(self) -> Optional[Tuple[Union[int, str], Module, Dict[str, Dict[str, Tensor]], Optional[float], List[Dict]]]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
|
|
|
@ -34,6 +34,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
|
|||
"""
|
||||
|
||||
def collect(self) -> Dict[str, Tensor]:
|
||||
assert self.compressor.bound_model is not None
|
||||
for _ in range(self.training_epochs):
|
||||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
|
@ -50,6 +51,7 @@ class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
|
|||
"""
|
||||
|
||||
def collect(self) -> Dict[str, List[Tensor]]:
|
||||
assert self.compressor.bound_model is not None
|
||||
for _ in range(self.training_epochs):
|
||||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ class NormMetricsCalculator(MetricsCalculator):
|
|||
if len(across_dim) == 0:
|
||||
metrics[name] = tensor.abs()
|
||||
else:
|
||||
metrics[name] = tensor.norm(p=self.p, dim=across_dim)
|
||||
metrics[name] = tensor.norm(p=self.p, dim=across_dim) # type: ignore
|
||||
return metrics
|
||||
|
||||
|
||||
|
@ -142,7 +142,7 @@ class DistMetricsCalculator(MetricsCalculator):
|
|||
if len(across_dim) == 0:
|
||||
dist_sum = torch.abs(reorder_tensor - other).sum()
|
||||
else:
|
||||
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum()
|
||||
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum() # type: ignore
|
||||
# NOTE: this place need refactor when support layer level pruning.
|
||||
tmp_metric = metric
|
||||
for i in idx[:-1]:
|
||||
|
|
|
@ -141,7 +141,7 @@ class DDPG(nn.Module):
|
|||
])
|
||||
|
||||
target_q_batch = to_tensor(reward_batch) + \
|
||||
self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values
|
||||
self.discount * to_tensor(terminal_batch.astype(np.float32)) * next_q_values
|
||||
|
||||
# Critic update
|
||||
self.critic.zero_grad()
|
||||
|
|
|
@ -38,8 +38,8 @@ class AMCEnv:
|
|||
assert target in ['flops', 'params']
|
||||
self.target = target
|
||||
|
||||
self.origin_target, self.origin_params_num, self.origin_statistics = count_flops_params(model, dummy_input, verbose=False)
|
||||
self.origin_statistics = {result['name']: result for result in self.origin_statistics}
|
||||
self.origin_target, self.origin_params_num, origin_statistics = count_flops_params(model, dummy_input, verbose=False)
|
||||
self.origin_statistics = {result['name']: result for result in origin_statistics}
|
||||
|
||||
self.under_pruning_target = sum([self.origin_statistics[name][self.target] for name in self.pruning_op_names])
|
||||
self.excepted_pruning_target = self.total_sparsity * self.under_pruning_target
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
from collections import deque, namedtuple
|
||||
from typing import Any, List
|
||||
import warnings
|
||||
import random
|
||||
|
||||
|
@ -31,7 +32,7 @@ def sample_batch_indexes(low, high, size):
|
|||
'Not enough entries to sample without replacement. '
|
||||
'Consider increasing your warm-up phase to avoid oversampling!')
|
||||
batch_idxs = np.random.random_integers(low, high - 1, size=size)
|
||||
assert len(batch_idxs) == size
|
||||
assert len(batch_idxs) == size # type: ignore
|
||||
return batch_idxs
|
||||
|
||||
|
||||
|
@ -147,14 +148,14 @@ class SequentialMemory(Memory):
|
|||
# Skip this transition because the environment was reset here. Select a new, random
|
||||
# transition and use this instead. This may cause the batch to contain the same
|
||||
# transition twice.
|
||||
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0]
|
||||
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] # type: ignore
|
||||
terminal0 = self.terminals[idx - 2] if idx >= 2 else False
|
||||
assert 1 <= idx < self.nb_entries
|
||||
|
||||
# This code is slightly complicated by the fact that subsequent observations might be
|
||||
# from different episodes. We ensure that an experience never spans multiple episodes.
|
||||
# This is probably not that important in practice but it seems cleaner.
|
||||
state0 = [self.observations[idx - 1]]
|
||||
state0: List[Any] = [self.observations[idx - 1]]
|
||||
for offset in range(0, self.window_length - 1):
|
||||
current_idx = idx - 2 - offset
|
||||
current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False
|
||||
|
|
|
@ -29,7 +29,7 @@ class NormalSparsityAllocator(SparsityAllocator):
|
|||
# We assume the metric value are all positive right now.
|
||||
metric = metrics[name]
|
||||
if self.continuous_mask:
|
||||
metric *= self._compress_mask(wrapper.weight_mask)
|
||||
metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
|
||||
prune_num = int(sparsity_rate * metric.numel())
|
||||
if prune_num == 0:
|
||||
threshold = metric.min() - 1
|
||||
|
@ -64,7 +64,7 @@ class BankSparsityAllocator(SparsityAllocator):
|
|||
# We assume the metric value are all positive right now.
|
||||
metric = metrics[name]
|
||||
if self.continuous_mask:
|
||||
metric *= self._compress_mask(wrapper.weight_mask)
|
||||
metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
|
||||
n_dim = len(metric.shape)
|
||||
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
|
||||
# make up for balance_gran
|
||||
|
@ -129,15 +129,15 @@ class GlobalSparsityAllocator(SparsityAllocator):
|
|||
|
||||
# We assume the metric value are all positive right now.
|
||||
if self.continuous_mask:
|
||||
metric = metric * self._compress_mask(wrapper.weight_mask)
|
||||
metric = metric * self._compress_mask(wrapper.weight_mask) # type: ignore
|
||||
|
||||
layer_weight_num = wrapper.weight.data.numel()
|
||||
layer_weight_num = wrapper.weight.data.numel() # type: ignore
|
||||
total_weight_num += layer_weight_num
|
||||
expend_times = int(layer_weight_num / metric.numel())
|
||||
|
||||
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
|
||||
retention_numel = math.ceil(retention_ratio * layer_weight_num)
|
||||
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel()))
|
||||
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel())) # type: ignore
|
||||
stay_metric_num = metric.numel() - removed_metric_num
|
||||
if stay_metric_num <= 0:
|
||||
sub_thresholds[name] = metric.min().item() - 1
|
||||
|
@ -182,7 +182,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
|
|||
grouped_metric = {name: metrics[name] for name in names if name in metrics}
|
||||
if self.continuous_mask:
|
||||
for name, metric in grouped_metric.items():
|
||||
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask)
|
||||
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) # type: ignore
|
||||
if len(grouped_metric) > 0:
|
||||
grouped_metrics[idx] = grouped_metric
|
||||
for _, group_metric_dict in grouped_metrics.items():
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from copy import deepcopy
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import json_tricks
|
||||
|
||||
import numpy as np
|
||||
|
@ -150,9 +150,9 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
|
|||
|
||||
|
||||
class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
||||
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
|
||||
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
|
||||
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
|
||||
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False):
|
||||
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -196,9 +196,9 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
self.target_sparsity_list = config_list_canonical(model, config_list)
|
||||
self._adjust_target_sparsity()
|
||||
|
||||
self._temp_config_list = None
|
||||
self._current_sparsity_list = None
|
||||
self._current_score = None
|
||||
self._temp_config_list = []
|
||||
self._current_sparsity_list = []
|
||||
self._current_score = 0.
|
||||
|
||||
super().reset(model, config_list=config_list, masks=masks)
|
||||
|
||||
|
@ -248,7 +248,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
|
||||
return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity
|
||||
|
||||
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> List:
|
||||
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> Optional[List]:
|
||||
assert len(random_sparsity) == len(op_names)
|
||||
|
||||
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
|
||||
|
@ -267,7 +267,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
scale = target_sparsity / (total_weights_pruned / total_weights)
|
||||
|
||||
# rescale the sparsity
|
||||
sparsity = np.asarray(sparsity) * scale
|
||||
sparsity = list(np.asarray(sparsity) * scale)
|
||||
return sparsity
|
||||
|
||||
def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]:
|
||||
|
@ -285,7 +285,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
# decrease magnitude with current temperature
|
||||
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
|
||||
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
|
||||
if len(current_sparsity) == 0:
|
||||
if not current_sparsity:
|
||||
sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))]
|
||||
for temp_config, op_name in zip(sub_temp_config_list, config['op_names']):
|
||||
temp_config.update({'total_sparsity': 0, 'op_names': [op_name]})
|
||||
|
@ -327,11 +327,12 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
|
||||
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
|
||||
# initial/update temp config list
|
||||
if self._temp_config_list is None:
|
||||
if not self._temp_config_list:
|
||||
self._init_temp_config_list()
|
||||
else:
|
||||
score = self._tasks[task_result.task_id].score
|
||||
if self._current_sparsity_list is None:
|
||||
assert score is not None, 'SimulatedAnnealingTaskGenerator need each score is not None.'
|
||||
if not self._current_sparsity_list:
|
||||
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
|
||||
self._current_score = score
|
||||
else:
|
||||
|
|
|
@ -19,7 +19,7 @@ class ConstructHelper:
|
|||
def __init__(self, callable_obj: Callable, *args, **kwargs):
|
||||
assert callable(callable_obj), '`callable_obj` must be a callable object.'
|
||||
self.callable_obj = callable_obj
|
||||
self.args = deepcopy(args)
|
||||
self.args = deepcopy(list(args))
|
||||
self.kwargs = deepcopy(kwargs)
|
||||
|
||||
def call(self):
|
||||
|
|
|
@ -149,14 +149,14 @@ def compute_sparsity_compact2origin(origin_model: Module, compact_model: Module,
|
|||
continue
|
||||
if 'op_names' in config and module_name not in config['op_names']:
|
||||
continue
|
||||
total_weight_num += module.weight.data.numel()
|
||||
total_weight_num += module.weight.data.numel() # type: ignore
|
||||
for module_name, module in compact_model.named_modules():
|
||||
module_type = type(module).__name__
|
||||
if 'op_types' in config and module_type not in config['op_types']:
|
||||
continue
|
||||
if 'op_names' in config and module_name not in config['op_names']:
|
||||
continue
|
||||
left_weight_num += module.weight.data.numel()
|
||||
left_weight_num += module.weight.data.numel() # type: ignore
|
||||
compact2origin_sparsity.append(deepcopy(config))
|
||||
compact2origin_sparsity[-1]['total_sparsity'] = 1 - left_weight_num / total_weight_num
|
||||
return compact2origin_sparsity
|
||||
|
@ -179,7 +179,7 @@ def compute_sparsity_mask2compact(compact_model: Module, compact_model_masks: Di
|
|||
continue
|
||||
if 'op_names' in config and module_name not in config['op_names']:
|
||||
continue
|
||||
module_weight_num = module.weight.data.numel()
|
||||
module_weight_num = module.weight.data.numel() # type: ignore
|
||||
total_weight_num += module_weight_num
|
||||
if module_name in compact_model_masks:
|
||||
weight_mask = compact_model_masks[module_name]['weight']
|
||||
|
@ -229,7 +229,7 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
|
|||
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
|
||||
|
||||
|
||||
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Dict:
|
||||
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
|
||||
"""
|
||||
Count the layer weight elements number in config_list.
|
||||
If masks is not empty, the masked weight will not be counted.
|
||||
|
@ -248,7 +248,7 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
|
|||
masked_rate[module_name] = 1 - (weight_mask.sum().item() / weight_mask.numel())
|
||||
model_weights_numel[module_name] = round(weight_mask.sum().item())
|
||||
else:
|
||||
model_weights_numel[module_name] = module.weight.data.numel()
|
||||
model_weights_numel[module_name] = module.weight.data.numel() # type: ignore
|
||||
return model_weights_numel, masked_rate
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
{
|
||||
"ignore": [
|
||||
"nni/algorithms",
|
||||
"nni/algorithms/compression/pytorch",
|
||||
"nni/algorithms/compression/tensorflow",
|
||||
"nni/algorithms/compression/v2/pytorch/base/pruner.py",
|
||||
"nni/algorithms/feature_engineering",
|
||||
"nni/algorithms/hpo",
|
||||
"nni/algorithms/nas",
|
||||
"nni/common/device.py",
|
||||
"nni/common/graph_utils.py",
|
||||
"nni/compression",
|
||||
|
|
Загрузка…
Ссылка в новой задаче