This commit is contained in:
J-shang 2022-04-27 17:55:53 +08:00 коммит произвёл GitHub
Родитель d49864ce28
Коммит cbac2c5c0f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
21 изменённых файлов: 204 добавлений и 134 удалений

Просмотреть файл

@ -1,9 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import List, Dict, Optional, Tuple, Any
from typing import Any, List, Dict, Optional, Tuple
import torch
from torch.nn import Module
@ -29,7 +28,33 @@ def _setattr(model: Module, name: str, module: Module):
name_list = name.split(".")
setattr(parent_module, name_list[-1], module)
else:
raise '{} not exist.'.format(name)
raise Exception('{} not exist.'.format(name))
class ModuleWrapper(Module):
"""
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
"""
def __init__(self, module: Module, module_name: str, config: Dict):
super().__init__()
# origin layer information
self.module = module
self.name = module_name
# config information
self.config = config
def forward(self, *inputs):
raise NotImplementedError
class Compressor:
@ -46,7 +71,7 @@ class Compressor:
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
self.is_wrapped = False
if model is not None:
if model is not None and config_list is not None:
self.reset(model=model, config_list=config_list)
else:
_logger.warning('This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.')
@ -63,6 +88,7 @@ class Compressor:
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert isinstance(model, Module), 'Only support compressing pytorch Module, but the type of model is {}.'.format(type(model))
self.bound_model = model
self.config_list = config_list
self.validate_config(model=model, config_list=config_list)
@ -70,7 +96,7 @@ class Compressor:
self._unwrap_model()
self._modules_to_compress = None
self.modules_wrapper = collections.OrderedDict()
self.modules_wrapper = {}
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper[layer.name] = wrapper
@ -93,6 +119,8 @@ class Compressor:
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
if self._modules_to_compress is None:
self._modules_to_compress = []
for name, module in self.bound_model.named_modules():
@ -118,6 +146,8 @@ class Compressor:
Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed.
"""
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
ret = None
for config in self.config_list:
config = config.copy()
@ -142,32 +172,26 @@ class Compressor:
return None
return ret
def get_modules_wrapper(self) -> Dict[str, Module]:
def get_modules_wrapper(self) -> Dict[str, ModuleWrapper]:
"""
Returns
-------
OrderedDict[str, Module]
An ordered dict, key is the name of the module, value is the wrapper of the module.
Dict[str, ModuleWrapper]
An dict, key is the name of the module, value is the wrapper of the module.
"""
return self.modules_wrapper
raise NotImplementedError
def _wrap_model(self):
"""
Wrap all modules that needed to be compressed.
"""
if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
raise NotImplementedError
def _unwrap_model(self):
"""
Unwrap all modules that needed to be compressed.
"""
if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
raise NotImplementedError
def set_wrappers_attribute(self, name: str, value: Any):
"""
@ -182,7 +206,7 @@ class Compressor:
value
Value of the variable.
"""
for wrapper in self.get_modules_wrapper():
for wrapper in self.get_modules_wrapper().values():
if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone())
else:
@ -216,8 +240,10 @@ class Compressor:
Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
"""
self._unwrap_model()
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
self._unwrap_model()
module_groups = {}
for name, module in self.bound_model.named_modules():
if module == self.bound_model:
@ -259,7 +285,7 @@ class Compressor:
"""
raise NotImplementedError()
def _wrap_modules(self, layer: LayerInfo, config: Dict):
def _wrap_modules(self, layer: LayerInfo, config: Dict) -> ModuleWrapper:
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
@ -297,4 +323,6 @@ class Compressor:
torch.nn.Module
model with specified modules compressed.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
return self.bound_model

Просмотреть файл

@ -2,11 +2,12 @@
# Licensed under the MIT license.
import logging
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, OrderedDict
import torch
from torch import Tensor
from torch.nn import Module, Parameter
from torch.nn import Module
from torch.nn.parameter import Parameter
from .compressor import Compressor, LayerInfo, _setattr
@ -37,15 +38,15 @@ class PrunerModuleWrapper(Module):
# config information
self.config = config
self.weight = Parameter(torch.empty(self.module.weight.size()))
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
self.bias = Parameter(torch.empty(self.module.bias.size()))
else:
self.register_buffer("bias_mask", None)
pruning_target_names = ['weight', 'bias']
for pruning_target_name in pruning_target_names:
pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
pruning_target = getattr(self.module, pruning_target_name, None)
if hasattr(self.module, pruning_target_name) and pruning_target is not None:
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
else:
self.register_buffer(pruning_target_mask_name, None)
def _weight2buffer(self):
"""
@ -89,7 +90,17 @@ class Pruner(Compressor):
def reset(self, model: Optional[Module] = None, config_list: Optional[List[Dict]] = None):
super().reset(model=model, config_list=config_list)
def _wrap_modules(self, layer: LayerInfo, config: Dict):
def get_modules_wrapper(self) -> OrderedDict[str, PrunerModuleWrapper]:
"""
Returns
-------
OrderedDict[str, PrunerModuleWrapper]
An ordered dict, key is the name of the module, value is the wrapper of the module.
"""
assert self.modules_wrapper is not None, 'Bound model has not be wrapped.'
return self.modules_wrapper
def _wrap_modules(self, layer: LayerInfo, config: Dict) -> PrunerModuleWrapper:
"""
Create a wrapper module to replace the original one.
@ -99,6 +110,11 @@ class Pruner(Compressor):
The layer to instrument the mask.
config
The configuration for generating the mask.
Returns
-------
PrunerModuleWrapper
The wrapper of the module in layerinfo.
"""
_logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, config)
@ -114,8 +130,10 @@ class Pruner(Compressor):
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()):
for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
_setattr(self.bound_model, wrapper.name, wrapper)
wrapper._weight2buffer()
self.is_wrapped = True
@ -125,8 +143,10 @@ class Pruner(Compressor):
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items():
for wrapper in self.get_modules_wrapper().values():
_setattr(self.bound_model, wrapper.name, wrapper.module)
wrapper._weight2parameter()
self.is_wrapped = False
@ -191,7 +211,7 @@ class Pruner(Compressor):
dim
The pruned dim.
"""
for _, wrapper in self.get_modules_wrapper().items():
for wrapper in self.get_modules_wrapper().values():
weight_mask = wrapper.weight_mask
mask_size = weight_mask.size()
if len(mask_size) == 1:

Просмотреть файл

@ -5,7 +5,7 @@ import gc
import logging
import os
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from typing import List, Dict, Tuple, Optional, Union
import json_tricks
import torch
@ -19,7 +19,7 @@ class Task:
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
_reference_counter = {}
def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str,
def __init__(self, task_id: int, model_path: Union[str, Path], masks_path: Union[str, Path], config_list_path: Union[str, Path],
speedup: Optional[bool] = True, finetune: Optional[bool] = True, evaluate: Optional[bool] = True):
"""
Parameters
@ -87,7 +87,7 @@ class Task:
config_list = json_tricks.load(f)
return model, masks, config_list
def referenced_paths(self) -> List[str]:
def referenced_paths(self) -> List[Union[str, Path]]:
"""
Return the path list that need to count reference in this task.
"""
@ -111,7 +111,7 @@ class Task:
class TaskResult:
def __init__(self, task_id: int, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
def __init__(self, task_id: Union[int, str], compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None:
"""
Parameters

Просмотреть файл

@ -82,12 +82,13 @@ class AMCTaskGenerator(TaskGenerator):
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# append experience & update agent policy
if task_result.task_id != 'origin':
if self.action is not None:
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
self.T.append([reward, self.observation, observation, self.action, done])
self.observation = observation.copy()
if done:
assert task_result.score is not None, 'task_result.score should not be None if environment is done.'
final_reward = task_result.score - 1
# agent observe and update policy
for _, s_t, s_t1, a_t, d_t in self.T:

Просмотреть файл

@ -46,7 +46,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks)
self.iterative_pruner.compress()
_, _, _, _, config_list = self.iterative_pruner.get_best_result()
best_result = self.iterative_pruner.get_best_result()
assert best_result is not None, 'Best result does not exist, iterative pruner may not start pruning.'
_, _, _, _, config_list = best_result
return config_list
@ -149,7 +151,7 @@ class AutoCompressPruner(IterativePruner):
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Callable[[Module], float] = None):
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,

Просмотреть файл

@ -8,7 +8,7 @@ from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Or, Optional as SchemaOptional, SchemaError
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch.optim import Optimizer
@ -77,10 +77,10 @@ INTERNAL_SCHEMA = {
class BasicPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]):
self.data_collector: DataCollector = None
self.metrics_calculator: MetricsCalculator = None
self.sparsity_allocator: SparsityAllocator = None
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
self.data_collector: Optional[DataCollector] = None
self.metrics_calculator: Optional[MetricsCalculator] = None
self.sparsity_allocator: Optional[SparsityAllocator] = None
super().__init__(model, config_list)
@ -114,6 +114,8 @@ class BasicPruner(Pruner):
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
assert self.bound_model is not None and self.config_list is not None, 'Model and/or config_list are not set in this pruner, please set them by reset() before compress().'
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
data = self.data_collector.collect()
_logger.debug('Collected Data:\n%s', data)
metrics = self.metrics_calculator.calculate_metrics(data)
@ -553,8 +555,8 @@ class SlimPruner(BasicPruner):
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
def patched_criterion(input_tensor: Tensor, target: Tensor):
sum_l1 = 0
for _, wrapper in self.get_modules_wrapper().items():
sum_l1 += torch.norm(wrapper.module.weight, p=1)
for wrapper in self.get_modules_wrapper().values():
sum_l1 += torch.norm(wrapper.module.weight, p=1) # type: ignore
return criterion(input_tensor, target) + self._scale * sum_l1
return patched_criterion
@ -654,11 +656,11 @@ class ActivationPruner(BasicPruner):
def _choose_activation(self, activation: str = 'relu') -> Callable:
if activation == 'relu':
return nn.functional.relu
return F.relu
elif activation == 'relu6':
return nn.functional.relu6
return F.relu6
else:
raise 'Unsupported activatoin {}'.format(activation)
raise Exception('Unsupported activatoin {}'.format(activation))
def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to activation pruner collector is not empty.'
@ -684,7 +686,7 @@ class ActivationPruner(BasicPruner):
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset(collector_infos=[collector_info])
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None:
self.metrics_calculator = self._get_metrics_calculator()
if self.sparsity_allocator is None:
@ -999,13 +1001,13 @@ class TaylorFOWeightPruner(BasicPruner):
return (weight_tensor.detach() * grad.detach()).data.pow(2)
def reset_tools(self):
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()}
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector)
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) # type: ignore
if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset(collector_infos=[collector_info])
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None:
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0)
if self.sparsity_allocator is None:
@ -1095,24 +1097,26 @@ class ADMMPruner(BasicPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/admm_pruning_torch.py <examples/model_compress/pruning/admm_pruning_torch.py>`
"""
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int,
training_epochs: int, granularity: str = 'fine-grained'):
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
assert model is not None, 'Model is required if traced_optimizer is provided.'
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.iterations = iterations
self.training_epochs = training_epochs
assert granularity in ['fine-grained', 'coarse-grained']
self.granularity = granularity
self.Z, self.U = {}, {}
super().__init__(model, config_list)
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
def reset(self, model: Module, config_list: List[Dict]):
super().reset(model, config_list)
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()}
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()}
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
@ -1156,6 +1160,8 @@ class ADMMPruner(BasicPruner):
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
assert self.bound_model is not None
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
for i in range(self.iterations):
_logger.info('======= ADMM Iteration %d Start =======', i)
data = self.data_collector.collect()
@ -1169,11 +1175,10 @@ class ADMMPruner(BasicPruner):
self.Z[name] = self.Z[name].mul(mask['weight'])
self.U[name] = self.U[name] + data[name] - self.Z[name]
self.Z = None
self.U = None
self.Z, self.U = {}, {}
torch.cuda.empty_cache()
metrics = self.metrics_calculator.calculate_metrics(data)
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics)
self.load_masks(masks)

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional
from typing import Dict, List, Tuple, Callable, Optional, Union
import torch
from torch import Tensor
@ -36,8 +36,8 @@ class PruningScheduler(BasePruningScheduler):
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speedup: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False):
self.pruner = pruner
self.task_generator = task_generator
@ -155,5 +155,5 @@ class PruningScheduler(BasePruningScheduler):
torch.cuda.empty_cache()
return result
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
def get_best_result(self) -> Optional[Tuple[Union[int, str], Module, Dict[str, Dict[str, Tensor]], Optional[float], List[Dict]]]:
return self.task_generator.get_best_result()

Просмотреть файл

@ -2,7 +2,8 @@
# Licensed under the MIT license.
import logging
from typing import Dict, List, Callable, Optional
from pathlib import Path
from typing import Dict, List, Callable, Optional, Union
from torch import Tensor
from torch.nn import Module
@ -293,9 +294,9 @@ class SimulatedAnnealingPruner(IterativePruner):
Parameters
----------
model : Module
model : Optional[Module]
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list : Optional[List[Dict]]
The origin config list provided by the user.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
@ -312,7 +313,7 @@ class SimulatedAnnealingPruner(IterativePruner):
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
pruning_params : Dict
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str
log_dir : Union[str, Path]
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
@ -337,9 +338,9 @@ class SimulatedAnnealingPruner(IterativePruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>`
"""
def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100,
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], evaluator: Callable[[Module], float], start_temperature: float = 100,
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list,
@ -350,7 +351,7 @@ class SimulatedAnnealingPruner(IterativePruner):
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)

Просмотреть файл

@ -7,7 +7,8 @@ from typing import Dict, List, Tuple, Callable
import torch
from torch import autograd, Tensor
from torch.nn import Module, Parameter
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
@ -41,15 +42,15 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
"""
def __init__(self, module: Module, module_name: str, config: Dict):
super().__init__(module, module_name, config)
self.weight_score = Parameter(torch.empty(self.weight.size()))
self.weight_score = Parameter(torch.empty(self.weight.size())) # type: ignore
torch.nn.init.constant_(self.weight_score, val=0.0)
def forward(self, *inputs):
# apply mask to weight, bias
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach_()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach_()))
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask)
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs)
@ -58,7 +59,7 @@ class _StraightThrough(autograd.Function):
Straight through the gradient to the score, then the score = initial_score + sum(-lr * grad(weight) * weight).
"""
@staticmethod
def forward(self, score, masks):
def forward(ctx, score, masks):
return masks
@staticmethod
@ -71,12 +72,13 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
Collect all weight_score in wrappers as data used to calculate metrics.
"""
def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data
data[wrapper.name] = wrapper.weight_score.data # type: ignore
return data
@ -193,6 +195,7 @@ class MovementPruner(BasicPruner):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score
assert self.bound_model is not None
params = [{"params": [p for n, p in self.bound_model.named_parameters() if "weight_score" in n and p.requires_grad]}]
optimizer = Adam(params, 1e-2)
self.step_counter = 0
@ -205,10 +208,10 @@ class MovementPruner(BasicPruner):
if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter)
data = {}
for _, wrapper in self.get_modules_wrapper().items():
for wrapper in self.get_modules_wrapper().values():
data[wrapper.name] = wrapper.weight_score.data
metrics = self.metrics_calculator.calculate_metrics(data)
masks = self.sparsity_allocator.generate_sparsity(metrics)
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks)
if self.data_collector is None:
@ -232,15 +235,15 @@ class MovementPruner(BasicPruner):
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device)
wrapper.to(layer.module.weight.device) # type: ignore
return wrapper
def compress(self) -> Tuple[Module, Dict]:
# sparsity grow from 0
for _, wrapper in self.get_modules_wrapper().items():
for wrapper in self.get_modules_wrapper().values():
wrapper.config['total_sparsity'] = 0
result = super().compress()
# del weight_score
for _, wrapper in self.get_modules_wrapper().items():
for wrapper in self.get_modules_wrapper().values():
wrapper.weight_score = None
return result

Просмотреть файл

@ -13,7 +13,7 @@ from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo, Task, TaskResult
from nni.algorithms.compression.v2.pytorch.base import Pruner, LayerInfo, Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
_logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class DataCollector:
The compressor binded with this DataCollector.
"""
def __init__(self, compressor: Compressor):
def __init__(self, compressor: Pruner):
self.compressor = compressor
def reset(self):
@ -76,10 +76,10 @@ class TrainerBasedDataCollector(DataCollector):
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
"""
def __init__(self, compressor: Compressor, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Callable[[Callable], Callable] = None):
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
"""
Parameters
----------
@ -165,6 +165,7 @@ class TrainerBasedDataCollector(DataCollector):
def _reset_optimizer(self):
parameter_name_map = self.compressor.get_origin2wrapped_parameter_name_map()
assert self.compressor.bound_model is not None
self.optimizer = self.optimizer_helper.call(self.compressor.bound_model, parameter_name_map)
def _patch_optimizer(self):
@ -187,11 +188,11 @@ class TrainerBasedDataCollector(DataCollector):
self._hook_buffer[self._hook_id] = {}
if collector_info.hook_type == 'forward':
self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector)
self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
elif collector_info.hook_type == 'backward':
self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector)
self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
elif collector_info.hook_type == 'tensor':
self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector)
self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
else:
_logger.warning('Skip unsupported hook type: %s', collector_info.hook_type)
@ -210,7 +211,7 @@ class TrainerBasedDataCollector(DataCollector):
assert all(isinstance(layer_info, LayerInfo) for layer_info in layers)
for layer in layers:
self._hook_buffer[hook_id][layer.name] = []
handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name]))
handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name])) # type: ignore
self._hook_handles[hook_id][layer.name] = handle
def _add_tensor_hook(self, hook_id: int, tensors: Dict[str, Tensor],
@ -286,7 +287,7 @@ class MetricsCalculator:
self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
def calculate_metrics(self, data: Dict) -> Dict[str, Tensor]:
"""
@ -334,7 +335,7 @@ class SparsityAllocator:
Inherit the mask already in the wrapper if set True.
"""
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None,
def __init__(self, pruner: Pruner, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
self.pruner = pruner
self.dim = dim if not isinstance(dim, int) else [dim]
@ -345,7 +346,7 @@ class SparsityAllocator:
self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
self.continuous_mask = continuous_mask
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
@ -384,7 +385,7 @@ class SparsityAllocator:
weight_mask = weight_mask.expand(expand_size).reshape(reshape_size)
wrapper = self.pruner.get_modules_wrapper()[name]
weight_size = wrapper.weight.data.size()
weight_size = wrapper.weight.data.size() # type: ignore
if self.dim is None:
assert weight_mask.size() == weight_size
@ -401,7 +402,7 @@ class SparsityAllocator:
expand_mask = {'weight': weight_mask.expand(weight_size).clone()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor.
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size():
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): # type: ignore
expand_mask['bias'] = weight_mask.clone()
return expand_mask
@ -463,7 +464,7 @@ class TaskGenerator:
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: str = '.', keep_intermediate_result: bool = False):
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result
@ -486,7 +487,7 @@ class TaskGenerator:
self._save_data('origin', model, masks, config_list)
self._task_id_candidate = 0
self._tasks: Dict[int, Task] = {}
self._tasks: Dict[Union[int, str], Task] = {}
self._pending_tasks: List[Task] = self.init_pending_tasks()
self._best_score = None
@ -560,7 +561,7 @@ class TaskGenerator:
self._dump_tasks_info()
return task
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
def get_best_result(self) -> Optional[Tuple[Union[int, str], Module, Dict[str, Dict[str, Tensor]], Optional[float], List[Dict]]]:
"""
Returns
-------

Просмотреть файл

@ -34,6 +34,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
@ -50,6 +51,7 @@ class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
def collect(self) -> Dict[str, List[Tensor]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)

Просмотреть файл

@ -70,7 +70,7 @@ class NormMetricsCalculator(MetricsCalculator):
if len(across_dim) == 0:
metrics[name] = tensor.abs()
else:
metrics[name] = tensor.norm(p=self.p, dim=across_dim)
metrics[name] = tensor.norm(p=self.p, dim=across_dim) # type: ignore
return metrics
@ -142,7 +142,7 @@ class DistMetricsCalculator(MetricsCalculator):
if len(across_dim) == 0:
dist_sum = torch.abs(reorder_tensor - other).sum()
else:
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum()
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum() # type: ignore
# NOTE: this place need refactor when support layer level pruning.
tmp_metric = metric
for i in idx[:-1]:

Просмотреть файл

@ -141,7 +141,7 @@ class DDPG(nn.Module):
])
target_q_batch = to_tensor(reward_batch) + \
self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values
self.discount * to_tensor(terminal_batch.astype(np.float32)) * next_q_values
# Critic update
self.critic.zero_grad()

Просмотреть файл

@ -38,8 +38,8 @@ class AMCEnv:
assert target in ['flops', 'params']
self.target = target
self.origin_target, self.origin_params_num, self.origin_statistics = count_flops_params(model, dummy_input, verbose=False)
self.origin_statistics = {result['name']: result for result in self.origin_statistics}
self.origin_target, self.origin_params_num, origin_statistics = count_flops_params(model, dummy_input, verbose=False)
self.origin_statistics = {result['name']: result for result in origin_statistics}
self.under_pruning_target = sum([self.origin_statistics[name][self.target] for name in self.pruning_op_names])
self.excepted_pruning_target = self.total_sparsity * self.under_pruning_target

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import absolute_import
from collections import deque, namedtuple
from typing import Any, List
import warnings
import random
@ -31,7 +32,7 @@ def sample_batch_indexes(low, high, size):
'Not enough entries to sample without replacement. '
'Consider increasing your warm-up phase to avoid oversampling!')
batch_idxs = np.random.random_integers(low, high - 1, size=size)
assert len(batch_idxs) == size
assert len(batch_idxs) == size # type: ignore
return batch_idxs
@ -147,14 +148,14 @@ class SequentialMemory(Memory):
# Skip this transition because the environment was reset here. Select a new, random
# transition and use this instead. This may cause the batch to contain the same
# transition twice.
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0]
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] # type: ignore
terminal0 = self.terminals[idx - 2] if idx >= 2 else False
assert 1 <= idx < self.nb_entries
# This code is slightly complicated by the fact that subsequent observations might be
# from different episodes. We ensure that an experience never spans multiple episodes.
# This is probably not that important in practice but it seems cleaner.
state0 = [self.observations[idx - 1]]
state0: List[Any] = [self.observations[idx - 1]]
for offset in range(0, self.window_length - 1):
current_idx = idx - 2 - offset
current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False

Просмотреть файл

@ -29,7 +29,7 @@ class NormalSparsityAllocator(SparsityAllocator):
# We assume the metric value are all positive right now.
metric = metrics[name]
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask)
metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
prune_num = int(sparsity_rate * metric.numel())
if prune_num == 0:
threshold = metric.min() - 1
@ -64,7 +64,7 @@ class BankSparsityAllocator(SparsityAllocator):
# We assume the metric value are all positive right now.
metric = metrics[name]
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask)
metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
n_dim = len(metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
@ -129,15 +129,15 @@ class GlobalSparsityAllocator(SparsityAllocator):
# We assume the metric value are all positive right now.
if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask)
metric = metric * self._compress_mask(wrapper.weight_mask) # type: ignore
layer_weight_num = wrapper.weight.data.numel()
layer_weight_num = wrapper.weight.data.numel() # type: ignore
total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel())
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
retention_numel = math.ceil(retention_ratio * layer_weight_num)
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel()))
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel())) # type: ignore
stay_metric_num = metric.numel() - removed_metric_num
if stay_metric_num <= 0:
sub_thresholds[name] = metric.min().item() - 1
@ -182,7 +182,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
grouped_metric = {name: metrics[name] for name in names if name in metrics}
if self.continuous_mask:
for name, metric in grouped_metric.items():
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask)
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) # type: ignore
if len(grouped_metric) > 0:
grouped_metrics[idx] = grouped_metric
for _, group_metric_dict in grouped_metrics.items():

Просмотреть файл

@ -4,7 +4,7 @@
from copy import deepcopy
import logging
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple, Union
import json_tricks
import numpy as np
@ -150,9 +150,9 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False):
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
"""
Parameters
----------
@ -196,9 +196,9 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self.target_sparsity_list = config_list_canonical(model, config_list)
self._adjust_target_sparsity()
self._temp_config_list = None
self._current_sparsity_list = None
self._current_score = None
self._temp_config_list = []
self._current_sparsity_list = []
self._current_score = 0.
super().reset(model, config_list=config_list, masks=masks)
@ -248,7 +248,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> List:
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> Optional[List]:
assert len(random_sparsity) == len(op_names)
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
@ -267,7 +267,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
scale = target_sparsity / (total_weights_pruned / total_weights)
# rescale the sparsity
sparsity = np.asarray(sparsity) * scale
sparsity = list(np.asarray(sparsity) * scale)
return sparsity
def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]:
@ -285,7 +285,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
# decrease magnitude with current temperature
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0:
if not current_sparsity:
sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))]
for temp_config, op_name in zip(sub_temp_config_list, config['op_names']):
temp_config.update({'total_sparsity': 0, 'op_names': [op_name]})
@ -327,11 +327,12 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# initial/update temp config list
if self._temp_config_list is None:
if not self._temp_config_list:
self._init_temp_config_list()
else:
score = self._tasks[task_result.task_id].score
if self._current_sparsity_list is None:
assert score is not None, 'SimulatedAnnealingTaskGenerator need each score is not None.'
if not self._current_sparsity_list:
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
self._current_score = score
else:

Просмотреть файл

@ -19,7 +19,7 @@ class ConstructHelper:
def __init__(self, callable_obj: Callable, *args, **kwargs):
assert callable(callable_obj), '`callable_obj` must be a callable object.'
self.callable_obj = callable_obj
self.args = deepcopy(args)
self.args = deepcopy(list(args))
self.kwargs = deepcopy(kwargs)
def call(self):

Просмотреть файл

@ -149,14 +149,14 @@ def compute_sparsity_compact2origin(origin_model: Module, compact_model: Module,
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
total_weight_num += module.weight.data.numel()
total_weight_num += module.weight.data.numel() # type: ignore
for module_name, module in compact_model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
left_weight_num += module.weight.data.numel()
left_weight_num += module.weight.data.numel() # type: ignore
compact2origin_sparsity.append(deepcopy(config))
compact2origin_sparsity[-1]['total_sparsity'] = 1 - left_weight_num / total_weight_num
return compact2origin_sparsity
@ -179,7 +179,7 @@ def compute_sparsity_mask2compact(compact_model: Module, compact_model_masks: Di
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
module_weight_num = module.weight.data.numel()
module_weight_num = module.weight.data.numel() # type: ignore
total_weight_num += module_weight_num
if module_name in compact_model_masks:
weight_mask = compact_model_masks[module_name]['weight']
@ -229,7 +229,7 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Dict:
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
"""
Count the layer weight elements number in config_list.
If masks is not empty, the masked weight will not be counted.
@ -248,7 +248,7 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
masked_rate[module_name] = 1 - (weight_mask.sum().item() / weight_mask.numel())
model_weights_numel[module_name] = round(weight_mask.sum().item())
else:
model_weights_numel[module_name] = module.weight.data.numel()
model_weights_numel[module_name] = module.weight.data.numel() # type: ignore
return model_weights_numel, masked_rate

Просмотреть файл

Просмотреть файл

@ -1,6 +1,11 @@
{
"ignore": [
"nni/algorithms",
"nni/algorithms/compression/pytorch",
"nni/algorithms/compression/tensorflow",
"nni/algorithms/compression/v2/pytorch/base/pruner.py",
"nni/algorithms/feature_engineering",
"nni/algorithms/hpo",
"nni/algorithms/nas",
"nni/common/device.py",
"nni/common/graph_utils.py",
"nni/compression",