[DeepSpeed in notebooks] Jupyter + Colab (#10130)
* init devices/setup explicitly * docs + test * simplify * cleanup * cleanup * cleanup * correct the required dist setup * derive local_rank from env LOCAL_RANK
This commit is contained in:
Родитель
6710d1d5ef
Коммит
b54cb0bd82
|
@ -429,6 +429,88 @@ Notes:
|
|||
In this example, we tell DeepSpeed to use GPU 1.
|
||||
|
||||
|
||||
|
||||
Deployment in Notebooks
|
||||
=======================================================================================================================
|
||||
|
||||
The problem with notebooks is that there is no normal ``deepspeed`` launcher to rely on, so under certain setups we
|
||||
have to emulate it.
|
||||
|
||||
Here is how you'd have to adjust your training code in the notebook to use DeepSpeed.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DeepSpeed requires a distributed environment even when only one process is used.
|
||||
# This emulates a launcher in the notebook
|
||||
import os
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '9994' # modify if RuntimeError: Address already in use
|
||||
os.environ['RANK'] = "0"
|
||||
os.environ['LOCAL_RANK'] = "0"
|
||||
os.environ['WORLD_SIZE'] = "1"
|
||||
|
||||
# Now proceed as normal, plus pass the deepspeed config file
|
||||
training_args = TrainingArguments(..., deepspeed="ds_config.json")
|
||||
trainer = Trainer(...)
|
||||
trainer.train()
|
||||
|
||||
Note: `...` stands for the normal arguments that you'd pass to the functions.
|
||||
|
||||
If you want to create the config file on the fly in the notebook in the current directory, you could have a dedicated
|
||||
cell with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
%%bash
|
||||
cat <<'EOT' > ds_config.json
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
},
|
||||
|
||||
"zero_allow_untested_optimizer": true,
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [0.8, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 3e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
|
||||
|
||||
Configuration
|
||||
=======================================================================================================================
|
||||
|
||||
|
|
|
@ -14,13 +14,16 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureStd,
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
mockenv,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
|
@ -52,6 +55,20 @@ def require_deepspeed(test_case):
|
|||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
class TestDeepSpeed(TestCasePlus):
|
||||
|
||||
# this setup emulates a notebook where a launcher needs to be emulated by hand
|
||||
@mockenv(MASTER_ADDR="localhost", MASTER_PORT="109999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1")
|
||||
def test_fake_notebook_no_launcher(self):
|
||||
sys.path.append(self.tests_dir_str)
|
||||
from test_trainer import get_regression_trainer
|
||||
|
||||
del sys.path[-1] # restore
|
||||
ds_config_file = f"{self.test_file_dir_str}/ds_config.json"
|
||||
with CaptureStd() as cs:
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_file)
|
||||
trainer.train()
|
||||
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_basic_distributed(self):
|
||||
self.run_quick(distributed=True)
|
||||
|
|
|
@ -239,6 +239,9 @@ class Trainer:
|
|||
self.hp_name = None
|
||||
self.deepspeed = None
|
||||
|
||||
# force device and distributed setup init explicitly
|
||||
args._setup_devices
|
||||
|
||||
if model is None:
|
||||
if model_init is not None:
|
||||
self.model_init = model_init
|
||||
|
|
|
@ -561,6 +561,12 @@ class TrainingArguments:
|
|||
import deepspeed
|
||||
|
||||
deepspeed.init_distributed()
|
||||
|
||||
# workaround for setups like notebooks where the launcher can't be used,
|
||||
# but deepspeed requires a dist env.
|
||||
# env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
|
||||
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.local_rank == -1:
|
||||
|
|
Загрузка…
Ссылка в новой задаче