зеркало из https://github.com/microsoft/nni.git
[Model Compression] Add more Task Generator (#4178)
This commit is contained in:
Родитель
7a50c96d04
Коммит
a16e570ddb
|
@ -15,7 +15,8 @@ from .tools import TaskGenerator
|
|||
|
||||
class PruningScheduler(BasePruningScheduler):
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
|
||||
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None):
|
||||
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
|
||||
reset_weight: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -33,6 +34,8 @@ class PruningScheduler(BasePruningScheduler):
|
|||
evaluator
|
||||
Evaluate the pruned model and give a score.
|
||||
If evaluator is None, the best result refers to the latest result.
|
||||
reset_weight
|
||||
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
|
||||
"""
|
||||
self.pruner = pruner
|
||||
self.task_generator = task_generator
|
||||
|
@ -40,6 +43,7 @@ class PruningScheduler(BasePruningScheduler):
|
|||
self.speed_up = speed_up
|
||||
self.dummy_input = dummy_input
|
||||
self.evaluator = evaluator
|
||||
self.reset_weight = reset_weight
|
||||
|
||||
def generate_task(self) -> Optional[Task]:
|
||||
return self.task_generator.next()
|
||||
|
@ -47,12 +51,15 @@ class PruningScheduler(BasePruningScheduler):
|
|||
def record_task_result(self, task_result: TaskResult):
|
||||
self.task_generator.receive_task_result(task_result)
|
||||
|
||||
def pruning_one_step(self, task: Task) -> TaskResult:
|
||||
def pruning_one_step_normal(self, task: Task) -> TaskResult:
|
||||
"""
|
||||
generate masks -> speed up -> finetune -> evaluate
|
||||
"""
|
||||
model, masks, config_list = task.load_data()
|
||||
|
||||
# pruning model
|
||||
self.pruner.reset(model, config_list)
|
||||
self.pruner.load_masks(masks)
|
||||
|
||||
# pruning model
|
||||
compact_model, pruner_generated_masks = self.pruner.compress()
|
||||
compact_model_masks = deepcopy(pruner_generated_masks)
|
||||
|
||||
|
@ -75,12 +82,71 @@ class PruningScheduler(BasePruningScheduler):
|
|||
self.pruner._unwrap_model()
|
||||
|
||||
# evaluate
|
||||
score = self.evaluator(compact_model) if self.evaluator is not None else None
|
||||
if self.evaluator is not None:
|
||||
if self.speed_up:
|
||||
score = self.evaluator(compact_model)
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self.evaluator(compact_model)
|
||||
self.pruner._unwrap_model()
|
||||
else:
|
||||
score = None
|
||||
|
||||
# clear model references
|
||||
self.pruner.clear_model_references()
|
||||
|
||||
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
|
||||
|
||||
def pruning_one_step_reset_weight(self, task: Task) -> TaskResult:
|
||||
"""
|
||||
finetune -> generate masks -> reset weight -> speed up -> evaluate
|
||||
"""
|
||||
model, masks, config_list = task.load_data()
|
||||
checkpoint = deepcopy(model.state_dict())
|
||||
self.pruner.reset(model, config_list)
|
||||
self.pruner.load_masks(masks)
|
||||
|
||||
# finetune
|
||||
if self.finetuner is not None:
|
||||
self.finetuner(model)
|
||||
|
||||
# pruning model
|
||||
compact_model, pruner_generated_masks = self.pruner.compress()
|
||||
compact_model_masks = deepcopy(pruner_generated_masks)
|
||||
|
||||
# show the pruning effect
|
||||
self.pruner.show_pruned_weights()
|
||||
self.pruner._unwrap_model()
|
||||
|
||||
# reset model weight
|
||||
compact_model.load_state_dict(checkpoint)
|
||||
|
||||
# speed up
|
||||
if self.speed_up:
|
||||
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
|
||||
# evaluate
|
||||
if self.evaluator is not None:
|
||||
if self.speed_up:
|
||||
score = self.evaluator(compact_model)
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self.evaluator(compact_model)
|
||||
self.pruner._unwrap_model()
|
||||
else:
|
||||
score = None
|
||||
|
||||
# clear model references
|
||||
self.pruner.clear_model_references()
|
||||
|
||||
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
|
||||
|
||||
def pruning_one_step(self, task: Task) -> TaskResult:
|
||||
if self.reset_weight:
|
||||
return self.pruning_one_step_reset_weight(task)
|
||||
else:
|
||||
return self.pruning_one_step_normal(task)
|
||||
|
||||
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
|
||||
return self.task_generator.get_best_result()
|
||||
|
|
|
@ -24,5 +24,7 @@ from .sparsity_allocator import (
|
|||
)
|
||||
from .task_generator import (
|
||||
AGPTaskGenerator,
|
||||
LinearTaskGenerator
|
||||
LinearTaskGenerator,
|
||||
LotteryTicketTaskGenerator,
|
||||
SimulatedAnnealingTaskGenerator
|
||||
)
|
||||
|
|
|
@ -4,15 +4,20 @@
|
|||
from copy import deepcopy
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
import json_tricks
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
|
||||
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical, compute_sparsity
|
||||
from nni.algorithms.compression.v2.pytorch.utils.pruning import (
|
||||
config_list_canonical,
|
||||
compute_sparsity,
|
||||
get_model_weights_numel
|
||||
)
|
||||
from .base import TaskGenerator
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -21,6 +26,23 @@ _logger = logging.getLogger(__name__)
|
|||
class FunctionBasedTaskGenerator(TaskGenerator):
|
||||
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
|
||||
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
total_iteration
|
||||
The total iteration number.
|
||||
origin_model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
origin_config_list
|
||||
The origin config list provided by the user. Note that this config_list is directly config the origin model.
|
||||
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
|
||||
origin_masks
|
||||
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
|
||||
log_dir
|
||||
The log directory use to saving the task generator log.
|
||||
keep_intermidiate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
"""
|
||||
self.current_iteration = 0
|
||||
self.target_sparsity = config_list_canonical(origin_model, origin_config_list)
|
||||
self.total_iteration = total_iteration
|
||||
|
@ -54,7 +76,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
|
|||
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
|
||||
|
||||
# if reach the total_iteration, no more task will be generated
|
||||
if self.current_iteration >= self.total_iteration:
|
||||
if self.current_iteration > self.total_iteration:
|
||||
return []
|
||||
|
||||
task_id = self._task_id_candidate
|
||||
|
@ -77,9 +99,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
|
|||
|
||||
|
||||
class AGPTaskGenerator(FunctionBasedTaskGenerator):
|
||||
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]:
|
||||
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
|
||||
config_list = []
|
||||
for target, mo in zip(target_sparsity, model_based_sparsity):
|
||||
for target, mo in zip(target_sparsity, compact2origin_sparsity):
|
||||
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
|
||||
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
|
||||
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
|
@ -89,12 +111,223 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
|
|||
|
||||
|
||||
class LinearTaskGenerator(FunctionBasedTaskGenerator):
|
||||
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]:
|
||||
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
|
||||
config_list = []
|
||||
for target, mo in zip(target_sparsity, model_based_sparsity):
|
||||
for target, mo in zip(target_sparsity, compact2origin_sparsity):
|
||||
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
|
||||
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
|
||||
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
config_list.append(deepcopy(target))
|
||||
config_list[-1]['total_sparsity'] = sparsity
|
||||
return config_list
|
||||
|
||||
|
||||
class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
|
||||
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
|
||||
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
|
||||
super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir,
|
||||
keep_intermidiate_result=keep_intermidiate_result)
|
||||
self.current_iteration = 1
|
||||
|
||||
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
|
||||
config_list = []
|
||||
for target, mo in zip(target_sparsity, compact2origin_sparsity):
|
||||
# NOTE: The ori_sparsity calculation formula in compression v1 is as follow, it is different from the paper.
|
||||
# But the formula in paper will cause numerical problems, so keep the formula in compression v1.
|
||||
ori_sparsity = 1 - (1 - target['total_sparsity']) ** (iteration / self.total_iteration)
|
||||
# The following is the formula in paper.
|
||||
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
|
||||
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
|
||||
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
config_list.append(deepcopy(target))
|
||||
config_list[-1]['total_sparsity'] = sparsity
|
||||
return config_list
|
||||
|
||||
|
||||
class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
||||
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
|
||||
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
|
||||
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermidiate_result: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
origin_model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
origin_config_list
|
||||
The origin config list provided by the user. Note that this config_list is directly config the origin model.
|
||||
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
|
||||
origin_masks
|
||||
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
|
||||
start_temperature
|
||||
Start temperature of the simulated annealing process.
|
||||
stop_temperature
|
||||
Stop temperature of the simulated annealing process.
|
||||
cool_down_rate
|
||||
Cool down rate of the temperature.
|
||||
perturbation_magnitude
|
||||
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
|
||||
log_dir
|
||||
The log directory use to saving the task generator log.
|
||||
keep_intermidiate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
"""
|
||||
self.start_temperature = start_temperature
|
||||
self.current_temperature = start_temperature
|
||||
self.stop_temperature = stop_temperature
|
||||
self.cool_down_rate = cool_down_rate
|
||||
self.perturbation_magnitude = perturbation_magnitude
|
||||
|
||||
self.weights_numel, self.masked_rate = get_model_weights_numel(origin_model, origin_config_list, origin_masks)
|
||||
self.target_sparsity_list = config_list_canonical(origin_model, origin_config_list)
|
||||
self._adjust_target_sparsity()
|
||||
|
||||
self._temp_config_list = None
|
||||
self._current_sparsity_list = None
|
||||
self._current_score = None
|
||||
|
||||
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
|
||||
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result)
|
||||
|
||||
def _adjust_target_sparsity(self):
|
||||
"""
|
||||
If origin_masks is not empty, then re-scale the target sparsity.
|
||||
"""
|
||||
if len(self.masked_rate) > 0:
|
||||
for config in self.target_sparsity_list:
|
||||
sparsity, op_names = config['total_sparsity'], config['op_names']
|
||||
remaining_weight_numel = 0
|
||||
pruned_weight_numel = 0
|
||||
for name in op_names:
|
||||
remaining_weight_numel += self.weights_numel[name]
|
||||
if name in self.masked_rate:
|
||||
pruned_weight_numel += 1 / (1 / self.masked_rate[name] - 1) * self.weights_numel[name]
|
||||
config['total_sparsity'] = max(0, sparsity - pruned_weight_numel / (pruned_weight_numel + remaining_weight_numel))
|
||||
|
||||
def _init_temp_config_list(self):
|
||||
self._temp_config_list = []
|
||||
self._temp_sparsity_list = []
|
||||
for config in self.target_sparsity_list:
|
||||
sparsity_config, sparsity = self._init_config_sparsity(config)
|
||||
self._temp_config_list.extend(sparsity_config)
|
||||
self._temp_sparsity_list.append(sparsity)
|
||||
|
||||
def _init_config_sparsity(self, config: Dict) -> Tuple[List[Dict], List]:
|
||||
assert 'total_sparsity' in config, 'Sparsity must be set in config: {}'.format(config)
|
||||
target_sparsity = config['total_sparsity']
|
||||
op_names = config['op_names']
|
||||
|
||||
if target_sparsity == 0:
|
||||
return [], []
|
||||
|
||||
while True:
|
||||
random_sparsity = sorted(np.random.uniform(0, 1, len(op_names)))
|
||||
rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names)
|
||||
if rescaled_sparsity is not None and rescaled_sparsity[0] >= 0 and rescaled_sparsity[-1] < 1:
|
||||
break
|
||||
|
||||
return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity
|
||||
|
||||
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> List:
|
||||
assert len(random_sparsity) == len(op_names)
|
||||
|
||||
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
|
||||
sparsity = sorted(random_sparsity)
|
||||
|
||||
total_weights = 0
|
||||
total_weights_pruned = 0
|
||||
|
||||
# calculate the scale
|
||||
for idx, num_weight in enumerate(num_weights):
|
||||
total_weights += num_weight
|
||||
total_weights_pruned += int(num_weight * sparsity[idx])
|
||||
if total_weights_pruned == 0:
|
||||
return None
|
||||
|
||||
scale = target_sparsity / (total_weights_pruned / total_weights)
|
||||
|
||||
# rescale the sparsity
|
||||
sparsity = np.asarray(sparsity) * scale
|
||||
return sparsity
|
||||
|
||||
def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]:
|
||||
sparsity = sorted(sparsity)
|
||||
op_names = [k for k, _ in sorted(self.weights_numel.items(), key=lambda item: item[1]) if k in config['op_names']]
|
||||
assert len(sparsity) == len(op_names)
|
||||
return [{'total_sparsity': sparsity, 'op_names': [op_name]} for sparsity, op_name in zip(sparsity, op_names)]
|
||||
|
||||
def _update_with_perturbations(self):
|
||||
self._temp_config_list = []
|
||||
self._temp_sparsity_list = []
|
||||
# decrease magnitude with current temperature
|
||||
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
|
||||
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
|
||||
if len(current_sparsity) == 0:
|
||||
self._temp_sparsity_list.append([])
|
||||
continue
|
||||
while True:
|
||||
perturbation = np.random.uniform(-magnitude, magnitude, len(current_sparsity))
|
||||
temp_sparsity = np.clip(0, current_sparsity + perturbation, None)
|
||||
temp_sparsity = self._rescale_sparsity(temp_sparsity, config['total_sparsity'], config['op_names'])
|
||||
if temp_sparsity is not None and temp_sparsity[0] >= 0 and temp_sparsity[-1] < 1:
|
||||
self._temp_config_list.extend(self._sparsity_to_config_list(temp_sparsity, config))
|
||||
self._temp_sparsity_list.append(temp_sparsity)
|
||||
break
|
||||
|
||||
def _recover_real_sparsity(self, config_list: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
If the origin masks is not None, then the sparsity in new generated config_list need to be rescaled.
|
||||
"""
|
||||
for config in config_list:
|
||||
assert len(config['op_names']) == 1
|
||||
op_name = config['op_names'][0]
|
||||
if op_name in self.masked_rate:
|
||||
config['total_sparsity'] = self.masked_rate[op_name] + config['total_sparsity'] * (1 - self.masked_rate[op_name])
|
||||
return config_list
|
||||
|
||||
def init_pending_tasks(self) -> List[Task]:
|
||||
origin_model = torch.load(self._origin_model_path)
|
||||
origin_masks = torch.load(self._origin_masks_path)
|
||||
|
||||
self.temp_model_path = Path(self._intermidiate_result_dir, 'origin_compact_model.pth')
|
||||
self.temp_masks_path = Path(self._intermidiate_result_dir, 'origin_compact_model_masks.pth')
|
||||
torch.save(origin_model, self.temp_model_path)
|
||||
torch.save(origin_masks, self.temp_masks_path)
|
||||
|
||||
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
|
||||
|
||||
return self.generate_tasks(task_result)
|
||||
|
||||
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
|
||||
# initial/update temp config list
|
||||
if self._temp_config_list is None:
|
||||
self._init_temp_config_list()
|
||||
else:
|
||||
score = self._tasks[task_result.task_id].score
|
||||
if self._current_sparsity_list is None:
|
||||
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
|
||||
self._current_score = score
|
||||
else:
|
||||
delta_E = np.abs(score - self._current_score)
|
||||
probability = np.exp(-1 * delta_E / self.current_temperature)
|
||||
if self._current_score < score or np.random.uniform(0, 1) < probability:
|
||||
self._current_score = score
|
||||
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
|
||||
self.current_temperature *= self.cool_down_rate
|
||||
if self.current_temperature < self.stop_temperature:
|
||||
return []
|
||||
self._update_with_perturbations()
|
||||
|
||||
task_id = self._task_id_candidate
|
||||
new_config_list = self._recover_real_sparsity(deepcopy(self._temp_config_list))
|
||||
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id))
|
||||
|
||||
with Path(config_list_path).open('w') as f:
|
||||
json_tricks.dump(new_config_list, f, indent=4)
|
||||
|
||||
task = Task(task_id, self.temp_model_path, self.temp_masks_path, config_list_path)
|
||||
|
||||
self._tasks[task_id] = task
|
||||
|
||||
self._task_id_candidate += 1
|
||||
|
||||
return [task]
|
||||
|
|
Загрузка…
Ссылка в новой задаче