зеркало из https://github.com/microsoft/DeepSpeed.git
[testing] 3x faster unit tests (#1636)
This commit is contained in:
Родитель
1d295ff5f8
Коммит
7f58853c2e
|
@ -38,7 +38,7 @@ jobs:
|
|||
run: |
|
||||
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
|
||||
cd tests
|
||||
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
|
||||
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
|
||||
|
||||
nv-torch18-v100:
|
||||
runs-on: [self-hosted, nvidia, torch18, v100]
|
||||
|
@ -65,7 +65,8 @@ jobs:
|
|||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
|
||||
cd tests
|
||||
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose unit/
|
||||
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
|
||||
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
|
||||
|
||||
nv-transformers-v100:
|
||||
runs-on: [self-hosted, nvidia, torch18, v100]
|
||||
|
|
|
@ -6,6 +6,7 @@ pre-commit
|
|||
pytest
|
||||
pytest-forked
|
||||
pytest-randomly
|
||||
pytest-xdist
|
||||
recommonmark
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
|
|
|
@ -15,6 +15,44 @@ from pathlib import Path
|
|||
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
|
||||
|
||||
|
||||
def get_xdist_worker_id():
|
||||
xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None)
|
||||
if xdist_worker is not None:
|
||||
xdist_worker_id = xdist_worker.replace('gw', '')
|
||||
return int(xdist_worker_id)
|
||||
return None
|
||||
|
||||
|
||||
def get_master_port():
|
||||
master_port = os.environ.get('DS_TEST_PORT', '29503')
|
||||
xdist_worker_id = get_xdist_worker_id()
|
||||
if xdist_worker_id is not None:
|
||||
master_port = str(int(master_port) + xdist_worker_id)
|
||||
return master_port
|
||||
|
||||
|
||||
def set_cuda_visibile():
|
||||
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
xdist_worker_id = get_xdist_worker_id()
|
||||
if xdist_worker_id is None:
|
||||
xdist_worker_id = 0
|
||||
if cuda_visible is None:
|
||||
# CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead
|
||||
import subprocess
|
||||
nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
|
||||
num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n'))
|
||||
cuda_visible = ",".join(map(str, range(num_gpus)))
|
||||
|
||||
# rotate list based on xdist worker id, example below
|
||||
# wid=0 -> ['0', '1', '2', '3']
|
||||
# wid=1 -> ['1', '2', '3', '0']
|
||||
# wid=2 -> ['2', '3', '0', '1']
|
||||
# wid=3 -> ['3', '0', '1', '2']
|
||||
dev_id_list = cuda_visible.split(",")
|
||||
dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
|
||||
|
||||
|
||||
def distributed_test(world_size=2, backend='nccl'):
|
||||
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
|
||||
This decorator manages the spawning and joining of processes, initialization of
|
||||
|
@ -36,7 +74,7 @@ def distributed_test(world_size=2, backend='nccl'):
|
|||
def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
|
||||
"""Initialize torch.distributed and execute the user function. """
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = os.environ.get('DS_TEST_PORT', '29503')
|
||||
os.environ['MASTER_PORT'] = get_master_port()
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
# NOTE: unit tests don't support multi-node so local_rank == global rank
|
||||
os.environ['RANK'] = str(local_rank)
|
||||
|
@ -45,6 +83,8 @@ def distributed_test(world_size=2, backend='nccl'):
|
|||
# turn off NCCL logging if set
|
||||
os.environ.pop('NCCL_DEBUG', None)
|
||||
|
||||
set_cuda_visibile()
|
||||
|
||||
deepspeed.init_distributed(dist_backend=backend)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
@ -197,6 +197,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
|
|||
|
||||
|
||||
# FP16 test cases can only run on the devices support FP16.
|
||||
@pytest.mark.sequential
|
||||
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
|
||||
[
|
||||
(64,160,128,2,24,False,True),
|
||||
|
|
|
@ -850,6 +850,7 @@ def test_onebitlamb_fp16_pipeline(topo, tmpdir):
|
|||
_helper(topo, tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.sequential
|
||||
def test_compressed_allreduce_basic(tmpdir):
|
||||
@distributed_test(world_size=[1, 2])
|
||||
def _test_compressed_allreduce_basic():
|
||||
|
|
|
@ -58,14 +58,13 @@ def test_zero_unbalanced_gradients(tmpdir, zero_stage):
|
|||
}
|
||||
}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim)
|
||||
|
||||
@distributed_test(world_size=[1])
|
||||
def _test_zero_unbalanced_gradients(args, model, hidden_dim):
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
def _test_zero_unbalanced_gradients(model, hidden_dim):
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
data_loader = random_dataloader(model=model,
|
||||
|
@ -75,7 +74,7 @@ def test_zero_unbalanced_gradients(tmpdir, zero_stage):
|
|||
|
||||
run_unbalanced_gradients(model, data_loader)
|
||||
|
||||
_test_zero_unbalanced_gradients(args=args, model=model, hidden_dim=hidden_dim)
|
||||
_test_zero_unbalanced_gradients(model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
|
||||
|
@ -103,7 +102,6 @@ def test_zero3_repeat_forward_loop(tmpdir, zero_stage):
|
|||
}
|
||||
}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 4
|
||||
|
||||
class AlbertLikeModel(torch.nn.Module):
|
||||
|
@ -122,8 +120,8 @@ def test_zero3_repeat_forward_loop(tmpdir, zero_stage):
|
|||
model = AlbertLikeModel(hidden_dim=hidden_dim)
|
||||
|
||||
@distributed_test(world_size=[1])
|
||||
def _test_zero3_repeat_forward_loop(args, model, hidden_dim):
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
def _test_zero3_repeat_forward_loop(model, hidden_dim):
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
data_loader = random_dataloader(model=model,
|
||||
|
@ -136,7 +134,7 @@ def test_zero3_repeat_forward_loop(tmpdir, zero_stage):
|
|||
model.backward(loss)
|
||||
model.step()
|
||||
|
||||
_test_zero3_repeat_forward_loop(args=args, model=model, hidden_dim=hidden_dim)
|
||||
_test_zero3_repeat_forward_loop(model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
|
||||
|
@ -189,7 +187,6 @@ def test_zero_to_fp32_1_param_group(tmpdir, zero_stage):
|
|||
hidden = l(hidden)
|
||||
return self.cross_entropy_loss(hidden, y)
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 3 # do not change
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
|
@ -197,7 +194,7 @@ def test_zero_to_fp32_1_param_group(tmpdir, zero_stage):
|
|||
n_layers = world_size * 2
|
||||
model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers)
|
||||
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
data_loader = random_dataloader(model=model,
|
||||
|
@ -284,7 +281,6 @@ def test_zero_to_fp32_2_param_groups(tmpdir, zero_stage):
|
|||
hidden = l(hidden)
|
||||
return self.cross_entropy_loss(hidden, y)
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 3
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
|
@ -303,10 +299,10 @@ def test_zero_to_fp32_2_param_groups(tmpdir, zero_stage):
|
|||
]
|
||||
optim = torch.optim.SGD(optim_groups, lr=0.1)
|
||||
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model=model,
|
||||
model, _, _, _ = deepspeed.initialize(model=model,
|
||||
model_parameters=model.parameters(),
|
||||
optimizer = optim,
|
||||
optimizer=optim,
|
||||
config=config_dict
|
||||
)
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=16,
|
||||
|
@ -370,26 +366,25 @@ def test_incorrect_allgather_bucket_size(tmpdir, zero_stage, allgather_bucket_si
|
|||
}
|
||||
}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim)
|
||||
|
||||
@distributed_test(world_size=[1])
|
||||
def _test_incorrect_allgather_bucket_size(args, model, hidden_dim):
|
||||
def _test_incorrect_allgather_bucket_size(model, hidden_dim):
|
||||
if allgather_bucket_size % 2 == 0:
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
else:
|
||||
with pytest.raises(AssertionError) as assertinfo:
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
assert "allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str(
|
||||
assertinfo)
|
||||
|
||||
_test_incorrect_allgather_bucket_size(args=args, model=model, hidden_dim=hidden_dim)
|
||||
_test_incorrect_allgather_bucket_size(model=model, hidden_dim=hidden_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('zero_stage, world_size', [(2, 2), (2, 3), (2, 4)])
|
||||
|
@ -413,14 +408,13 @@ def test_partition_nccl_alignment(tmpdir, zero_stage, world_size):
|
|||
}
|
||||
}
|
||||
|
||||
args = args_from_dict(tmpdir, config_dict)
|
||||
hidden_dim = 4
|
||||
|
||||
model = SimpleModel(hidden_dim=hidden_dim)
|
||||
|
||||
@distributed_test(world_size=world_size)
|
||||
def _test_partition_nccl_alignment(args, model, hidden_dim):
|
||||
model, _, _, _ = deepspeed.initialize(args=args,
|
||||
def _test_partition_nccl_alignment(model, hidden_dim):
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
|
||||
|
@ -434,4 +428,4 @@ def test_partition_nccl_alignment(tmpdir, zero_stage, world_size):
|
|||
assert (partitioned_data.data_ptr() %
|
||||
(2 * nccl_start_alignment_factor) == 0)
|
||||
|
||||
_test_partition_nccl_alignment(args=args, model=model, hidden_dim=hidden_dim)
|
||||
_test_partition_nccl_alignment(model=model, hidden_dim=hidden_dim)
|
||||
|
|
|
@ -8,13 +8,13 @@ import pytest
|
|||
import deepspeed
|
||||
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape
|
||||
|
||||
from common import distributed_test
|
||||
from common import distributed_test, get_master_port
|
||||
|
||||
|
||||
def setup_serial_env():
|
||||
# Setup for a serial run
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29503'
|
||||
os.environ['MASTER_PORT'] = get_master_port()
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
|
|
Загрузка…
Ссылка в новой задаче