зеркало из https://github.com/microsoft/nni.git
[Compression] Add support for deepspeed (#5517)
Co-authored-by: xinzhang3 <xinzhang3@microsoft.com>
This commit is contained in:
Родитель
739a6d6ece
Коммит
9053d651a6
|
@ -23,4 +23,4 @@ graphviz
|
|||
gym
|
||||
sympy
|
||||
tianshou >= 0.4.1
|
||||
timm >= 0.5.4
|
||||
timm >= 0.5.4
|
|
@ -18,4 +18,4 @@ timm >= 0.5.4
|
|||
|
||||
keras
|
||||
tensorflow == 2.3
|
||||
protobuf <= 3.20.1
|
||||
protobuf <= 3.20.1
|
|
@ -158,3 +158,88 @@ Moreover, if you are utilizing a personalized optimizer or learning rate schedul
|
|||
|
||||
optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001)
|
||||
lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch)
|
||||
|
||||
|
||||
A complete example of using a trainer with DeepSpeed mode under the TransformersEvaluator can be found: :githublink:`here <examples/model_compress/quantization/bert_quantization_with_ds.py>`.
|
||||
|
||||
|
||||
DeepspeedTorchEvaluator
|
||||
-----------------------
|
||||
|
||||
:class:`DeepspeedTorchEvaluator <nni.contrib.compression.DeepspeedTorchEvaluator>` is an evaluator designed specifically for native PyTorch users who are utilizing DeepSpeed.
|
||||
|
||||
:class:`DeepspeedTorchEvaluator <nni.contrib.compression.TorchEvaluator>` has eight initialization parameters ``training_func``, ``training_step``, ``deepspeed``, ``optimizer``, ``lr_scheduler``,
|
||||
``resume_from_checkpoint_args``, ``dummy_input``, ``evaluating_func``.
|
||||
|
||||
* ``training_func`` is the training loop to train the compressed model.
|
||||
It is a callable function with six input parameters ``model``, ``optimizers``,
|
||||
``training_step``, ``lr_schedulers``, ``max_steps``, ``max_epochs``.
|
||||
Please make sure each input argument of the ``training_func`` is actually used,
|
||||
especially ``max_steps`` and ``max_epochs`` can correctly control the duration of training.
|
||||
* ``training_step`` A callable function, the first argument of inputs should be ``batch``, and the outputs should contain loss.
|
||||
Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a key ``loss``.
|
||||
* ``deepspeed`` is the deepspeed configuration which Contains the parameters needed in DeepSpeed, such as train_batch_size, among others.
|
||||
* ``optimizer`` is a single traced optimizer instance or a function that takes the model parameters as input and returns an optimizer instance.
|
||||
Please make sure using ``nni.trace`` wrapping the ``Optimizer`` class before initializing it / them if it is a single traced optimizer.
|
||||
* ``lr_scheduler`` is a single traced lr_scheduler instance or a function that takes the model parameters and the optimizer as input and returns an lr_scheduler instance.
|
||||
Please make sure using ``nni.trace`` wrapping the ``_LRScheduler`` class before initializing it / them if it is a single traced scheduler.
|
||||
* ``resume_from_checkpoint_args`` is used in the deepspeed_init process to load models saved during training with DeepSpeed.
|
||||
* ``dummy_input`` is used to trace the model, same as ``example_inputs``
|
||||
in `torch.jit.trace <https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace>`_.
|
||||
* ``evaluating_func`` is a callable function to evaluate the compressed model performance. Its input is a compressed model and its output is metric.
|
||||
The format of metric should be a float number or a dict with key ``default``.
|
||||
|
||||
Please refer :class:`DeepspeedTorchEvaluator <nni.contrib.compression.DeepspeedTorchEvaluator>` for more details.
|
||||
Here is an example of how to initialize a :class:`DeepspeedTorchEvaluator <nni.contrib.compression.DeepspeedTorchEvaluator>`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(batch, model, *args, **kwargs):
|
||||
output = model(batch[0])
|
||||
loss = F.cross_entropy(output, batch[1])
|
||||
return loss
|
||||
|
||||
def training_func(model, optimizer, training_step, lr_scheduler, max_steps, max_epochs):
|
||||
# here model is an instance of DeepSpeedEngine
|
||||
assert max_steps is not None or max_epochs is not None
|
||||
total_steps = max_steps if max_steps else max_epochs * len(train_dataloader)
|
||||
total_epochs = total_steps // len(train_dataloader) + (0 if total_steps % len(train_dataloader) == 0 else 1)
|
||||
|
||||
current_step = 0
|
||||
for _ in range(total_epochs):
|
||||
for batch in train_dataloader:
|
||||
loss = training_step(batch, model)
|
||||
model.backward(model)
|
||||
model.step()
|
||||
|
||||
# if reach the total steps, exit from the training loop
|
||||
current_step = current_step + 1
|
||||
if current_step >= total_steps:
|
||||
return
|
||||
|
||||
# if you are using a epoch-wise scheduler, call it here
|
||||
lr_scheduler.step()
|
||||
|
||||
ds_config = {
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": False,
|
||||
"train_batch_size": 128,
|
||||
"train_micro_batch_size_per_gpu": 128,
|
||||
"zero_force_ds_cpu_optimizer": False,
|
||||
"zero_allow_untested_optimizer": True
|
||||
}
|
||||
|
||||
optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001)
|
||||
lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch)
|
||||
|
||||
evaluator = DeepspeedTorchEvaluator(training_func, training_step, ds_config, lr_scheduler)
|
||||
|
||||
.. note::
|
||||
It is also worth to note that not all the arguments of :class:`TorchEvaluator <nni.contrib.compression.TorchEvaluator>` must be provided.
|
||||
Some compressors only require ``evaluate_func`` as they do not train the model, some compressors only require ``training_func``.
|
||||
Please refer to each compressor's doc to check the required arguments.
|
||||
But, it is fine to provide more arguments than the compressor's need.
|
||||
|
||||
|
||||
A complete example can be found :githublink:`here <examples/model_compress/quantization/quantization_with_deepspeed.py>`.
|
|
@ -21,3 +21,10 @@ TransformersEvaluator
|
|||
---------------------
|
||||
|
||||
.. autoclass:: nni.contrib.compression.TransformersEvaluator
|
||||
|
||||
.. _new-deepspeed-torch-evaluator:
|
||||
|
||||
DeepspeedTorchEvaluator
|
||||
-----------------------
|
||||
|
||||
.. autoclass:: nni.contrib.compression.DeepspeedTorchEvaluator
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
|
@ -34,7 +34,7 @@
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png
|
||||
:alt: Speedup Model with Mask
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_speedup.py`
|
||||
|
||||
|
@ -34,7 +34,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_thumb.png
|
||||
:alt: Pruning Quickstart
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_pruning_quick_start.py`
|
||||
|
||||
|
@ -51,7 +51,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png
|
||||
:alt: Use NAS Benchmarks as Datasets
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_nasbench_as_dataset.py`
|
||||
|
||||
|
@ -68,7 +68,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_thumb.png
|
||||
:alt: Quantization Quickstart
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_quick_start.py`
|
||||
|
||||
|
@ -85,7 +85,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png
|
||||
:alt: Speed Up Quantized Model with TensorRT
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_speedup.py`
|
||||
|
||||
|
@ -102,7 +102,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png
|
||||
:alt: Hello, NAS!
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hello_nas.py`
|
||||
|
||||
|
@ -119,7 +119,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_quantization_bert_glue_thumb.png
|
||||
:alt: Quantize BERT on Task GLUE
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_quantization_bert_glue.py`
|
||||
|
||||
|
@ -136,7 +136,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png
|
||||
:alt: Searching in DARTS search space
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_darts.py`
|
||||
|
||||
|
@ -148,12 +148,12 @@ Tutorials
|
|||
|
||||
.. raw:: html
|
||||
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is a new tutorial on pruning transformer in nni v3.0 (`old tutorial <https://nni.readthedo...">
|
||||
<div class="sphx-glr-thumbcontainer" tooltip="This is a new tutorial on pruning transformer in nni v3.0 (old tutorial). The main difference b...">
|
||||
|
||||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/images/thumb/sphx_glr_new_pruning_bert_glue_thumb.png
|
||||
:alt: Pruning Bert on Task MNLI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_new_pruning_bert_glue.py`
|
||||
|
||||
|
@ -196,7 +196,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with PyTorch
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py`
|
||||
|
||||
|
@ -213,7 +213,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port PyTorch Quickstart to NNI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py`
|
||||
|
||||
|
@ -242,7 +242,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png
|
||||
:alt: HPO Quickstart with TensorFlow
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py`
|
||||
|
||||
|
@ -259,7 +259,7 @@ Tutorials
|
|||
.. only:: html
|
||||
|
||||
.. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png
|
||||
:alt: Port TensorFlow Quickstart to NNI
|
||||
:alt:
|
||||
|
||||
:ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py`
|
||||
|
||||
|
@ -278,6 +278,7 @@ Tutorials
|
|||
:hidden:
|
||||
:includehidden:
|
||||
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/index.rst
|
||||
/tutorials/hpo_quickstart_tensorflow/index.rst
|
||||
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
{
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false,
|
||||
"train_batch_size": 128,
|
||||
"train_micro_batch_size_per_gpu": 128,
|
||||
"zero_force_ds_cpu_optimizer": false,
|
||||
"zero_allow_untested_optimizer": true
|
||||
}
|
|
@ -0,0 +1,198 @@
|
|||
from pathlib import Path
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import nni
|
||||
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
|
||||
task_name = 'mnli'
|
||||
finetune_lr = 4e-5
|
||||
quant_lr = 1e-5
|
||||
quant_method = 'lsq'
|
||||
dev_mode = True
|
||||
|
||||
if dev_mode:
|
||||
quant_max_epochs = 1
|
||||
finetune_max_epochs = 1
|
||||
else:
|
||||
quant_max_epochs = 10
|
||||
finetune_max_epochs = 10
|
||||
|
||||
# %%
|
||||
# Load the pre-trained model from the transformers
|
||||
|
||||
def build_model(pretrained_model_name_or_path: str, task_name: str):
|
||||
is_regression = task_name == 'stsb'
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
|
||||
return model
|
||||
|
||||
# %%
|
||||
# Create datasets on the specific task GLUE
|
||||
|
||||
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
|
||||
task_to_keys = {
|
||||
'cola': ('sentence', None),
|
||||
'mnli': ('premise', 'hypothesis'),
|
||||
'mrpc': ('sentence1', 'sentence2'),
|
||||
'qnli': ('question', 'sentence'),
|
||||
'qqp': ('question1', 'question2'),
|
||||
'rte': ('sentence1', 'sentence2'),
|
||||
'sst2': ('sentence', None),
|
||||
'stsb': ('sentence1', 'sentence2'),
|
||||
'wnli': ('sentence1', 'sentence2'),
|
||||
}
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if 'label' in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result['labels'] = examples['label']
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
|
||||
for key in list(raw_datasets.keys()):
|
||||
if 'test' in key:
|
||||
raw_datasets.pop(key)
|
||||
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
|
||||
remove_columns=raw_datasets['train'].column_names)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
if task_name == 'mnli':
|
||||
validation_datasets = {
|
||||
'validation_matched': processed_datasets['validation_matched'],
|
||||
'validation_mismatched': processed_datasets['validation_mismatched']
|
||||
}
|
||||
else:
|
||||
validation_datasets = {
|
||||
'validation': processed_datasets['validation']
|
||||
}
|
||||
|
||||
return train_dataset, validation_datasets
|
||||
|
||||
|
||||
def prepare_traced_trainer(model, load_best_model_at_end=False, is_quant=False):
|
||||
is_regression = task_name == 'stsb'
|
||||
metric = load_metric('glue', task_name)
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
||||
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
|
||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||
result['default'] = result.get('f1', result.get('accuracy', 0.))
|
||||
return result
|
||||
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, '')
|
||||
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()]) # type: ignore
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
training_args = TrainingArguments(output_dir='./output/trainer',
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
fp16=False,
|
||||
learning_rate=3e-5,
|
||||
evaluation_strategy='steps',
|
||||
per_device_train_batch_size=128, #128,
|
||||
per_device_eval_batch_size=128, #128,
|
||||
num_train_epochs=finetune_max_epochs,
|
||||
dataloader_num_workers=12,
|
||||
save_strategy='steps',
|
||||
save_total_limit=1,
|
||||
metric_for_best_model='default',
|
||||
greater_is_better=True,
|
||||
seed=1024,
|
||||
eval_steps=100,
|
||||
deepspeed="./bert_ds_config.json",
|
||||
load_best_model_at_end=load_best_model_at_end,)
|
||||
# if is_quant:
|
||||
# training_args.learning_rate = quant_lr
|
||||
# else:
|
||||
# training_args.learning_rate = finetune_lr
|
||||
print(f"=== learning_rate:{training_args.learning_rate} ====")
|
||||
trainer = nni.trace(Trainer)(model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=merged_validation_dataset,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def build_finetuning_model(state_dict_path: str, is_quant=False):
|
||||
model = build_model('bert-base-uncased', task_name)
|
||||
if Path(state_dict_path).exists():
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
print(f"==== load finetune model directly ====")
|
||||
else:
|
||||
trainer = prepare_traced_trainer(model, True, is_quant)
|
||||
trainer.train()
|
||||
torch.save(model.state_dict(), state_dict_path)
|
||||
return model
|
||||
|
||||
|
||||
import nni
|
||||
from nni.contrib.compression.quantization import QATQuantizer, LsqQuantizer, PtqQuantizer
|
||||
from nni.contrib.compression.utils import TransformersEvaluator
|
||||
|
||||
|
||||
def fake_quantize():
|
||||
config_list = [{
|
||||
'op_types': ['Linear'],
|
||||
'op_names_re': ['bert.encoder.layer.{}'.format(i) for i in range(12)],
|
||||
'target_names': ['weight', '_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'symmetric',
|
||||
'granularity': 'default',
|
||||
}]
|
||||
|
||||
# create a finetune model
|
||||
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
|
||||
model: torch.nn.Module = build_finetuning_model(f'./output/{task_name}.bin', is_quant=False) # type: ignore
|
||||
traced_trainer = prepare_traced_trainer(model, is_quant=False)
|
||||
evaluator = TransformersEvaluator(traced_trainer)
|
||||
if quant_method == 'lsq':
|
||||
quantizer = LsqQuantizer(model, config_list, evaluator)
|
||||
model, calibration_config = quantizer.compress(max_steps=None, max_epochs=quant_max_epochs)
|
||||
elif quant_method == 'qat':
|
||||
quantizer = QATQuantizer(model, config_list, evaluator, 10)
|
||||
model, calibration_config = quantizer.compress(max_steps=None, max_epochs=quant_max_epochs)
|
||||
elif quant_method == 'ptq':
|
||||
quantizer = PtqQuantizer(model, config_list, evaluator)
|
||||
model, calibration_config = quantizer.compress(max_steps=1, max_epochs=None)
|
||||
else:
|
||||
raise ValueError(f"quantization method {quant_method} is not supported")
|
||||
print(calibration_config)
|
||||
# evaluate the performance of the fake quantize model
|
||||
quantizer.evaluator.bind_model(model, quantizer._get_param_names_map())
|
||||
|
||||
|
||||
def evaluate():
|
||||
model = build_finetuning_model(f'./output/{task_name}.bin', is_quant=False)
|
||||
trainer = prepare_traced_trainer(model, is_quant=False)
|
||||
metrics = trainer.evaluate()
|
||||
print(f"Evaluate metrics={metrics}")
|
||||
|
||||
|
||||
fake_quantize()
|
||||
# evaluate()
|
|
@ -0,0 +1,47 @@
|
|||
{
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [0.99, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 0
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false,
|
||||
"train_batch_size": 128,
|
||||
"train_micro_batch_size_per_gpu": 128,
|
||||
"zero_force_ds_cpu_optimizer": false,
|
||||
"zero_allow_untested_optimizer": true
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sys
|
||||
from typing import Callable, Union, List, Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import Tensor
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
from deepspeed import DeepSpeedEngine
|
||||
from nni.contrib.compression.quantization import LsqQuantizer
|
||||
from nni.contrib.compression.utils import DeepspeedTorchEvaluator
|
||||
from nni.common.types import SCHEDULER
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
device = 'cuda'
|
||||
_TRAINING_STEP = Callable[..., Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]]
|
||||
|
||||
datasets.MNIST(root='data/mnist', train=True, download=True)
|
||||
datasets.MNIST(root='data/mnist', train=False, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = datasets.MNIST(root='data/mnist', train=True, transform=transform)
|
||||
train_dataloader = DataLoader(mnist_train, batch_size=64)
|
||||
mnist_test = datasets.MNIST(root='data/mnist', train=False, transform=transform)
|
||||
test_dataloader = DataLoader(mnist_test, batch_size=1000)
|
||||
|
||||
|
||||
class Mnist(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
|
||||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
|
||||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
|
||||
self.fc2 = torch.nn.Linear(500, 10)
|
||||
self.relu1 = torch.nn.ReLU6()
|
||||
self.relu2 = torch.nn.ReLU6()
|
||||
self.relu3 = torch.nn.ReLU6()
|
||||
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
|
||||
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
|
||||
self.batchnorm1 = torch.nn.BatchNorm2d(20)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.batchnorm1(self.conv1(x)))
|
||||
x = self.max_pool1(x)
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.max_pool2(x)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = self.relu3(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def training_step(batch, model) -> Tensor:
|
||||
x, y = batch[0].to(device), batch[1].to(device)
|
||||
logits = model(x)
|
||||
loss: torch.Tensor = F.nll_loss(logits, y)
|
||||
return loss
|
||||
|
||||
|
||||
def training_model(model: DeepSpeedEngine, optimizer: Union[Optimizer, List[Optimizer]], \
|
||||
training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None,
|
||||
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
|
||||
assert isinstance(model, DeepSpeedEngine)
|
||||
model.train()
|
||||
max_epochs = max_epochs or (40 if max_steps is None else 100)
|
||||
current_steps = 0
|
||||
best_acc = 0.0
|
||||
|
||||
# training
|
||||
for epoch in range(max_epochs):
|
||||
print(f'Epoch {epoch} start!')
|
||||
model.train()
|
||||
for batch in train_dataloader:
|
||||
if isinstance(optimizer, Optimizer):
|
||||
optimizer.zero_grad()
|
||||
elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer):
|
||||
for opt in optimizer:
|
||||
opt.zero_grad()
|
||||
loss = training_step(batch, model)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
current_steps += 1
|
||||
if max_steps and current_steps == max_steps:
|
||||
return
|
||||
|
||||
acc = evaluating_model(model)
|
||||
best_acc = max(acc, best_acc)
|
||||
print(f"epoch={epoch}\tacc={acc}\tbest_acc={best_acc}")
|
||||
|
||||
|
||||
def evaluating_model(model: torch.nn.Module):
|
||||
model.eval()
|
||||
# testing
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for x, y in test_dataloader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
correct += preds.eq(y.view_as(preds)).sum().item()
|
||||
print(f'Accuracy: {100 * correct / len(mnist_test)}%)\n')
|
||||
|
||||
return correct / len(mnist_test)
|
||||
|
||||
|
||||
def main():
|
||||
model = Mnist().to(device)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight'],
|
||||
'op_names': ['conv2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},{
|
||||
'op_names': ['relu1', 'relu2'],
|
||||
'target_names': ['_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},{
|
||||
'op_names': ['max_pool2'],
|
||||
'target_names': ['_output_'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'op_names': ['conv1'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'fuse_names': [("conv1", "batchnorm1")]
|
||||
}]
|
||||
|
||||
evaluator = DeepspeedTorchEvaluator(training_model, training_step, "./ds_config.json") #, lrs)
|
||||
quantizer = LsqQuantizer(model, configure_list, evaluator)
|
||||
model, calibration_config = quantizer.compress(None, 4)
|
||||
acc = evaluating_model(model)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -334,7 +334,7 @@ def build_channel_dependency(graph_module: torch.fx.GraphModule,
|
|||
# To determine if this cat operation will introduce channel
|
||||
# dependency, we need the specific input parameters of the cat
|
||||
# operation.
|
||||
cat_dim = node.kwargs.get('dim', None) or node.args[1]
|
||||
cat_dim = node.kwargs.get('dim', None) if node.kwargs.get('dim', None) is not None else node.args[1]
|
||||
if cat_dim != 1:
|
||||
d_set = set(find_adjacent_layers(node, graph_module, target_types, 'parent'))
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .utils.evaluator import LightningEvaluator, TorchEvaluator, TransformersEvaluator
|
||||
from .utils.evaluator import (LightningEvaluator,
|
||||
TorchEvaluator,
|
||||
TransformersEvaluator,
|
||||
DeepspeedTorchEvaluator)
|
||||
|
|
|
@ -98,7 +98,7 @@ class LsqQuantizer(Quantizer):
|
|||
# init_target = target.data.detach().abs().mean() * 2 / (target_space.qmax ** 0.5)
|
||||
init_target = torch.tensor([0.01]).to(target.device)
|
||||
if not target_space._scaler:
|
||||
target_space.scale.data = init_target # type: ignore
|
||||
target_space.scale.data = init_target.view(1) # type: ignore
|
||||
target_space.zero_point = torch.tensor(0.0).to(target.device)
|
||||
else:
|
||||
new_target = init_target.expand(target.shape).to(target.device)
|
||||
|
@ -117,7 +117,16 @@ class LsqQuantizer(Quantizer):
|
|||
for target_name, _ in ts.items():
|
||||
if hasattr(wrapper, f"{target_name}_scale"):
|
||||
delattr(wrapper, f"{target_name}_scale")
|
||||
param = torch.nn.Parameter()
|
||||
# for deepspeed
|
||||
try:
|
||||
device = next(wrapper.parameters()).device
|
||||
except StopIteration:
|
||||
try:
|
||||
device = next(wrapper.buffers()).device
|
||||
except StopIteration:
|
||||
# NOTE: this will have risk in model parallel
|
||||
device = next(self.bound_model.parameters()).device
|
||||
param = torch.nn.Parameter(torch.Tensor([0.01]).to(device))
|
||||
wrapper.register_parameter(f"{target_name}_scale", param)
|
||||
|
||||
def patch_optimizer_param_group(self):
|
||||
|
|
|
@ -3,7 +3,11 @@
|
|||
|
||||
from .check_ddp import check_ddp_model, reset_ddp_model
|
||||
from .dependency import auto_set_denpendency_group_ids
|
||||
from .evaluator import Evaluator, LightningEvaluator, TorchEvaluator, TransformersEvaluator
|
||||
from .evaluator import (Evaluator,
|
||||
LightningEvaluator,
|
||||
TorchEvaluator,
|
||||
TransformersEvaluator,
|
||||
DeepspeedTorchEvaluator)
|
||||
from .scaling import Scaling
|
||||
from .attr import (
|
||||
get_nested_attr,
|
||||
|
|
|
@ -22,13 +22,46 @@ except ImportError:
|
|||
else:
|
||||
LIGHTNING_INSTALLED = True
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError:
|
||||
DEEPSPEED_INSTALLED = False
|
||||
else:
|
||||
DEEPSPEED_INSTALLED = True
|
||||
|
||||
try:
|
||||
from transformers.trainer import Trainer as HFTrainer
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers import TrainingArguments
|
||||
except ImportError:
|
||||
TRANSFORMERS_INSTALLED = False
|
||||
|
||||
class PatchCallback:
|
||||
def on_train_begin(self, *args, **kwargs):
|
||||
raise RuntimeError("Don't use the fake PatchCallback, please install transformers")
|
||||
else:
|
||||
TRANSFORMERS_INSTALLED = True
|
||||
|
||||
class PatchCallback(TrainerCallback): # type: ignore
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
pass
|
||||
|
||||
try:
|
||||
from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig # type: ignore
|
||||
except ImportError:
|
||||
ACCELERATE_INSTALLED = False
|
||||
class DeepSpeedConfig:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate")
|
||||
def get_value(self, key: str):
|
||||
raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate")
|
||||
def del_config_sub_tree(self, key: str):
|
||||
raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate")
|
||||
def is_zero3(self):
|
||||
raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate")
|
||||
else:
|
||||
ACCELERATE_INSTALLED = True
|
||||
|
||||
import nni
|
||||
from nni.common import is_traceable
|
||||
from nni.common.types import SCHEDULER
|
||||
|
@ -1019,6 +1052,8 @@ class TransformersEvaluator(Evaluator):
|
|||
self._ori_trainer_attr['optimizer.step'] = self.trainer.optimizer.step
|
||||
|
||||
def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]):
|
||||
if self.trainer.args.deepspeed:
|
||||
return
|
||||
assert isinstance(self.model, Module)
|
||||
assert module_name_param_dict is not None
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, self.trainer.optimizer)
|
||||
|
@ -1032,6 +1067,7 @@ class TransformersEvaluator(Evaluator):
|
|||
self.trainer.optimizer = None
|
||||
self._param_names_map = None
|
||||
self._ori_trainer_attr.pop('compute_loss', None)
|
||||
self.trainer.remove_callback(PatchCallback)
|
||||
self.trainer = None # type: ignore
|
||||
self.model = None
|
||||
else:
|
||||
|
@ -1053,19 +1089,27 @@ class TransformersEvaluator(Evaluator):
|
|||
self.trainer.compute_loss = self._ori_trainer_attr['compute_loss']
|
||||
|
||||
def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
|
||||
assert self.trainer.optimizer is not None
|
||||
old_step = self.trainer.optimizer.step
|
||||
def custom_on_train_begin(_, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
optimizer = self.trainer.deepspeed if hasattr(self.trainer, "deepspeed") else self.trainer.callback_handler.optimizer
|
||||
|
||||
def patched_step(_, *args, **kwargs):
|
||||
for task in before_step_tasks:
|
||||
task()
|
||||
# call origin optimizer step method
|
||||
output = old_step(*args, **kwargs)
|
||||
for task in after_step_tasks:
|
||||
task()
|
||||
return output
|
||||
assert optimizer is not None
|
||||
old_step = optimizer.step
|
||||
|
||||
self.trainer.optimizer.step = types.MethodType(patched_step, self.trainer.optimizer)
|
||||
def patched_step(_, *args, **kwargs):
|
||||
for task in before_step_tasks:
|
||||
task()
|
||||
# call origin optimizer step method
|
||||
output = old_step(*args, **kwargs)
|
||||
for task in after_step_tasks:
|
||||
task()
|
||||
return output
|
||||
|
||||
optimizer.step = types.MethodType(patched_step, optimizer)
|
||||
|
||||
PatchCallback.on_train_begin = types.MethodType(custom_on_train_begin, PatchCallback)
|
||||
|
||||
# Add Callback into the callback_handler
|
||||
self.trainer.add_callback(PatchCallback)
|
||||
|
||||
def revert_optimizer_step(self):
|
||||
assert self.trainer.optimizer is not None
|
||||
|
@ -1100,3 +1144,502 @@ class TransformersEvaluator(Evaluator):
|
|||
|
||||
def get_dummy_input(self) -> Any:
|
||||
return self.dummy_input
|
||||
|
||||
|
||||
class DeepspeedTorchEvaluator(Evaluator):
|
||||
"""
|
||||
The DeepseedTorchEvaluator is an evaluator designed specifically for native PyTorch users who are utilizing DeepSpeed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
training_func
|
||||
The training function is used to train the model, note that this a entire optimization training loop.
|
||||
Training function has three required parameters, ``model``, ``optimizers`` and ``training_step``,
|
||||
and three optional parameters, ``lr_schedulers``, ``max_steps``, ``max_epochs``.
|
||||
|
||||
Let's explain these six parameters NNI passed in, but in most cases, users don't need to care about these.
|
||||
Users only need to treat these six parameters as the original parameters during the training process.
|
||||
|
||||
* The ``model`` is a wrapped model from the original model, it has a similar structure to the model to be pruned,
|
||||
so it can share training function with the original model.
|
||||
* ``optimizers`` are re-initialized from the ``optimizers`` passed to the evaluator and the wrapped model's parameters.
|
||||
* ``training_step`` also based on the ``training_step`` passed to the evaluator,
|
||||
it might be modified by the compressor during model compression.
|
||||
* If users use ``lr_schedulers`` in the ``training_func``, NNI will re-initialize the ``lr_schedulers`` with the re-initialized
|
||||
optimizers.
|
||||
* ``max_steps`` is the NNI training duration limitation. It is for pruner (or quantizer) to control the number of training steps.
|
||||
The user implemented ``training_func`` should respect ``max_steps`` by stopping the training loop after ``max_steps`` is reached.
|
||||
Pruner may pass ``None`` to ``max_steps`` when it only controls ``max_epochs``.
|
||||
* ``max_epochs`` is similar to the ``max_steps``, the only different is that it controls the number of training epochs.
|
||||
The user implemented ``training_func`` should respect ``max_epochs`` by stopping the training loop
|
||||
after ``max_epochs`` is reached. Pruner may pass ``None`` to ``max_epochs`` when it only controls ``max_steps``.
|
||||
|
||||
Note that when the pruner passes ``None`` to both ``max_steps`` and ``max_epochs``,
|
||||
it treats ``training_func`` as a function of model fine-tuning.
|
||||
Users should assign proper values to ``max_steps`` and ``max_epochs``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_func(model: DeepSpeedEngine, optimizers: torch.optim.Optimizer,
|
||||
training_step: Callable[[Any, Any], torch.Tensor],
|
||||
lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None,
|
||||
max_epochs: int | None = None, *args, **kwargs):
|
||||
...
|
||||
total_epochs = max_epochs if max_epochs else 20
|
||||
total_steps = max_steps if max_steps else 1000000
|
||||
current_steps = 0
|
||||
...
|
||||
for epoch in range(total_epochs):
|
||||
...
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
if current_steps >= total_steps:
|
||||
return
|
||||
|
||||
Note that ``optimizers`` and ``lr_schedulers`` passed to the ``training_func`` have the same type as the ``optimizers``
|
||||
and ``lr_schedulers`` passed to evaluator, a single ``torch.optim.Optimzier``/ ``torch.optim._LRScheduler`` instance or
|
||||
a list of them.
|
||||
|
||||
training_step
|
||||
A callable function, the first argument of inputs should be ``batch``, and the outputs should contain loss.
|
||||
Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a key ``loss``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(batch, model, ...):
|
||||
inputs, labels = batch
|
||||
output = model(inputs)
|
||||
...
|
||||
loss = loss_func(output, labels)
|
||||
return loss
|
||||
deepspeed
|
||||
Str | dict. The deepspeed configuration which Contains the parameters needed in DeepSpeed, such as train_batch_size, among others.
|
||||
optimzier
|
||||
Optional. A single traced optimizer instance or a function that takes the model parameters
|
||||
as input and returns an optimizer instance. NNI may modify the ``torch.optim.Optimizer`` member function ``step``
|
||||
and/or optimize compressed models, so NNI needs to have the ability to re-initialize the optimizer. ``nni.trace`` can
|
||||
record the initialization parameters of a function/class, which can then be used by NNI to re-initialize the
|
||||
optimizer for a new but structurally similar model. E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
|
||||
lr_schedulers
|
||||
Optional. A single traced lr_scheduler instance or a function that takes the model parameters and the optimizer as input
|
||||
and returns an lr_scheduler instance. For the same reason with ``optimizers``, NNI needs the traced lr_scheduler
|
||||
to re-initialize it.
|
||||
E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``.
|
||||
resume_from_checkpoint_args
|
||||
Dict | None. Used in the deepspeed_init process to load models saved during training with DeepSpeed.
|
||||
Let's explain these seven elements in the resume_from_checkpoint_args.
|
||||
|
||||
* ``load_dir``: The directory to load the checkpoint from.
|
||||
* ``tag`` : Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
|
||||
* ``load_module_strict``: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
|
||||
* ``load_optimizer_states``: Optional. Boolean to load the training optimizer states from Checkpoint.
|
||||
* ``load_lr_scheduler_states``: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
|
||||
* ``load_module_only``: Optional. Boolean to load only the model weights from the checkpoint.
|
||||
* ``custom_load_fn``: Optional. Custom model load function.
|
||||
|
||||
dummy_input
|
||||
Optional. The dummy_input is used to trace the graph, it's same with ``example_inputs`` in
|
||||
`torch.jit.trace <https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace>`_.
|
||||
evaluating_func
|
||||
Optional. A function that input is model and return the evaluation metric.
|
||||
This is the function used to evaluate the compressed model performance.
|
||||
The input is a model and the output is a ``float`` metric or a ``dict``
|
||||
(``dict`` should contains key ``default`` with a ``float`` value).
|
||||
NNI will take the float number as the model score, and assume the higher score means the better performance.
|
||||
If you want to provide additional information, please put it into a dict
|
||||
and NNI will take the value of key ``default`` as evaluation metric.
|
||||
|
||||
Notes
|
||||
-----
|
||||
It is also worth to note that not all the arguments of ``DeepspeedTorchEvaluator`` must be provided.
|
||||
Some pruners (or quantizers) only require ``evaluating_func`` as they do not train the model,
|
||||
some pruners (or quantizers) only require ``training_func``.
|
||||
Please refer to each pruner's (or quantizer's) doc to check the required arguments.
|
||||
But, it is fine to provide more arguments than the pruner's (or quantizer's) need.
|
||||
"""
|
||||
|
||||
def __init__(self, training_func: _TRAINING_FUNC, training_step: _TRAINING_STEP, deepspeed: str | Dict,
|
||||
optimizer: Optimizer | Callable[[List[Tensor]], Optimizer] | None = None,
|
||||
lr_scheduler: SCHEDULER | Callable[[Optimizer], SCHEDULER] | None = None,
|
||||
resume_from_checkpoint_args: Dict | None = None, dummy_input: Any | None = None,
|
||||
evaluating_func: _EVALUATING_FUNC | None = None):
|
||||
assert ACCELERATE_INSTALLED, "accelerate is not installed"
|
||||
assert DEEPSPEED_INSTALLED, "deepspeed is not installed"
|
||||
self.training_func = training_func
|
||||
self._ori_training_step = training_step
|
||||
self._training_step = self._ori_training_step
|
||||
self.dummy_input = dummy_input
|
||||
self.evaluating_func = evaluating_func
|
||||
self.resume_from_checkpoint_args = resume_from_checkpoint_args
|
||||
self._ori_optimizer_step = None
|
||||
|
||||
self.model: Module | None = None
|
||||
self.optimizer: Optimizer | Callable[[List[Tensor]], Optimizer] | None = None
|
||||
self.lr_scheduler: SCHEDULER | Callable[[Optimizer], SCHEDULER] | None = None
|
||||
self._ori_optimizer_step: Callable | None = None
|
||||
self._param_names_map: Dict[str, str] | None = None
|
||||
self.deepspeed_engine = None
|
||||
|
||||
# will del self._tmp_optimizer and self._tmp_lr_scheduler in `_init_optimizer_helpers`
|
||||
self._tmp_optimizer: Optimizer | Callable[[List[Tensor]], Optimizer] | None = optimizer
|
||||
self._tmp_lr_scheduler: SCHEDULER | Callable[[Optimizer], SCHEDULER] | None = lr_scheduler
|
||||
self._initialization_complete = False
|
||||
|
||||
self.deepspeed_config: DeepSpeedConfig | None = self.process_deepspeed(deepspeed)
|
||||
|
||||
def process_deepspeed(self, config_file_or_dict: str | Dict) -> DeepSpeedConfig:
|
||||
if config_file_or_dict is None:
|
||||
raise ValueError('deepspeed_config should not be None')
|
||||
assert isinstance(config_file_or_dict, (Dict, str)), \
|
||||
f"Only two types: Dict and str are supported for config_file_or_dict, but got {type(config_file_or_dict)}"
|
||||
return DeepSpeedConfig(config_file_or_dict)
|
||||
|
||||
def check_optim_sched(self) -> None:
|
||||
assert self._tmp_optimizer is None or isinstance(self._tmp_optimizer, Optimizer) or callable(self._tmp_optimizer)
|
||||
assert self._tmp_lr_scheduler is None or isinstance(self._tmp_lr_scheduler, SCHEDULER) or callable(self._tmp_lr_scheduler)
|
||||
# check the validation of optimizer
|
||||
if isinstance(self._tmp_optimizer, Optimizer):
|
||||
assert is_traceable(self._tmp_optimizer)
|
||||
# check the validation of scheduler
|
||||
if isinstance(self._tmp_lr_scheduler, SCHEDULER):
|
||||
assert is_traceable(self._tmp_lr_scheduler)
|
||||
|
||||
# there are 9 cases:
|
||||
# case 1: opt = None, sche = None, depends on the optimizer configuration in deepspeed_config
|
||||
# case 2: opt = Callback, sche = None, ok
|
||||
# case 3: opt = Optim, sche = None, ok
|
||||
# case 4: opt = None, sche = Callback, depends on the optimizer configuration in deepspeed_config
|
||||
# case 5: opt = Callback, sche = Callback, ok
|
||||
# case 6: opt = Optim, sche = Callback, ok
|
||||
# case 7: opt = None, sche = Scheduler, X
|
||||
# case 8: opt = Callback, sche = Scheduler, X
|
||||
# case 9: opt = Optim, sche = Scheduler, ok
|
||||
assert hasattr(self, "deepspeed_config") and self.deepspeed_config is not None
|
||||
if self._tmp_optimizer is not None and self.deepspeed_config.get_value('optimizer') is not None:
|
||||
raise ValueError("Please provide the optimizer during the evaluator's initialization or in the" +
|
||||
"deepspeed_config, but don\'t provide both at the same time.")
|
||||
|
||||
# case 1: optimizer is None and config is None
|
||||
if self._tmp_optimizer is None and self.deepspeed_config.get_value('optimizer') is None:
|
||||
raise ValueError("Optimizer and optimizer configuration in deepspeed config" +
|
||||
"can\'t be None at the same time, please provide one")
|
||||
# case 2: optimizer is Callable or None, but scheduler is _SCHEUDLER
|
||||
if not isinstance(self._tmp_optimizer, Optimizer) and isinstance(self._tmp_lr_scheduler, SCHEDULER):
|
||||
raise ValueError("Don't support for non-instance optimizer and instance scheduler pair")
|
||||
|
||||
def _init_optimizer_helpers(self, pure_model: Module):
|
||||
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'
|
||||
# check the validation of optimizer and scheduler
|
||||
self.check_optim_sched()
|
||||
if isinstance(self._tmp_optimizer, Optimizer):
|
||||
self._optimizer_helper = OptimizerConstructHelper.from_trace(pure_model, self._tmp_optimizer)
|
||||
else:
|
||||
self.optimizer = self._tmp_optimizer
|
||||
if isinstance(self._tmp_lr_scheduler, SCHEDULER):
|
||||
self._lr_scheduler_helper = LRSchedulerConstructHelper.from_trace(self._tmp_lr_scheduler)
|
||||
else:
|
||||
self.lr_scheduler = self._tmp_lr_scheduler
|
||||
|
||||
delattr(self, '_tmp_optimizer')
|
||||
delattr(self, '_tmp_lr_scheduler')
|
||||
self._initialization_complete = True
|
||||
|
||||
def _rewrap_if_ddp_model(self, model):
|
||||
errmsg = "model is None, no need to rewrap model to DistributedDatapallel model"
|
||||
assert model is not None, errmsg
|
||||
is_ddp_model, _ = check_ddp_model(model)
|
||||
|
||||
if is_ddp_model:
|
||||
raise RuntimeError("DeepSpeed will provide DDP logic so that your model should not be wrapped with DistributedParallel")
|
||||
|
||||
return model
|
||||
|
||||
def bind_model(self, model: Module, param_names_map: Dict[str, str] | None = None):
|
||||
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
|
||||
assert self._initialization_complete is True, err_msg
|
||||
assert isinstance(model, Module)
|
||||
if self.model is not None:
|
||||
_logger.warning('Already bound a model, will unbind it before bind a new model.')
|
||||
self.unbind_model()
|
||||
|
||||
is_ddp_model, _ = check_ddp_model(model)
|
||||
assert not is_ddp_model, \
|
||||
"DeepSpeed will automatically initialize the distributed environment during its initialize"
|
||||
self.model = model
|
||||
self._param_names_map = param_names_map
|
||||
# initialize optimizers & lr_schedulers for the bound model here
|
||||
if hasattr(self, '_optimizer_helper'):
|
||||
self.optimizer = self._optimizer_helper.call(model, param_names_map)
|
||||
if hasattr(self, '_lr_scheduler_helper'):
|
||||
self.lr_scheduler = self._lr_scheduler_helper.call(self.optimizer) # type: ignore
|
||||
|
||||
def deepspeed_init(self, inference=False):
|
||||
assert self.model is not None
|
||||
assert self.deepspeed_config is not None
|
||||
# whether to check the validation of params
|
||||
deepspeed_config: DeepSpeedConfig = deepcopy(self.deepspeed_config)
|
||||
config: Dict = deepspeed_config.config # type: ignore
|
||||
if inference:
|
||||
# only Z3 makes sense for the inference
|
||||
if not deepspeed_config.is_zero3():
|
||||
raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config")
|
||||
|
||||
# in case the training config is re-used for inference
|
||||
deepspeed_config.del_config_sub_tree("optimizer")
|
||||
deepspeed_config.del_config_sub_tree("lr_scheduler")
|
||||
optimizer, lr_scheduler = None, None
|
||||
model_parameters = None
|
||||
else:
|
||||
model_parameters = list(filter(lambda p: p.requires_grad, self.model.parameters()))
|
||||
optimizer, lr_scheduler = self.optimizer, self.lr_scheduler
|
||||
|
||||
# deepspeed init
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"model_parameters": model_parameters,
|
||||
"config_params": config,
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": lr_scheduler,
|
||||
}
|
||||
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
|
||||
# load deepspeed checkpoint
|
||||
if self.resume_from_checkpoint_args is not None:
|
||||
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
|
||||
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
|
||||
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
|
||||
# path contains what looks like a deepspeed checkpoint
|
||||
resume_from_checkpoint = self.resume_from_checkpoint_args.get("load_dir", None)
|
||||
assert resume_from_checkpoint is not None
|
||||
tag = self.resume_from_checkpoint_args.get('tag', "global_step")
|
||||
load_module_strict = self.resume_from_checkpoint_args.get('load_module_strict', True)
|
||||
load_optimizer_states = self.resume_from_checkpoint_args.get('load_optimizer_states', True)
|
||||
load_lr_scheduler_states = self.resume_from_checkpoint_args.get('load_lr_scheduler_states', True)
|
||||
load_module_only = self.resume_from_checkpoint_args.get('load_module_only', False)
|
||||
custom_load_fn = self.resume_from_checkpoint_args.get("custom_load_fn", None)
|
||||
# copyed from transformers
|
||||
import glob
|
||||
# TODO to add load model from tag
|
||||
deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/{tag}*"))
|
||||
|
||||
if len(deepspeed_checkpoint_dirs) > 0:
|
||||
# logger.info(f"Attempting to resume from {self.resume_from_checkpoint}")
|
||||
# this magically updates self.optimizer and self.lr_scheduler
|
||||
# load_path, _ = deepspeed_engine.load_checkpoint(
|
||||
# self.resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
||||
# )
|
||||
load_path, _ = self.load_checkpoint(resume_from_checkpoint,
|
||||
tag,
|
||||
load_module_strict=load_module_strict,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states,
|
||||
load_module_only=load_module_only,
|
||||
custom_load_fn=custom_load_fn)
|
||||
if load_path is None:
|
||||
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
|
||||
else:
|
||||
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
|
||||
# record the original deepspeed step function
|
||||
self._ori_optimizer_step = deepspeed_engine.step
|
||||
|
||||
return deepspeed_engine, optimizer, lr_scheduler
|
||||
|
||||
def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]):
|
||||
if not isinstance(self.optimizer, Optimizer):
|
||||
return
|
||||
# used for adding param_group without deepspeed config
|
||||
assert isinstance(self.model, Module)
|
||||
assert module_name_param_dict is not None
|
||||
self._optimizer_add_param_group(self.model, module_name_param_dict, [self.optimizer]) # type: ignore
|
||||
|
||||
def unbind_model(self):
|
||||
if self.model:
|
||||
self.revert_loss()
|
||||
self.revert_optimizer_step()
|
||||
self.remove_all_hooks()
|
||||
self.lr_scheduler = None
|
||||
self.optimizer = None
|
||||
self._param_names_map = None
|
||||
self.model = None
|
||||
# TODO to check if unibind deepspeed params is needed
|
||||
self.deepspeed_engine = None
|
||||
self.deepspeed_config = None
|
||||
else:
|
||||
_logger.warning('Did not bind any model, no need to unbind model.')
|
||||
|
||||
def patch_loss(self, patch: Callable[[Tensor, Any], Tensor]):
|
||||
old_training_step = self._training_step
|
||||
|
||||
def patched_training_step(*args, **kwargs):
|
||||
out = old_training_step(*args, **kwargs)
|
||||
# we assume in training_step, ``batch`` is the first argument
|
||||
batch = args[0] if len(args) > 0 else kwargs['batch']
|
||||
if isinstance(out, Tensor):
|
||||
out = patch(out, batch)
|
||||
elif isinstance(out, Sequence) and not isinstance(out, str):
|
||||
assert isinstance(out[0], Tensor)
|
||||
new_loss = patch(out[0], batch)
|
||||
out = (new_loss,) + tuple(out[1:])
|
||||
elif isinstance(out, MutableMapping):
|
||||
assert 'loss' in out and isinstance(out['loss'], Tensor)
|
||||
out['loss'] = patch(out['loss'], batch)
|
||||
return out
|
||||
|
||||
self._training_step: _TRAINING_STEP = patched_training_step
|
||||
|
||||
def revert_loss(self):
|
||||
self._training_step = self._ori_training_step
|
||||
|
||||
def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
|
||||
self.is_patch_optim_for_ds = True
|
||||
self.before_step_tasks = before_step_tasks
|
||||
self.after_step_tasks = after_step_tasks
|
||||
|
||||
def revert_optimizer_step(self):
|
||||
assert self.deepspeed_engine is not None
|
||||
if self._ori_optimizer_step is not None:
|
||||
self.deepspeed_engine.step = self._ori_optimizer_step
|
||||
|
||||
def patch_engine_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
|
||||
assert self.deepspeed_engine is not None
|
||||
old_step = self.deepspeed_engine.step
|
||||
|
||||
def patched_step(_, *args, **kwargs):
|
||||
for task in before_step_tasks:
|
||||
task()
|
||||
# call origin optimizer step method
|
||||
output = old_step(*args, **kwargs)
|
||||
for task in after_step_tasks:
|
||||
task()
|
||||
return output
|
||||
|
||||
self.deepspeed_engine.step = types.MethodType(patched_step, self.deepspeed_engine)
|
||||
|
||||
def train(self, max_steps: int | None = None, max_epochs: int | None = None):
|
||||
# deepspeed init
|
||||
deepspeed_engine, optimizer, lr_scheduler = self.deepspeed_init(inference=False)
|
||||
self.deepspeed_engine = deepspeed_engine
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.model = deepspeed_engine.module
|
||||
|
||||
assert self.deepspeed_engine is not None
|
||||
assert self.optimizer is not None
|
||||
assert self._training_step is not None
|
||||
|
||||
if hasattr(self, 'is_patch_optim_for_ds') and self.is_patch_optim_for_ds:
|
||||
self.patch_engine_step(self.before_step_tasks, self.after_step_tasks)
|
||||
|
||||
self.training_func(self.deepspeed_engine, self.optimizer, self._training_step, self.lr_scheduler, max_steps, max_epochs)
|
||||
|
||||
def finetune(self):
|
||||
self.train()
|
||||
|
||||
def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]:
|
||||
# assert self.model is not None
|
||||
if self.evaluating_func is None:
|
||||
warn_msg = f'Did not pass evaluation_func to {self.__class__.__name__}, will return None for calling evaluate()'
|
||||
_logger.warning(warn_msg)
|
||||
return None
|
||||
|
||||
if self.deepspeed_engine is None:
|
||||
deepspeed_engine, _, _ = self.deepspeed_init(inference=True)
|
||||
self.deepspeed_engine = deepspeed_engine
|
||||
self.model = self.deepspeed_engine.module
|
||||
|
||||
assert self.deepspeed_engine is not None
|
||||
|
||||
metric = self.evaluating_func(self.deepspeed_engine)
|
||||
if isinstance(metric, dict):
|
||||
nni_used_metric = metric.get('default', None)
|
||||
if nni_used_metric is None:
|
||||
warn_msg = f'Evaluation function returns a dict metric without key `default`,' + \
|
||||
'will return None as the model evaluation metric value.'
|
||||
_logger.warning(warn_msg)
|
||||
return nni_used_metric, metric
|
||||
else:
|
||||
return metric
|
||||
|
||||
def get_dummy_input(self) -> Any:
|
||||
return self.dummy_input
|
||||
|
||||
def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
|
||||
"""
|
||||
Save training checkpoint
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir
|
||||
Required. Directory for saving the checkpoint
|
||||
tag
|
||||
Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
|
||||
used if not provided. Tag name must be the same across all ranks.
|
||||
client_state
|
||||
Optional. State dictionary used for saving required training states in the client code.
|
||||
save_latest
|
||||
Optional. Save a file 'latest' pointing to the latest saved checkpoint.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Important: all processes must call this method and not just the process with rank 0. It is
|
||||
because each process needs to save its master weights and scheduler+optimizer states. This
|
||||
method will hang waiting to synchronize with other processes if it's called just for the
|
||||
process with rank 0.
|
||||
"""
|
||||
|
||||
# copyed from deepspeed
|
||||
assert self.deepspeed_engine is not None
|
||||
return self.deepspeed_engine.save_checkpoint(save_dir, tag, client_state=client_state,
|
||||
save_latest=save_latest)
|
||||
|
||||
def load_checkpoint(self,
|
||||
load_dir,
|
||||
tag=None,
|
||||
load_module_strict=True,
|
||||
load_optimizer_states=True,
|
||||
load_lr_scheduler_states=True,
|
||||
load_module_only=False,
|
||||
custom_load_fn=None):
|
||||
"""
|
||||
Load training checkpoint
|
||||
|
||||
Parameters
|
||||
----------
|
||||
load_dir
|
||||
Required. Directory to load the checkpoint from
|
||||
tag
|
||||
Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
|
||||
load_module_strict
|
||||
Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
|
||||
load_optimizer_states
|
||||
Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
|
||||
load_lr_scheduler_states
|
||||
Optional. Boolean to add the learning rate scheduler states from Checkpoint.
|
||||
load_module_only
|
||||
Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
|
||||
custom_load_fn
|
||||
Optional. Custom model load function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
load_path
|
||||
Path of the loaded checkpoint. None if loading the checkpoint failed.
|
||||
client_state
|
||||
State dictionary used for loading required training states in the client code.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
|
||||
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
|
||||
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
|
||||
before ``load_checkpoint()``.
|
||||
"""
|
||||
|
||||
# copyed from deepspeed
|
||||
assert self.deepspeed_engine is not None
|
||||
return self.deepspeed_engine.load_checkpoint(load_dir,
|
||||
tag=tag,
|
||||
load_module_strict=load_module_strict,
|
||||
load_optimizer_states=load_optimizer_states,
|
||||
load_lr_scheduler_states=load_lr_scheduler_states,
|
||||
load_module_only=load_module_only,
|
||||
custom_load_fn=custom_load_fn)
|
|
@ -33,14 +33,6 @@ steps:
|
|||
displayName: (Ubuntu) Downgrade swig
|
||||
condition: and(succeeded(), contains('${{ parameters.platform }}', 'ubuntu'))
|
||||
|
||||
- script: |
|
||||
set -e
|
||||
brew install 'swig@3'
|
||||
rm /usr/local/bin/swig
|
||||
ln -s '/usr/local/opt/swig@3/bin/swig' /usr/local/bin/swig
|
||||
displayName: (macOS) Downgrade swig
|
||||
condition: and(succeeded(), contains('${{ parameters.platform }}', 'macos'))
|
||||
|
||||
- script: |
|
||||
set -e
|
||||
azcopy copy 'https://nni.blob.core.windows.net/cache/dependencies-${{ parameters.platform }}.zip?$(sas_cache)' dependencies.zip
|
||||
|
|
|
@ -74,4 +74,4 @@ sudo systemctl disable apt-daily-upgrade.service
|
|||
|
||||
# Deprovision
|
||||
sudo /usr/sbin/waagent -force -deprovision
|
||||
sudo HISTSIZE=0 sync
|
||||
sudo HISTSIZE=0 sync
|
Загрузка…
Ссылка в новой задаче