diff --git a/dependencies/recommended_gpu_legacy.txt b/dependencies/recommended_gpu_legacy.txt index f33ace927..a927242f0 100644 --- a/dependencies/recommended_gpu_legacy.txt +++ b/dependencies/recommended_gpu_legacy.txt @@ -23,4 +23,4 @@ graphviz gym sympy tianshou >= 0.4.1 -timm >= 0.5.4 +timm >= 0.5.4 \ No newline at end of file diff --git a/dependencies/recommended_legacy.txt b/dependencies/recommended_legacy.txt index 49128ac8d..8d8e1ee04 100644 --- a/dependencies/recommended_legacy.txt +++ b/dependencies/recommended_legacy.txt @@ -18,4 +18,4 @@ timm >= 0.5.4 keras tensorflow == 2.3 -protobuf <= 3.20.1 +protobuf <= 3.20.1 \ No newline at end of file diff --git a/docs/source/compression/evaluator.rst b/docs/source/compression/evaluator.rst index 05623b352..d5785dcdd 100644 --- a/docs/source/compression/evaluator.rst +++ b/docs/source/compression/evaluator.rst @@ -158,3 +158,88 @@ Moreover, if you are utilizing a personalized optimizer or learning rate schedul optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001) lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch) + + +A complete example of using a trainer with DeepSpeed mode under the TransformersEvaluator can be found: :githublink:`here `. + + +DeepspeedTorchEvaluator +----------------------- + +:class:`DeepspeedTorchEvaluator ` is an evaluator designed specifically for native PyTorch users who are utilizing DeepSpeed. + +:class:`DeepspeedTorchEvaluator ` has eight initialization parameters ``training_func``, ``training_step``, ``deepspeed``, ``optimizer``, ``lr_scheduler``, +``resume_from_checkpoint_args``, ``dummy_input``, ``evaluating_func``. + +* ``training_func`` is the training loop to train the compressed model. + It is a callable function with six input parameters ``model``, ``optimizers``, + ``training_step``, ``lr_schedulers``, ``max_steps``, ``max_epochs``. + Please make sure each input argument of the ``training_func`` is actually used, + especially ``max_steps`` and ``max_epochs`` can correctly control the duration of training. +* ``training_step`` A callable function, the first argument of inputs should be ``batch``, and the outputs should contain loss. + Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a key ``loss``. +* ``deepspeed`` is the deepspeed configuration which Contains the parameters needed in DeepSpeed, such as train_batch_size, among others. +* ``optimizer`` is a single traced optimizer instance or a function that takes the model parameters as input and returns an optimizer instance. + Please make sure using ``nni.trace`` wrapping the ``Optimizer`` class before initializing it / them if it is a single traced optimizer. +* ``lr_scheduler`` is a single traced lr_scheduler instance or a function that takes the model parameters and the optimizer as input and returns an lr_scheduler instance. + Please make sure using ``nni.trace`` wrapping the ``_LRScheduler`` class before initializing it / them if it is a single traced scheduler. +* ``resume_from_checkpoint_args`` is used in the deepspeed_init process to load models saved during training with DeepSpeed. +* ``dummy_input`` is used to trace the model, same as ``example_inputs`` + in `torch.jit.trace `_. +* ``evaluating_func`` is a callable function to evaluate the compressed model performance. Its input is a compressed model and its output is metric. + The format of metric should be a float number or a dict with key ``default``. + +Please refer :class:`DeepspeedTorchEvaluator ` for more details. +Here is an example of how to initialize a :class:`DeepspeedTorchEvaluator `. + +.. code-block:: python + + def training_step(batch, model, *args, **kwargs): + output = model(batch[0]) + loss = F.cross_entropy(output, batch[1]) + return loss + + def training_func(model, optimizer, training_step, lr_scheduler, max_steps, max_epochs): + # here model is an instance of DeepSpeedEngine + assert max_steps is not None or max_epochs is not None + total_steps = max_steps if max_steps else max_epochs * len(train_dataloader) + total_epochs = total_steps // len(train_dataloader) + (0 if total_steps % len(train_dataloader) == 0 else 1) + + current_step = 0 + for _ in range(total_epochs): + for batch in train_dataloader: + loss = training_step(batch, model) + model.backward(model) + model.step() + + # if reach the total steps, exit from the training loop + current_step = current_step + 1 + if current_step >= total_steps: + return + + # if you are using a epoch-wise scheduler, call it here + lr_scheduler.step() + + ds_config = { + "gradient_accumulation_steps": 1, + "steps_per_print": 2000, + "wall_clock_breakdown": False, + "train_batch_size": 128, + "train_micro_batch_size_per_gpu": 128, + "zero_force_ds_cpu_optimizer": False, + "zero_allow_untested_optimizer": True + } + + optimizer = nni.trace(torch.optim.Adam)(model.parameters(), lr=0.001) + lr_scheduler = nni.trace(torch.optim.lr_scheduler.LambdaLR)(optimizer, lr_lambda=lambda epoch: 1 / epoch) + + evaluator = DeepspeedTorchEvaluator(training_func, training_step, ds_config, lr_scheduler) + +.. note:: + It is also worth to note that not all the arguments of :class:`TorchEvaluator ` must be provided. + Some compressors only require ``evaluate_func`` as they do not train the model, some compressors only require ``training_func``. + Please refer to each compressor's doc to check the required arguments. + But, it is fine to provide more arguments than the compressor's need. + + +A complete example can be found :githublink:`here `. \ No newline at end of file diff --git a/docs/source/reference/compression/evaluator.rst b/docs/source/reference/compression/evaluator.rst index efc011562..d7401ba95 100644 --- a/docs/source/reference/compression/evaluator.rst +++ b/docs/source/reference/compression/evaluator.rst @@ -21,3 +21,10 @@ TransformersEvaluator --------------------- .. autoclass:: nni.contrib.compression.TransformersEvaluator + +.. _new-deepspeed-torch-evaluator: + +DeepspeedTorchEvaluator +----------------------- + +.. autoclass:: nni.contrib.compression.DeepspeedTorchEvaluator diff --git a/docs/source/tutorials/hpo_quickstart_pytorch/index.rst b/docs/source/tutorials/hpo_quickstart_pytorch/index.rst index ccd068fd7..dc70977af 100644 --- a/docs/source/tutorials/hpo_quickstart_pytorch/index.rst +++ b/docs/source/tutorials/hpo_quickstart_pytorch/index.rst @@ -17,7 +17,7 @@ .. only:: html .. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png - :alt: HPO Quickstart with PyTorch + :alt: :ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py` @@ -34,7 +34,7 @@ .. only:: html .. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png - :alt: Port PyTorch Quickstart to NNI + :alt: :ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py` diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst index f78ec6016..10ea3eb74 100644 --- a/docs/source/tutorials/index.rst +++ b/docs/source/tutorials/index.rst @@ -17,7 +17,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_pruning_speedup_thumb.png - :alt: Speedup Model with Mask + :alt: :ref:`sphx_glr_tutorials_pruning_speedup.py` @@ -34,7 +34,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_pruning_quick_start_thumb.png - :alt: Pruning Quickstart + :alt: :ref:`sphx_glr_tutorials_pruning_quick_start.py` @@ -51,7 +51,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_nasbench_as_dataset_thumb.png - :alt: Use NAS Benchmarks as Datasets + :alt: :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` @@ -68,7 +68,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_quantization_quick_start_thumb.png - :alt: Quantization Quickstart + :alt: :ref:`sphx_glr_tutorials_quantization_quick_start.py` @@ -85,7 +85,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_quantization_speedup_thumb.png - :alt: Speed Up Quantized Model with TensorRT + :alt: :ref:`sphx_glr_tutorials_quantization_speedup.py` @@ -102,7 +102,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_hello_nas_thumb.png - :alt: Hello, NAS! + :alt: :ref:`sphx_glr_tutorials_hello_nas.py` @@ -119,7 +119,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_quantization_bert_glue_thumb.png - :alt: Quantize BERT on Task GLUE + :alt: :ref:`sphx_glr_tutorials_quantization_bert_glue.py` @@ -136,7 +136,7 @@ Tutorials .. only:: html .. image:: /tutorials/images/thumb/sphx_glr_darts_thumb.png - :alt: Searching in DARTS search space + :alt: :ref:`sphx_glr_tutorials_darts.py` @@ -148,12 +148,12 @@ Tutorials .. raw:: html -
+
.. only:: html .. image:: /tutorials/images/thumb/sphx_glr_new_pruning_bert_glue_thumb.png - :alt: Pruning Bert on Task MNLI + :alt: :ref:`sphx_glr_tutorials_new_pruning_bert_glue.py` @@ -196,7 +196,7 @@ Tutorials .. only:: html .. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_main_thumb.png - :alt: HPO Quickstart with PyTorch + :alt: :ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_main.py` @@ -213,7 +213,7 @@ Tutorials .. only:: html .. image:: /tutorials/hpo_quickstart_pytorch/images/thumb/sphx_glr_model_thumb.png - :alt: Port PyTorch Quickstart to NNI + :alt: :ref:`sphx_glr_tutorials_hpo_quickstart_pytorch_model.py` @@ -242,7 +242,7 @@ Tutorials .. only:: html .. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_main_thumb.png - :alt: HPO Quickstart with TensorFlow + :alt: :ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_main.py` @@ -259,7 +259,7 @@ Tutorials .. only:: html .. image:: /tutorials/hpo_quickstart_tensorflow/images/thumb/sphx_glr_model_thumb.png - :alt: Port TensorFlow Quickstart to NNI + :alt: :ref:`sphx_glr_tutorials_hpo_quickstart_tensorflow_model.py` @@ -278,6 +278,7 @@ Tutorials :hidden: :includehidden: + /tutorials/hpo_quickstart_pytorch/index.rst /tutorials/hpo_quickstart_tensorflow/index.rst diff --git a/examples/model_compress/quantization/bert_ds_config.json b/examples/model_compress/quantization/bert_ds_config.json new file mode 100644 index 000000000..180c938b6 --- /dev/null +++ b/examples/model_compress/quantization/bert_ds_config.json @@ -0,0 +1,38 @@ +{ + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": 1, + "steps_per_print": 2000, + "wall_clock_breakdown": false, + "train_batch_size": 128, + "train_micro_batch_size_per_gpu": 128, + "zero_force_ds_cpu_optimizer": false, + "zero_allow_untested_optimizer": true +} \ No newline at end of file diff --git a/examples/model_compress/quantization/bert_quantization_with_ds.py b/examples/model_compress/quantization/bert_quantization_with_ds.py new file mode 100644 index 000000000..c1dbd84f1 --- /dev/null +++ b/examples/model_compress/quantization/bert_quantization_with_ds.py @@ -0,0 +1,198 @@ +from pathlib import Path +import argparse +import sys + +import numpy as np + +import torch +from torch.utils.data import ConcatDataset + +import nni + +from datasets import load_dataset, load_metric +from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction +from transformers.trainer import Trainer +from transformers.training_args import TrainingArguments + + +task_name = 'mnli' +finetune_lr = 4e-5 +quant_lr = 1e-5 +quant_method = 'lsq' +dev_mode = True + +if dev_mode: + quant_max_epochs = 1 + finetune_max_epochs = 1 +else: + quant_max_epochs = 10 + finetune_max_epochs = 10 + +# %% +# Load the pre-trained model from the transformers + +def build_model(pretrained_model_name_or_path: str, task_name: str): + is_regression = task_name == 'stsb' + num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2) + model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels) + return model + +# %% +# Create datasets on the specific task GLUE + +def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str): + task_to_keys = { + 'cola': ('sentence', None), + 'mnli': ('premise', 'hypothesis'), + 'mrpc': ('sentence1', 'sentence2'), + 'qnli': ('question', 'sentence'), + 'qqp': ('question1', 'question2'), + 'rte': ('sentence1', 'sentence2'), + 'sst2': ('sentence', None), + 'stsb': ('sentence1', 'sentence2'), + 'wnli': ('sentence1', 'sentence2'), + } + sentence1_key, sentence2_key = task_to_keys[task_name] + + # used to preprocess the raw data + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=False, max_length=128, truncation=True) + + if 'label' in examples: + # In all cases, rename the column to labels because the model will expect that. + result['labels'] = examples['label'] + return result + + raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir) + for key in list(raw_datasets.keys()): + if 'test' in key: + raw_datasets.pop(key) + + processed_datasets = raw_datasets.map(preprocess_function, batched=True, + remove_columns=raw_datasets['train'].column_names) + + train_dataset = processed_datasets['train'] + if task_name == 'mnli': + validation_datasets = { + 'validation_matched': processed_datasets['validation_matched'], + 'validation_mismatched': processed_datasets['validation_mismatched'] + } + else: + validation_datasets = { + 'validation': processed_datasets['validation'] + } + + return train_dataset, validation_datasets + + +def prepare_traced_trainer(model, load_best_model_at_end=False, is_quant=False): + is_regression = task_name == 'stsb' + metric = load_metric('glue', task_name) + + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + result = metric.compute(predictions=preds, references=p.label_ids) + result['default'] = result.get('f1', result.get('accuracy', 0.)) + return result + + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, '') + merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()]) # type: ignore + data_collator = DataCollatorWithPadding(tokenizer) + training_args = TrainingArguments(output_dir='./output/trainer', + do_train=True, + do_eval=True, + fp16=False, + learning_rate=3e-5, + evaluation_strategy='steps', + per_device_train_batch_size=128, #128, + per_device_eval_batch_size=128, #128, + num_train_epochs=finetune_max_epochs, + dataloader_num_workers=12, + save_strategy='steps', + save_total_limit=1, + metric_for_best_model='default', + greater_is_better=True, + seed=1024, + eval_steps=100, + deepspeed="./bert_ds_config.json", + load_best_model_at_end=load_best_model_at_end,) + # if is_quant: + # training_args.learning_rate = quant_lr + # else: + # training_args.learning_rate = finetune_lr + print(f"=== learning_rate:{training_args.learning_rate} ====") + trainer = nni.trace(Trainer)(model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=merged_validation_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + ) + + return trainer + + +def build_finetuning_model(state_dict_path: str, is_quant=False): + model = build_model('bert-base-uncased', task_name) + if Path(state_dict_path).exists(): + model.load_state_dict(torch.load(state_dict_path)) + print(f"==== load finetune model directly ====") + else: + trainer = prepare_traced_trainer(model, True, is_quant) + trainer.train() + torch.save(model.state_dict(), state_dict_path) + return model + + +import nni +from nni.contrib.compression.quantization import QATQuantizer, LsqQuantizer, PtqQuantizer +from nni.contrib.compression.utils import TransformersEvaluator + + +def fake_quantize(): + config_list = [{ + 'op_types': ['Linear'], + 'op_names_re': ['bert.encoder.layer.{}'.format(i) for i in range(12)], + 'target_names': ['weight', '_output_'], + 'quant_dtype': 'int8', + 'quant_scheme': 'symmetric', + 'granularity': 'default', + }] + + # create a finetune model + Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True) + model: torch.nn.Module = build_finetuning_model(f'./output/{task_name}.bin', is_quant=False) # type: ignore + traced_trainer = prepare_traced_trainer(model, is_quant=False) + evaluator = TransformersEvaluator(traced_trainer) + if quant_method == 'lsq': + quantizer = LsqQuantizer(model, config_list, evaluator) + model, calibration_config = quantizer.compress(max_steps=None, max_epochs=quant_max_epochs) + elif quant_method == 'qat': + quantizer = QATQuantizer(model, config_list, evaluator, 10) + model, calibration_config = quantizer.compress(max_steps=None, max_epochs=quant_max_epochs) + elif quant_method == 'ptq': + quantizer = PtqQuantizer(model, config_list, evaluator) + model, calibration_config = quantizer.compress(max_steps=1, max_epochs=None) + else: + raise ValueError(f"quantization method {quant_method} is not supported") + print(calibration_config) + # evaluate the performance of the fake quantize model + quantizer.evaluator.bind_model(model, quantizer._get_param_names_map()) + + +def evaluate(): + model = build_finetuning_model(f'./output/{task_name}.bin', is_quant=False) + trainer = prepare_traced_trainer(model, is_quant=False) + metrics = trainer.evaluate() + print(f"Evaluate metrics={metrics}") + + +fake_quantize() +# evaluate() diff --git a/examples/model_compress/quantization/ds_config.json b/examples/model_compress/quantization/ds_config.json new file mode 100644 index 000000000..57c00aa6f --- /dev/null +++ b/examples/model_compress/quantization/ds_config.json @@ -0,0 +1,47 @@ +{ + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 3e-5, + "betas": [0.99, 0.999], + "eps": 1e-8, + "weight_decay": 0 + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": 1, + "steps_per_print": 2000, + "wall_clock_breakdown": false, + "train_batch_size": 128, + "train_micro_batch_size_per_gpu": 128, + "zero_force_ds_cpu_optimizer": false, + "zero_allow_untested_optimizer": true +} \ No newline at end of file diff --git a/examples/model_compress/quantization/quantization_with_deepspeed.py b/examples/model_compress/quantization/quantization_with_deepspeed.py new file mode 100644 index 000000000..6a0001021 --- /dev/null +++ b/examples/model_compress/quantization/quantization_with_deepspeed.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import sys +from typing import Callable, Union, List, Dict, Tuple + +import torch +import torch.nn.functional as F +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from torch import Tensor + +from torchvision import datasets, transforms +from deepspeed import DeepSpeedEngine +from nni.contrib.compression.quantization import LsqQuantizer +from nni.contrib.compression.utils import DeepspeedTorchEvaluator +from nni.common.types import SCHEDULER + + +torch.manual_seed(0) +device = 'cuda' +_TRAINING_STEP = Callable[..., Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]] + +datasets.MNIST(root='data/mnist', train=True, download=True) +datasets.MNIST(root='data/mnist', train=False, download=True) +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) +mnist_train = datasets.MNIST(root='data/mnist', train=True, transform=transform) +train_dataloader = DataLoader(mnist_train, batch_size=64) +mnist_test = datasets.MNIST(root='data/mnist', train=False, transform=transform) +test_dataloader = DataLoader(mnist_test, batch_size=1000) + + +class Mnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + self.relu1 = torch.nn.ReLU6() + self.relu2 = torch.nn.ReLU6() + self.relu3 = torch.nn.ReLU6() + self.max_pool1 = torch.nn.MaxPool2d(2, 2) + self.max_pool2 = torch.nn.MaxPool2d(2, 2) + self.batchnorm1 = torch.nn.BatchNorm2d(20) + + def forward(self, x): + x = self.relu1(self.batchnorm1(self.conv1(x))) + x = self.max_pool1(x) + x = self.relu2(self.conv2(x)) + x = self.max_pool2(x) + x = x.view(-1, 4 * 4 * 50) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def training_step(batch, model) -> Tensor: + x, y = batch[0].to(device), batch[1].to(device) + logits = model(x) + loss: torch.Tensor = F.nll_loss(logits, y) + return loss + + +def training_model(model: DeepSpeedEngine, optimizer: Union[Optimizer, List[Optimizer]], \ + training_step: _TRAINING_STEP, scheduler: Union[None, SCHEDULER, List[SCHEDULER]] = None, + max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None): + assert isinstance(model, DeepSpeedEngine) + model.train() + max_epochs = max_epochs or (40 if max_steps is None else 100) + current_steps = 0 + best_acc = 0.0 + + # training + for epoch in range(max_epochs): + print(f'Epoch {epoch} start!') + model.train() + for batch in train_dataloader: + if isinstance(optimizer, Optimizer): + optimizer.zero_grad() + elif isinstance(optimizer, List) and all(isinstance(_, Optimizer) for _ in optimizer): + for opt in optimizer: + opt.zero_grad() + loss = training_step(batch, model) + assert isinstance(loss, torch.Tensor) + model.backward(loss) + model.step() + current_steps += 1 + if max_steps and current_steps == max_steps: + return + + acc = evaluating_model(model) + best_acc = max(acc, best_acc) + print(f"epoch={epoch}\tacc={acc}\tbest_acc={best_acc}") + + +def evaluating_model(model: torch.nn.Module): + model.eval() + # testing + correct = 0 + with torch.no_grad(): + for x, y in test_dataloader: + x, y = x.to(device), y.to(device) + logits = model(x) + preds = torch.argmax(logits, dim=1) + correct += preds.eq(y.view_as(preds)).sum().item() + print(f'Accuracy: {100 * correct / len(mnist_test)}%)\n') + + return correct / len(mnist_test) + + +def main(): + model = Mnist().to(device) + configure_list = [{ + 'target_names':['_input_', 'weight'], + 'op_names': ['conv2'], + 'quant_dtype': 'int8', + 'quant_scheme': 'affine', + 'granularity': 'default', + },{ + 'op_names': ['relu1', 'relu2'], + 'target_names': ['_output_'], + 'quant_dtype': 'int8', + 'quant_scheme': 'affine', + 'granularity': 'default', + },{ + 'op_names': ['max_pool2'], + 'target_names': ['_output_'], + 'quant_dtype': 'int8', + 'quant_scheme': 'affine', + 'granularity': 'default', + }, + { + 'target_names':['_input_', 'weight', '_output_'], + 'op_names': ['conv1'], + 'quant_dtype': 'int8', + 'quant_scheme': 'affine', + 'granularity': 'default', + 'fuse_names': [("conv1", "batchnorm1")] + }] + + evaluator = DeepspeedTorchEvaluator(training_model, training_step, "./ds_config.json") #, lrs) + quantizer = LsqQuantizer(model, configure_list, evaluator) + model, calibration_config = quantizer.compress(None, 4) + acc = evaluating_model(model) + + + +if __name__ == '__main__': + main() diff --git a/nni/compression/pytorch/speedup/v2/dependency.py b/nni/compression/pytorch/speedup/v2/dependency.py index 3249e6858..68d2fcd05 100644 --- a/nni/compression/pytorch/speedup/v2/dependency.py +++ b/nni/compression/pytorch/speedup/v2/dependency.py @@ -334,7 +334,7 @@ def build_channel_dependency(graph_module: torch.fx.GraphModule, # To determine if this cat operation will introduce channel # dependency, we need the specific input parameters of the cat # operation. - cat_dim = node.kwargs.get('dim', None) or node.args[1] + cat_dim = node.kwargs.get('dim', None) if node.kwargs.get('dim', None) is not None else node.args[1] if cat_dim != 1: d_set = set(find_adjacent_layers(node, graph_module, target_types, 'parent')) diff --git a/nni/contrib/compression/__init__.py b/nni/contrib/compression/__init__.py index f198ad9e0..d68564679 100644 --- a/nni/contrib/compression/__init__.py +++ b/nni/contrib/compression/__init__.py @@ -1,4 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .utils.evaluator import LightningEvaluator, TorchEvaluator, TransformersEvaluator +from .utils.evaluator import (LightningEvaluator, + TorchEvaluator, + TransformersEvaluator, + DeepspeedTorchEvaluator) diff --git a/nni/contrib/compression/quantization/lsq_quantizer.py b/nni/contrib/compression/quantization/lsq_quantizer.py index 7f11407be..e6b50dffa 100644 --- a/nni/contrib/compression/quantization/lsq_quantizer.py +++ b/nni/contrib/compression/quantization/lsq_quantizer.py @@ -98,7 +98,7 @@ class LsqQuantizer(Quantizer): # init_target = target.data.detach().abs().mean() * 2 / (target_space.qmax ** 0.5) init_target = torch.tensor([0.01]).to(target.device) if not target_space._scaler: - target_space.scale.data = init_target # type: ignore + target_space.scale.data = init_target.view(1) # type: ignore target_space.zero_point = torch.tensor(0.0).to(target.device) else: new_target = init_target.expand(target.shape).to(target.device) @@ -117,7 +117,16 @@ class LsqQuantizer(Quantizer): for target_name, _ in ts.items(): if hasattr(wrapper, f"{target_name}_scale"): delattr(wrapper, f"{target_name}_scale") - param = torch.nn.Parameter() + # for deepspeed + try: + device = next(wrapper.parameters()).device + except StopIteration: + try: + device = next(wrapper.buffers()).device + except StopIteration: + # NOTE: this will have risk in model parallel + device = next(self.bound_model.parameters()).device + param = torch.nn.Parameter(torch.Tensor([0.01]).to(device)) wrapper.register_parameter(f"{target_name}_scale", param) def patch_optimizer_param_group(self): diff --git a/nni/contrib/compression/utils/__init__.py b/nni/contrib/compression/utils/__init__.py index 33e9869e9..1908c6543 100644 --- a/nni/contrib/compression/utils/__init__.py +++ b/nni/contrib/compression/utils/__init__.py @@ -3,7 +3,11 @@ from .check_ddp import check_ddp_model, reset_ddp_model from .dependency import auto_set_denpendency_group_ids -from .evaluator import Evaluator, LightningEvaluator, TorchEvaluator, TransformersEvaluator +from .evaluator import (Evaluator, + LightningEvaluator, + TorchEvaluator, + TransformersEvaluator, + DeepspeedTorchEvaluator) from .scaling import Scaling from .attr import ( get_nested_attr, diff --git a/nni/contrib/compression/utils/evaluator.py b/nni/contrib/compression/utils/evaluator.py index 0ba7bd645..0e8822d59 100644 --- a/nni/contrib/compression/utils/evaluator.py +++ b/nni/contrib/compression/utils/evaluator.py @@ -22,13 +22,46 @@ except ImportError: else: LIGHTNING_INSTALLED = True +try: + import deepspeed +except ImportError: + DEEPSPEED_INSTALLED = False +else: + DEEPSPEED_INSTALLED = True + try: from transformers.trainer import Trainer as HFTrainer + from transformers import TrainerCallback, TrainerControl, TrainerState + from transformers import TrainingArguments except ImportError: TRANSFORMERS_INSTALLED = False + + class PatchCallback: + def on_train_begin(self, *args, **kwargs): + raise RuntimeError("Don't use the fake PatchCallback, please install transformers") else: TRANSFORMERS_INSTALLED = True + class PatchCallback(TrainerCallback): # type: ignore + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + pass + +try: + from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig # type: ignore +except ImportError: + ACCELERATE_INSTALLED = False + class DeepSpeedConfig: + def __init__(self, *args, **kwargs): + raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate") + def get_value(self, key: str): + raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate") + def del_config_sub_tree(self, key: str): + raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate") + def is_zero3(self): + raise RuntimeError("Don't use the fake DeepSpeedConfig, please install accelerate") +else: + ACCELERATE_INSTALLED = True + import nni from nni.common import is_traceable from nni.common.types import SCHEDULER @@ -1019,6 +1052,8 @@ class TransformersEvaluator(Evaluator): self._ori_trainer_attr['optimizer.step'] = self.trainer.optimizer.step def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]): + if self.trainer.args.deepspeed: + return assert isinstance(self.model, Module) assert module_name_param_dict is not None self._optimizer_add_param_group(self.model, module_name_param_dict, self.trainer.optimizer) @@ -1032,6 +1067,7 @@ class TransformersEvaluator(Evaluator): self.trainer.optimizer = None self._param_names_map = None self._ori_trainer_attr.pop('compute_loss', None) + self.trainer.remove_callback(PatchCallback) self.trainer = None # type: ignore self.model = None else: @@ -1053,19 +1089,27 @@ class TransformersEvaluator(Evaluator): self.trainer.compute_loss = self._ori_trainer_attr['compute_loss'] def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]): - assert self.trainer.optimizer is not None - old_step = self.trainer.optimizer.step + def custom_on_train_begin(_, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + optimizer = self.trainer.deepspeed if hasattr(self.trainer, "deepspeed") else self.trainer.callback_handler.optimizer - def patched_step(_, *args, **kwargs): - for task in before_step_tasks: - task() - # call origin optimizer step method - output = old_step(*args, **kwargs) - for task in after_step_tasks: - task() - return output + assert optimizer is not None + old_step = optimizer.step - self.trainer.optimizer.step = types.MethodType(patched_step, self.trainer.optimizer) + def patched_step(_, *args, **kwargs): + for task in before_step_tasks: + task() + # call origin optimizer step method + output = old_step(*args, **kwargs) + for task in after_step_tasks: + task() + return output + + optimizer.step = types.MethodType(patched_step, optimizer) + + PatchCallback.on_train_begin = types.MethodType(custom_on_train_begin, PatchCallback) + + # Add Callback into the callback_handler + self.trainer.add_callback(PatchCallback) def revert_optimizer_step(self): assert self.trainer.optimizer is not None @@ -1100,3 +1144,502 @@ class TransformersEvaluator(Evaluator): def get_dummy_input(self) -> Any: return self.dummy_input + + +class DeepspeedTorchEvaluator(Evaluator): + """ + The DeepseedTorchEvaluator is an evaluator designed specifically for native PyTorch users who are utilizing DeepSpeed. + + Parameters + ---------- + training_func + The training function is used to train the model, note that this a entire optimization training loop. + Training function has three required parameters, ``model``, ``optimizers`` and ``training_step``, + and three optional parameters, ``lr_schedulers``, ``max_steps``, ``max_epochs``. + + Let's explain these six parameters NNI passed in, but in most cases, users don't need to care about these. + Users only need to treat these six parameters as the original parameters during the training process. + + * The ``model`` is a wrapped model from the original model, it has a similar structure to the model to be pruned, + so it can share training function with the original model. + * ``optimizers`` are re-initialized from the ``optimizers`` passed to the evaluator and the wrapped model's parameters. + * ``training_step`` also based on the ``training_step`` passed to the evaluator, + it might be modified by the compressor during model compression. + * If users use ``lr_schedulers`` in the ``training_func``, NNI will re-initialize the ``lr_schedulers`` with the re-initialized + optimizers. + * ``max_steps`` is the NNI training duration limitation. It is for pruner (or quantizer) to control the number of training steps. + The user implemented ``training_func`` should respect ``max_steps`` by stopping the training loop after ``max_steps`` is reached. + Pruner may pass ``None`` to ``max_steps`` when it only controls ``max_epochs``. + * ``max_epochs`` is similar to the ``max_steps``, the only different is that it controls the number of training epochs. + The user implemented ``training_func`` should respect ``max_epochs`` by stopping the training loop + after ``max_epochs`` is reached. Pruner may pass ``None`` to ``max_epochs`` when it only controls ``max_steps``. + + Note that when the pruner passes ``None`` to both ``max_steps`` and ``max_epochs``, + it treats ``training_func`` as a function of model fine-tuning. + Users should assign proper values to ``max_steps`` and ``max_epochs``. + + .. code-block:: python + + def training_func(model: DeepSpeedEngine, optimizers: torch.optim.Optimizer, + training_step: Callable[[Any, Any], torch.Tensor], + lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None, + max_epochs: int | None = None, *args, **kwargs): + ... + total_epochs = max_epochs if max_epochs else 20 + total_steps = max_steps if max_steps else 1000000 + current_steps = 0 + ... + for epoch in range(total_epochs): + ... + model.backward(loss) + model.step() + if current_steps >= total_steps: + return + + Note that ``optimizers`` and ``lr_schedulers`` passed to the ``training_func`` have the same type as the ``optimizers`` + and ``lr_schedulers`` passed to evaluator, a single ``torch.optim.Optimzier``/ ``torch.optim._LRScheduler`` instance or + a list of them. + + training_step + A callable function, the first argument of inputs should be ``batch``, and the outputs should contain loss. + Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a key ``loss``. + + .. code-block:: python + + def training_step(batch, model, ...): + inputs, labels = batch + output = model(inputs) + ... + loss = loss_func(output, labels) + return loss + deepspeed + Str | dict. The deepspeed configuration which Contains the parameters needed in DeepSpeed, such as train_batch_size, among others. + optimzier + Optional. A single traced optimizer instance or a function that takes the model parameters + as input and returns an optimizer instance. NNI may modify the ``torch.optim.Optimizer`` member function ``step`` + and/or optimize compressed models, so NNI needs to have the ability to re-initialize the optimizer. ``nni.trace`` can + record the initialization parameters of a function/class, which can then be used by NNI to re-initialize the + optimizer for a new but structurally similar model. E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``. + lr_schedulers + Optional. A single traced lr_scheduler instance or a function that takes the model parameters and the optimizer as input + and returns an lr_scheduler instance. For the same reason with ``optimizers``, NNI needs the traced lr_scheduler + to re-initialize it. + E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``. + resume_from_checkpoint_args + Dict | None. Used in the deepspeed_init process to load models saved during training with DeepSpeed. + Let's explain these seven elements in the resume_from_checkpoint_args. + + * ``load_dir``: The directory to load the checkpoint from. + * ``tag`` : Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file + * ``load_module_strict``: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. + * ``load_optimizer_states``: Optional. Boolean to load the training optimizer states from Checkpoint. + * ``load_lr_scheduler_states``: Optional. Boolean to add the learning rate scheduler states from Checkpoint. + * ``load_module_only``: Optional. Boolean to load only the model weights from the checkpoint. + * ``custom_load_fn``: Optional. Custom model load function. + + dummy_input + Optional. The dummy_input is used to trace the graph, it's same with ``example_inputs`` in + `torch.jit.trace `_. + evaluating_func + Optional. A function that input is model and return the evaluation metric. + This is the function used to evaluate the compressed model performance. + The input is a model and the output is a ``float`` metric or a ``dict`` + (``dict`` should contains key ``default`` with a ``float`` value). + NNI will take the float number as the model score, and assume the higher score means the better performance. + If you want to provide additional information, please put it into a dict + and NNI will take the value of key ``default`` as evaluation metric. + + Notes + ----- + It is also worth to note that not all the arguments of ``DeepspeedTorchEvaluator`` must be provided. + Some pruners (or quantizers) only require ``evaluating_func`` as they do not train the model, + some pruners (or quantizers) only require ``training_func``. + Please refer to each pruner's (or quantizer's) doc to check the required arguments. + But, it is fine to provide more arguments than the pruner's (or quantizer's) need. + """ + + def __init__(self, training_func: _TRAINING_FUNC, training_step: _TRAINING_STEP, deepspeed: str | Dict, + optimizer: Optimizer | Callable[[List[Tensor]], Optimizer] | None = None, + lr_scheduler: SCHEDULER | Callable[[Optimizer], SCHEDULER] | None = None, + resume_from_checkpoint_args: Dict | None = None, dummy_input: Any | None = None, + evaluating_func: _EVALUATING_FUNC | None = None): + assert ACCELERATE_INSTALLED, "accelerate is not installed" + assert DEEPSPEED_INSTALLED, "deepspeed is not installed" + self.training_func = training_func + self._ori_training_step = training_step + self._training_step = self._ori_training_step + self.dummy_input = dummy_input + self.evaluating_func = evaluating_func + self.resume_from_checkpoint_args = resume_from_checkpoint_args + self._ori_optimizer_step = None + + self.model: Module | None = None + self.optimizer: Optimizer | Callable[[List[Tensor]], Optimizer] | None = None + self.lr_scheduler: SCHEDULER | Callable[[Optimizer], SCHEDULER] | None = None + self._ori_optimizer_step: Callable | None = None + self._param_names_map: Dict[str, str] | None = None + self.deepspeed_engine = None + + # will del self._tmp_optimizer and self._tmp_lr_scheduler in `_init_optimizer_helpers` + self._tmp_optimizer: Optimizer | Callable[[List[Tensor]], Optimizer] | None = optimizer + self._tmp_lr_scheduler: SCHEDULER | Callable[[Optimizer], SCHEDULER] | None = lr_scheduler + self._initialization_complete = False + + self.deepspeed_config: DeepSpeedConfig | None = self.process_deepspeed(deepspeed) + + def process_deepspeed(self, config_file_or_dict: str | Dict) -> DeepSpeedConfig: + if config_file_or_dict is None: + raise ValueError('deepspeed_config should not be None') + assert isinstance(config_file_or_dict, (Dict, str)), \ + f"Only two types: Dict and str are supported for config_file_or_dict, but got {type(config_file_or_dict)}" + return DeepSpeedConfig(config_file_or_dict) + + def check_optim_sched(self) -> None: + assert self._tmp_optimizer is None or isinstance(self._tmp_optimizer, Optimizer) or callable(self._tmp_optimizer) + assert self._tmp_lr_scheduler is None or isinstance(self._tmp_lr_scheduler, SCHEDULER) or callable(self._tmp_lr_scheduler) + # check the validation of optimizer + if isinstance(self._tmp_optimizer, Optimizer): + assert is_traceable(self._tmp_optimizer) + # check the validation of scheduler + if isinstance(self._tmp_lr_scheduler, SCHEDULER): + assert is_traceable(self._tmp_lr_scheduler) + + # there are 9 cases: + # case 1: opt = None, sche = None, depends on the optimizer configuration in deepspeed_config + # case 2: opt = Callback, sche = None, ok + # case 3: opt = Optim, sche = None, ok + # case 4: opt = None, sche = Callback, depends on the optimizer configuration in deepspeed_config + # case 5: opt = Callback, sche = Callback, ok + # case 6: opt = Optim, sche = Callback, ok + # case 7: opt = None, sche = Scheduler, X + # case 8: opt = Callback, sche = Scheduler, X + # case 9: opt = Optim, sche = Scheduler, ok + assert hasattr(self, "deepspeed_config") and self.deepspeed_config is not None + if self._tmp_optimizer is not None and self.deepspeed_config.get_value('optimizer') is not None: + raise ValueError("Please provide the optimizer during the evaluator's initialization or in the" + + "deepspeed_config, but don\'t provide both at the same time.") + + # case 1: optimizer is None and config is None + if self._tmp_optimizer is None and self.deepspeed_config.get_value('optimizer') is None: + raise ValueError("Optimizer and optimizer configuration in deepspeed config" + + "can\'t be None at the same time, please provide one") + # case 2: optimizer is Callable or None, but scheduler is _SCHEUDLER + if not isinstance(self._tmp_optimizer, Optimizer) and isinstance(self._tmp_lr_scheduler, SCHEDULER): + raise ValueError("Don't support for non-instance optimizer and instance scheduler pair") + + def _init_optimizer_helpers(self, pure_model: Module): + assert self._initialization_complete is False, 'Evaluator initialization is already complete.' + # check the validation of optimizer and scheduler + self.check_optim_sched() + if isinstance(self._tmp_optimizer, Optimizer): + self._optimizer_helper = OptimizerConstructHelper.from_trace(pure_model, self._tmp_optimizer) + else: + self.optimizer = self._tmp_optimizer + if isinstance(self._tmp_lr_scheduler, SCHEDULER): + self._lr_scheduler_helper = LRSchedulerConstructHelper.from_trace(self._tmp_lr_scheduler) + else: + self.lr_scheduler = self._tmp_lr_scheduler + + delattr(self, '_tmp_optimizer') + delattr(self, '_tmp_lr_scheduler') + self._initialization_complete = True + + def _rewrap_if_ddp_model(self, model): + errmsg = "model is None, no need to rewrap model to DistributedDatapallel model" + assert model is not None, errmsg + is_ddp_model, _ = check_ddp_model(model) + + if is_ddp_model: + raise RuntimeError("DeepSpeed will provide DDP logic so that your model should not be wrapped with DistributedParallel") + + return model + + def bind_model(self, model: Module, param_names_map: Dict[str, str] | None = None): + err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.' + assert self._initialization_complete is True, err_msg + assert isinstance(model, Module) + if self.model is not None: + _logger.warning('Already bound a model, will unbind it before bind a new model.') + self.unbind_model() + + is_ddp_model, _ = check_ddp_model(model) + assert not is_ddp_model, \ + "DeepSpeed will automatically initialize the distributed environment during its initialize" + self.model = model + self._param_names_map = param_names_map + # initialize optimizers & lr_schedulers for the bound model here + if hasattr(self, '_optimizer_helper'): + self.optimizer = self._optimizer_helper.call(model, param_names_map) + if hasattr(self, '_lr_scheduler_helper'): + self.lr_scheduler = self._lr_scheduler_helper.call(self.optimizer) # type: ignore + + def deepspeed_init(self, inference=False): + assert self.model is not None + assert self.deepspeed_config is not None + # whether to check the validation of params + deepspeed_config: DeepSpeedConfig = deepcopy(self.deepspeed_config) + config: Dict = deepspeed_config.config # type: ignore + if inference: + # only Z3 makes sense for the inference + if not deepspeed_config.is_zero3(): + raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") + + # in case the training config is re-used for inference + deepspeed_config.del_config_sub_tree("optimizer") + deepspeed_config.del_config_sub_tree("lr_scheduler") + optimizer, lr_scheduler = None, None + model_parameters = None + else: + model_parameters = list(filter(lambda p: p.requires_grad, self.model.parameters())) + optimizer, lr_scheduler = self.optimizer, self.lr_scheduler + + # deepspeed init + kwargs = { + "model": self.model, + "model_parameters": model_parameters, + "config_params": config, + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + } + deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + # load deepspeed checkpoint + if self.resume_from_checkpoint_args is not None: + # it's possible that the user is trying to resume from model_path, which doesn't necessarily + # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's + # a resume from a checkpoint and not just a local pretrained weight. So we check here if the + # path contains what looks like a deepspeed checkpoint + resume_from_checkpoint = self.resume_from_checkpoint_args.get("load_dir", None) + assert resume_from_checkpoint is not None + tag = self.resume_from_checkpoint_args.get('tag', "global_step") + load_module_strict = self.resume_from_checkpoint_args.get('load_module_strict', True) + load_optimizer_states = self.resume_from_checkpoint_args.get('load_optimizer_states', True) + load_lr_scheduler_states = self.resume_from_checkpoint_args.get('load_lr_scheduler_states', True) + load_module_only = self.resume_from_checkpoint_args.get('load_module_only', False) + custom_load_fn = self.resume_from_checkpoint_args.get("custom_load_fn", None) + # copyed from transformers + import glob + # TODO to add load model from tag + deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/{tag}*")) + + if len(deepspeed_checkpoint_dirs) > 0: + # logger.info(f"Attempting to resume from {self.resume_from_checkpoint}") + # this magically updates self.optimizer and self.lr_scheduler + # load_path, _ = deepspeed_engine.load_checkpoint( + # self.resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True + # ) + load_path, _ = self.load_checkpoint(resume_from_checkpoint, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + custom_load_fn=custom_load_fn) + if load_path is None: + raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}") + else: + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + # record the original deepspeed step function + self._ori_optimizer_step = deepspeed_engine.step + + return deepspeed_engine, optimizer, lr_scheduler + + def patch_optim_param_group(self, module_name_param_dict: Dict[str, List[Tensor]]): + if not isinstance(self.optimizer, Optimizer): + return + # used for adding param_group without deepspeed config + assert isinstance(self.model, Module) + assert module_name_param_dict is not None + self._optimizer_add_param_group(self.model, module_name_param_dict, [self.optimizer]) # type: ignore + + def unbind_model(self): + if self.model: + self.revert_loss() + self.revert_optimizer_step() + self.remove_all_hooks() + self.lr_scheduler = None + self.optimizer = None + self._param_names_map = None + self.model = None + # TODO to check if unibind deepspeed params is needed + self.deepspeed_engine = None + self.deepspeed_config = None + else: + _logger.warning('Did not bind any model, no need to unbind model.') + + def patch_loss(self, patch: Callable[[Tensor, Any], Tensor]): + old_training_step = self._training_step + + def patched_training_step(*args, **kwargs): + out = old_training_step(*args, **kwargs) + # we assume in training_step, ``batch`` is the first argument + batch = args[0] if len(args) > 0 else kwargs['batch'] + if isinstance(out, Tensor): + out = patch(out, batch) + elif isinstance(out, Sequence) and not isinstance(out, str): + assert isinstance(out[0], Tensor) + new_loss = patch(out[0], batch) + out = (new_loss,) + tuple(out[1:]) + elif isinstance(out, MutableMapping): + assert 'loss' in out and isinstance(out['loss'], Tensor) + out['loss'] = patch(out['loss'], batch) + return out + + self._training_step: _TRAINING_STEP = patched_training_step + + def revert_loss(self): + self._training_step = self._ori_training_step + + def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]): + self.is_patch_optim_for_ds = True + self.before_step_tasks = before_step_tasks + self.after_step_tasks = after_step_tasks + + def revert_optimizer_step(self): + assert self.deepspeed_engine is not None + if self._ori_optimizer_step is not None: + self.deepspeed_engine.step = self._ori_optimizer_step + + def patch_engine_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]): + assert self.deepspeed_engine is not None + old_step = self.deepspeed_engine.step + + def patched_step(_, *args, **kwargs): + for task in before_step_tasks: + task() + # call origin optimizer step method + output = old_step(*args, **kwargs) + for task in after_step_tasks: + task() + return output + + self.deepspeed_engine.step = types.MethodType(patched_step, self.deepspeed_engine) + + def train(self, max_steps: int | None = None, max_epochs: int | None = None): + # deepspeed init + deepspeed_engine, optimizer, lr_scheduler = self.deepspeed_init(inference=False) + self.deepspeed_engine = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.model = deepspeed_engine.module + + assert self.deepspeed_engine is not None + assert self.optimizer is not None + assert self._training_step is not None + + if hasattr(self, 'is_patch_optim_for_ds') and self.is_patch_optim_for_ds: + self.patch_engine_step(self.before_step_tasks, self.after_step_tasks) + + self.training_func(self.deepspeed_engine, self.optimizer, self._training_step, self.lr_scheduler, max_steps, max_epochs) + + def finetune(self): + self.train() + + def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]: + # assert self.model is not None + if self.evaluating_func is None: + warn_msg = f'Did not pass evaluation_func to {self.__class__.__name__}, will return None for calling evaluate()' + _logger.warning(warn_msg) + return None + + if self.deepspeed_engine is None: + deepspeed_engine, _, _ = self.deepspeed_init(inference=True) + self.deepspeed_engine = deepspeed_engine + self.model = self.deepspeed_engine.module + + assert self.deepspeed_engine is not None + + metric = self.evaluating_func(self.deepspeed_engine) + if isinstance(metric, dict): + nni_used_metric = metric.get('default', None) + if nni_used_metric is None: + warn_msg = f'Evaluation function returns a dict metric without key `default`,' + \ + 'will return None as the model evaluation metric value.' + _logger.warning(warn_msg) + return nni_used_metric, metric + else: + return metric + + def get_dummy_input(self) -> Any: + return self.dummy_input + + def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True): + """ + Save training checkpoint + + Parameters + ---------- + save_dir + Required. Directory for saving the checkpoint + tag + Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is + used if not provided. Tag name must be the same across all ranks. + client_state + Optional. State dictionary used for saving required training states in the client code. + save_latest + Optional. Save a file 'latest' pointing to the latest saved checkpoint. + + Notes + ----- + Important: all processes must call this method and not just the process with rank 0. It is + because each process needs to save its master weights and scheduler+optimizer states. This + method will hang waiting to synchronize with other processes if it's called just for the + process with rank 0. + """ + + # copyed from deepspeed + assert self.deepspeed_engine is not None + return self.deepspeed_engine.save_checkpoint(save_dir, tag, client_state=client_state, + save_latest=save_latest) + + def load_checkpoint(self, + load_dir, + tag=None, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + custom_load_fn=None): + """ + Load training checkpoint + + Parameters + ---------- + load_dir + Required. Directory to load the checkpoint from + tag + Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file + load_module_strict + Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. + load_optimizer_states + Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance + load_lr_scheduler_states + Optional. Boolean to add the learning rate scheduler states from Checkpoint. + load_module_only + Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting. + custom_load_fn + Optional. Custom model load function. + + Returns + ------- + load_path + Path of the loaded checkpoint. None if loading the checkpoint failed. + client_state + State dictionary used for loading required training states in the client code. + + Notes + ----- + Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right + after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and + ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine + before ``load_checkpoint()``. + """ + + # copyed from deepspeed + assert self.deepspeed_engine is not None + return self.deepspeed_engine.load_checkpoint(load_dir, + tag=tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + custom_load_fn=custom_load_fn) \ No newline at end of file diff --git a/pipelines/templates/install-dependencies.yml b/pipelines/templates/install-dependencies.yml index 28a5f8986..44811b080 100644 --- a/pipelines/templates/install-dependencies.yml +++ b/pipelines/templates/install-dependencies.yml @@ -33,14 +33,6 @@ steps: displayName: (Ubuntu) Downgrade swig condition: and(succeeded(), contains('${{ parameters.platform }}', 'ubuntu')) -- script: | - set -e - brew install 'swig@3' - rm /usr/local/bin/swig - ln -s '/usr/local/opt/swig@3/bin/swig' /usr/local/bin/swig - displayName: (macOS) Downgrade swig - condition: and(succeeded(), contains('${{ parameters.platform }}', 'macos')) - - script: | set -e azcopy copy 'https://nni.blob.core.windows.net/cache/dependencies-${{ parameters.platform }}.zip?$(sas_cache)' dependencies.zip diff --git a/test/vso_tools/build_vm/setup_linux.sh b/test/vso_tools/build_vm/setup_linux.sh index f64a9f887..e2906d3c0 100755 --- a/test/vso_tools/build_vm/setup_linux.sh +++ b/test/vso_tools/build_vm/setup_linux.sh @@ -74,4 +74,4 @@ sudo systemctl disable apt-daily-upgrade.service # Deprovision sudo /usr/sbin/waagent -force -deprovision -sudo HISTSIZE=0 sync +sudo HISTSIZE=0 sync \ No newline at end of file