[Model Compression] Add more Task Generator (#4178)

This commit is contained in:
J-shang 2021-09-24 17:03:03 +08:00 коммит произвёл GitHub
Родитель 7a50c96d04
Коммит a16e570ddb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 314 добавлений и 13 удалений

Просмотреть файл

@ -15,7 +15,8 @@ from .tools import TaskGenerator
class PruningScheduler(BasePruningScheduler):
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None):
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False):
"""
Parameters
----------
@ -33,6 +34,8 @@ class PruningScheduler(BasePruningScheduler):
evaluator
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
self.pruner = pruner
self.task_generator = task_generator
@ -40,6 +43,7 @@ class PruningScheduler(BasePruningScheduler):
self.speed_up = speed_up
self.dummy_input = dummy_input
self.evaluator = evaluator
self.reset_weight = reset_weight
def generate_task(self) -> Optional[Task]:
return self.task_generator.next()
@ -47,12 +51,15 @@ class PruningScheduler(BasePruningScheduler):
def record_task_result(self, task_result: TaskResult):
self.task_generator.receive_task_result(task_result)
def pruning_one_step(self, task: Task) -> TaskResult:
def pruning_one_step_normal(self, task: Task) -> TaskResult:
"""
generate masks -> speed up -> finetune -> evaluate
"""
model, masks, config_list = task.load_data()
# pruning model
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress()
compact_model_masks = deepcopy(pruner_generated_masks)
@ -75,12 +82,71 @@ class PruningScheduler(BasePruningScheduler):
self.pruner._unwrap_model()
# evaluate
score = self.evaluator(compact_model) if self.evaluator is not None else None
if self.evaluator is not None:
if self.speed_up:
score = self.evaluator(compact_model)
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
def pruning_one_step_reset_weight(self, task: Task) -> TaskResult:
"""
finetune -> generate masks -> reset weight -> speed up -> evaluate
"""
model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict())
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
# finetune
if self.finetuner is not None:
self.finetuner(model)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress()
compact_model_masks = deepcopy(pruner_generated_masks)
# show the pruning effect
self.pruner.show_pruned_weights()
self.pruner._unwrap_model()
# reset model weight
compact_model.load_state_dict(checkpoint)
# speed up
if self.speed_up:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate
if self.evaluator is not None:
if self.speed_up:
score = self.evaluator(compact_model)
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
def pruning_one_step(self, task: Task) -> TaskResult:
if self.reset_weight:
return self.pruning_one_step_reset_weight(task)
else:
return self.pruning_one_step_normal(task)
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
return self.task_generator.get_best_result()

Просмотреть файл

@ -24,5 +24,7 @@ from .sparsity_allocator import (
)
from .task_generator import (
AGPTaskGenerator,
LinearTaskGenerator
LinearTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)

Просмотреть файл

@ -4,15 +4,20 @@
from copy import deepcopy
import logging
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Tuple
import json_tricks
import numpy as np
from torch import Tensor
import torch
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical, compute_sparsity
from nni.algorithms.compression.v2.pytorch.utils.pruning import (
config_list_canonical,
compute_sparsity,
get_model_weights_numel
)
from .base import TaskGenerator
_logger = logging.getLogger(__name__)
@ -21,6 +26,23 @@ _logger = logging.getLogger(__name__)
class FunctionBasedTaskGenerator(TaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
"""
Parameters
----------
total_iteration
The total iteration number.
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.current_iteration = 0
self.target_sparsity = config_list_canonical(origin_model, origin_config_list)
self.total_iteration = total_iteration
@ -54,7 +76,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if self.current_iteration >= self.total_iteration:
if self.current_iteration > self.total_iteration:
return []
task_id = self._task_id_candidate
@ -77,9 +99,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
class AGPTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]:
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = []
for target, mo in zip(target_sparsity, model_based_sparsity):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
@ -89,12 +111,223 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
class LinearTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]:
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = []
for target, mo in zip(target_sparsity, model_based_sparsity):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir,
keep_intermidiate_result=keep_intermidiate_result)
self.current_iteration = 1
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = []
for target, mo in zip(target_sparsity, compact2origin_sparsity):
# NOTE: The ori_sparsity calculation formula in compression v1 is as follow, it is different from the paper.
# But the formula in paper will cause numerical problems, so keep the formula in compression v1.
ori_sparsity = 1 - (1 - target['total_sparsity']) ** (iteration / self.total_iteration)
# The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermidiate_result: bool = False):
"""
Parameters
----------
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
start_temperature
Start temperature of the simulated annealing process.
stop_temperature
Stop temperature of the simulated annealing process.
cool_down_rate
Cool down rate of the temperature.
perturbation_magnitude
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.start_temperature = start_temperature
self.current_temperature = start_temperature
self.stop_temperature = stop_temperature
self.cool_down_rate = cool_down_rate
self.perturbation_magnitude = perturbation_magnitude
self.weights_numel, self.masked_rate = get_model_weights_numel(origin_model, origin_config_list, origin_masks)
self.target_sparsity_list = config_list_canonical(origin_model, origin_config_list)
self._adjust_target_sparsity()
self._temp_config_list = None
self._current_sparsity_list = None
self._current_score = None
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result)
def _adjust_target_sparsity(self):
"""
If origin_masks is not empty, then re-scale the target sparsity.
"""
if len(self.masked_rate) > 0:
for config in self.target_sparsity_list:
sparsity, op_names = config['total_sparsity'], config['op_names']
remaining_weight_numel = 0
pruned_weight_numel = 0
for name in op_names:
remaining_weight_numel += self.weights_numel[name]
if name in self.masked_rate:
pruned_weight_numel += 1 / (1 / self.masked_rate[name] - 1) * self.weights_numel[name]
config['total_sparsity'] = max(0, sparsity - pruned_weight_numel / (pruned_weight_numel + remaining_weight_numel))
def _init_temp_config_list(self):
self._temp_config_list = []
self._temp_sparsity_list = []
for config in self.target_sparsity_list:
sparsity_config, sparsity = self._init_config_sparsity(config)
self._temp_config_list.extend(sparsity_config)
self._temp_sparsity_list.append(sparsity)
def _init_config_sparsity(self, config: Dict) -> Tuple[List[Dict], List]:
assert 'total_sparsity' in config, 'Sparsity must be set in config: {}'.format(config)
target_sparsity = config['total_sparsity']
op_names = config['op_names']
if target_sparsity == 0:
return [], []
while True:
random_sparsity = sorted(np.random.uniform(0, 1, len(op_names)))
rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names)
if rescaled_sparsity is not None and rescaled_sparsity[0] >= 0 and rescaled_sparsity[-1] < 1:
break
return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> List:
assert len(random_sparsity) == len(op_names)
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
sparsity = sorted(random_sparsity)
total_weights = 0
total_weights_pruned = 0
# calculate the scale
for idx, num_weight in enumerate(num_weights):
total_weights += num_weight
total_weights_pruned += int(num_weight * sparsity[idx])
if total_weights_pruned == 0:
return None
scale = target_sparsity / (total_weights_pruned / total_weights)
# rescale the sparsity
sparsity = np.asarray(sparsity) * scale
return sparsity
def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]:
sparsity = sorted(sparsity)
op_names = [k for k, _ in sorted(self.weights_numel.items(), key=lambda item: item[1]) if k in config['op_names']]
assert len(sparsity) == len(op_names)
return [{'total_sparsity': sparsity, 'op_names': [op_name]} for sparsity, op_name in zip(sparsity, op_names)]
def _update_with_perturbations(self):
self._temp_config_list = []
self._temp_sparsity_list = []
# decrease magnitude with current temperature
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0:
self._temp_sparsity_list.append([])
continue
while True:
perturbation = np.random.uniform(-magnitude, magnitude, len(current_sparsity))
temp_sparsity = np.clip(0, current_sparsity + perturbation, None)
temp_sparsity = self._rescale_sparsity(temp_sparsity, config['total_sparsity'], config['op_names'])
if temp_sparsity is not None and temp_sparsity[0] >= 0 and temp_sparsity[-1] < 1:
self._temp_config_list.extend(self._sparsity_to_config_list(temp_sparsity, config))
self._temp_sparsity_list.append(temp_sparsity)
break
def _recover_real_sparsity(self, config_list: List[Dict]) -> List[Dict]:
"""
If the origin masks is not None, then the sparsity in new generated config_list need to be rescaled.
"""
for config in config_list:
assert len(config['op_names']) == 1
op_name = config['op_names'][0]
if op_name in self.masked_rate:
config['total_sparsity'] = self.masked_rate[op_name] + config['total_sparsity'] * (1 - self.masked_rate[op_name])
return config_list
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path)
self.temp_model_path = Path(self._intermidiate_result_dir, 'origin_compact_model.pth')
self.temp_masks_path = Path(self._intermidiate_result_dir, 'origin_compact_model_masks.pth')
torch.save(origin_model, self.temp_model_path)
torch.save(origin_masks, self.temp_masks_path)
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# initial/update temp config list
if self._temp_config_list is None:
self._init_temp_config_list()
else:
score = self._tasks[task_result.task_id].score
if self._current_sparsity_list is None:
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
self._current_score = score
else:
delta_E = np.abs(score - self._current_score)
probability = np.exp(-1 * delta_E / self.current_temperature)
if self._current_score < score or np.random.uniform(0, 1) < probability:
self._current_score = score
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
self.current_temperature *= self.cool_down_rate
if self.current_temperature < self.stop_temperature:
return []
self._update_with_perturbations()
task_id = self._task_id_candidate
new_config_list = self._recover_real_sparsity(deepcopy(self._temp_config_list))
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4)
task = Task(task_id, self.temp_model_path, self.temp_masks_path, config_list_path)
self._tasks[task_id] = task
self._task_id_candidate += 1
return [task]