зеркало из https://github.com/microsoft/DeepSpeed.git
Re-enable GPT-J unit tests and refactor inference tests (#3618)
This commit is contained in:
Родитель
7726fc8d54
Коммит
78b7693591
|
@ -0,0 +1,5 @@
|
|||
[flake8]
|
||||
ignore = E,F403,F405,F541,F841,W
|
||||
select = E9,F,W6
|
||||
per-file-ignores =
|
||||
__init__.py:F401
|
|
@ -67,7 +67,7 @@ repos:
|
|||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
|
||||
args: ['--config=.flake8']
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
|
|
|
@ -6,3 +6,4 @@ markers =
|
|||
inference_ops:Individual inference operator tests
|
||||
seq_inference:Inference model tests to run sequentially
|
||||
nightly:Tests that should be run nightly
|
||||
world_size:Change world size of individual tests in a class
|
||||
|
|
|
@ -49,75 +49,57 @@ _gpt_models = [
|
|||
"gpt2",
|
||||
"distilgpt2",
|
||||
"Norod78/hebrew-bad_wiki-gpt_neo-tiny",
|
||||
"EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later!
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m-deduped",
|
||||
"bigscience/bloom-560m",
|
||||
]
|
||||
_opt_models = [
|
||||
"facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture.
|
||||
"facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants.
|
||||
]
|
||||
_all_models = HfApi().list_models()
|
||||
|
||||
test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
|
||||
test_tasks = [
|
||||
_test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
|
||||
_test_tasks = [
|
||||
"fill-mask", "question-answering", "text-classification", "token-classification", "text-generation",
|
||||
"text2text-generation", "summarization", "translation"
|
||||
]
|
||||
pytest.all_models = {task: [m.modelId for m in _all_models if m.pipeline_tag == task] for task in test_tasks}
|
||||
|
||||
_model_w_tasks = itertools.product(*[test_models, test_tasks])
|
||||
# Get a list of all models and mapping from task to supported models
|
||||
_hf_models = HfApi().list_models()
|
||||
_hf_model_names = [m.modelId for m in _hf_models]
|
||||
_hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}
|
||||
|
||||
# Get all combinations of task:model to test
|
||||
_model_w_tasks = [(m, t) for m, t in itertools.product(*[_test_models, _test_tasks]) if m in _hf_task_to_models[t]]
|
||||
|
||||
# Assign to pytest variables for testing
|
||||
pytest.model_w_tasks = _model_w_tasks
|
||||
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.model_w_tasks]
|
||||
|
||||
|
||||
def _valid_model_task(model_task):
|
||||
m, t = model_task
|
||||
return m in pytest.all_models[t]
|
||||
|
||||
|
||||
pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks))
|
||||
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
|
||||
"""
|
||||
These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
|
||||
def model_w_task(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
|
||||
def dtype(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
|
||||
def enable_cuda_graph(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
|
||||
def enable_triton(request):
|
||||
return request.param
|
||||
|
||||
|
||||
"""
|
||||
This fixture will validate the configuration
|
||||
"""
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def verify_models():
|
||||
# Verify all test models are registered in HF
|
||||
_test_models_not_found = [m for m in _test_models if m not in _hf_model_names]
|
||||
if _test_models_not_found:
|
||||
pytest.fail(f"Model(s) not found in HuggingFace: {_test_models_not_found}")
|
||||
|
||||
# Verify all models are assigned to at least one task
|
||||
_models_to_be_tested = set(m for m, t in _model_w_tasks)
|
||||
_missing_task_models = _models_to_be_tested.difference(_test_models)
|
||||
if _missing_task_models:
|
||||
pytest.fail(f"Model(s) do not have an assigned task: {_missing_task_models}")
|
||||
|
||||
|
||||
# Fixture to add skips for certain configurations
|
||||
@pytest.fixture()
|
||||
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton):
|
||||
def invalid_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
|
||||
model, task = model_w_task
|
||||
msg = ""
|
||||
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
|
||||
msg = "DS inference injection doesn't work well on older torch versions"
|
||||
elif model not in pytest.all_models[task]:
|
||||
msg = f"Not a valid model / task combination: {model} / {task}"
|
||||
elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
|
||||
if enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
|
||||
msg = "CUDA not detected, cannot use CUDA Graph"
|
||||
elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
|
||||
msg = "CUDA Graph is only available in torch versions >= 1.10"
|
||||
elif "gpt-j-6B" in model:
|
||||
elif "gpt-j-6b" in model:
|
||||
if dtype != torch.half:
|
||||
msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
|
||||
elif enable_cuda_graph:
|
||||
|
@ -139,10 +121,30 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_tri
|
|||
return msg
|
||||
|
||||
|
||||
"""
|
||||
These fixtures can be used to customize the query, inference args, and assert
|
||||
statement for each combination of model /task
|
||||
"""
|
||||
""" Fixtures for inference config """
|
||||
|
||||
|
||||
@pytest.fixture(params=pytest.model_w_tasks, ids=pytest.mt_names)
|
||||
def model_w_task(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
|
||||
def dtype(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
|
||||
def enable_cuda_graph(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
|
||||
def enable_triton(request):
|
||||
return request.param
|
||||
|
||||
|
||||
""" Fixtures for running query """
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -178,7 +180,7 @@ def query(model_w_task):
|
|||
def inf_kwargs(model_w_task):
|
||||
model, task = model_w_task
|
||||
if task == "text-generation":
|
||||
if model == "EleutherAI/gpt-j-6B":
|
||||
if model == "EleutherAI/gpt-j-6b":
|
||||
# This model on V100 is hitting memory problems that limit the number of output tokens
|
||||
return {"do_sample": False, "max_length": 12}
|
||||
return {"do_sample": False, "max_length": 20}
|
||||
|
@ -186,6 +188,9 @@ def inf_kwargs(model_w_task):
|
|||
return {}
|
||||
|
||||
|
||||
""" Assertion fixture for verifying model outputs """
|
||||
|
||||
|
||||
def fill_mask_assert(x, y):
|
||||
return set(res["token_str"] for res in x) == set(res["token_str"] for res in y)
|
||||
|
||||
|
@ -237,6 +242,7 @@ def assert_fn(model_w_task):
|
|||
return assert_fn
|
||||
|
||||
|
||||
# Used to verify DeepSpeed kernel injection worked with a model
|
||||
def check_injection(model):
|
||||
|
||||
def verify_injection(module):
|
||||
|
@ -251,27 +257,24 @@ def check_injection(model):
|
|||
verify_injection(model)
|
||||
|
||||
|
||||
"""
|
||||
Tests
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
class TestModelTask(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self,
|
||||
model_w_task,
|
||||
dtype,
|
||||
enable_cuda_graph,
|
||||
enable_triton,
|
||||
query,
|
||||
inf_kwargs,
|
||||
assert_fn,
|
||||
invalid_model_task_config,
|
||||
perf_meas=True):
|
||||
if invalid_model_task_config:
|
||||
pytest.skip(invalid_model_task_config)
|
||||
def test(
|
||||
self,
|
||||
model_w_task,
|
||||
dtype,
|
||||
enable_cuda_graph,
|
||||
enable_triton,
|
||||
query,
|
||||
inf_kwargs,
|
||||
assert_fn,
|
||||
invalid_test,
|
||||
perf_meas=True,
|
||||
):
|
||||
if invalid_test:
|
||||
pytest.skip(invalid_test)
|
||||
|
||||
model, task = model_w_task
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
@ -338,10 +341,10 @@ class TestModelTask(DistributedTest):
|
|||
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
|
||||
("EleutherAI/gpt-neox-20b", "text-generation"),
|
||||
("bigscience/bloom-3b", "text-generation"),
|
||||
("EleutherAI/gpt-j-6B", "text-generation")],
|
||||
("EleutherAI/gpt-j-6b", "text-generation")],
|
||||
ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"])
|
||||
class TestMPSize(DistributedTest):
|
||||
world_size = 4
|
||||
world_size = 2
|
||||
|
||||
def test(
|
||||
self,
|
||||
|
@ -350,10 +353,10 @@ class TestMPSize(DistributedTest):
|
|||
query,
|
||||
inf_kwargs,
|
||||
assert_fn,
|
||||
invalid_model_task_config,
|
||||
invalid_test,
|
||||
):
|
||||
if invalid_model_task_config:
|
||||
pytest.skip(invalid_model_task_config)
|
||||
if invalid_test:
|
||||
pytest.skip(invalid_test)
|
||||
|
||||
model, task = model_w_task
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
@ -402,12 +405,12 @@ class TestInjectionPolicy(DistributedTest):
|
|||
query,
|
||||
inf_kwargs,
|
||||
assert_fn,
|
||||
invalid_model_task_config,
|
||||
invalid_test,
|
||||
dtype,
|
||||
enable_cuda_graph,
|
||||
):
|
||||
if invalid_model_task_config:
|
||||
pytest.skip(invalid_model_task_config)
|
||||
if invalid_test:
|
||||
pytest.skip(invalid_test)
|
||||
|
||||
model, task = model_w_task
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
@ -452,12 +455,12 @@ class TestAutoTensorParallelism(DistributedTest):
|
|||
query,
|
||||
inf_kwargs,
|
||||
assert_fn,
|
||||
invalid_model_task_config,
|
||||
invalid_test,
|
||||
dtype,
|
||||
enable_cuda_graph,
|
||||
):
|
||||
if invalid_model_task_config:
|
||||
pytest.skip(invalid_model_task_config)
|
||||
if invalid_test:
|
||||
pytest.skip(invalid_test)
|
||||
|
||||
model, task = model_w_task
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
@ -483,7 +486,7 @@ class TestAutoTensorParallelism(DistributedTest):
|
|||
"model_family, model_name",
|
||||
(
|
||||
["gpt2", "EleutherAI/gpt-neo-2.7B"],
|
||||
["gpt2", "EleutherAI/gpt-j-6B"],
|
||||
["gpt2", "EleutherAI/gpt-j-6b"],
|
||||
["gpt2", "gpt2-xl"],
|
||||
),
|
||||
)
|
||||
|
@ -503,7 +506,7 @@ class TestLMCorrectness(DistributedTest):
|
|||
dtype = torch.float
|
||||
task_dict = lm_eval.tasks.get_task_dict([task])
|
||||
|
||||
if 'gpt-j-6B' in model_name:
|
||||
if 'gpt-j-6b' in model_name:
|
||||
dtype = torch.half
|
||||
lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
|
||||
{"device": "cpu"})
|
||||
|
|
|
@ -13,43 +13,17 @@ from unit.common import DistributedTest
|
|||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def query(model, task):
|
||||
if task == "text-generation":
|
||||
return "DeepSpeed is"
|
||||
elif task == "fill-mask":
|
||||
if "roberta" in model:
|
||||
return "I am a <mask> model"
|
||||
else:
|
||||
return "I am a [MASK] model"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inf_kwargs(task):
|
||||
if task == "text-generation":
|
||||
return {"do_sample": False, "min_length": 50, "max_length": 50}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
@pytest.mark.parametrize("model,task", [
|
||||
("bert-base-cased", "fill-mask"),
|
||||
("roberta-base", "fill-mask"),
|
||||
("gpt2", "text-generation"),
|
||||
("facebook/opt-125m", "text-generation"),
|
||||
("bigscience/bloom-560m", "text-generation"),
|
||||
])
|
||||
@pytest.mark.parametrize("cuda_graphs", [True, False])
|
||||
@pytest.mark.parametrize("use_cuda_events", [True, False])
|
||||
@pytest.mark.parametrize("enable_cuda_graph", [True, False])
|
||||
class TestModelProfiling(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dtype=torch.float16):
|
||||
if cuda_graphs and "bert" not in model:
|
||||
pytest.skip(f"CUDA Graph not supported for {model}")
|
||||
def test(self, enable_cuda_graph, use_cuda_events):
|
||||
task = "fill-mask"
|
||||
model = "bert-base-cased"
|
||||
dtype = torch.float16
|
||||
query = "I am a [MASK] model"
|
||||
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
@ -59,7 +33,7 @@ class TestModelProfiling(DistributedTest):
|
|||
dtype=dtype,
|
||||
mp_size=world_size,
|
||||
replace_with_kernel_inject=True,
|
||||
enable_cuda_graph=cuda_graphs)
|
||||
enable_cuda_graph=enable_cuda_graph)
|
||||
pipe.model.profile_model_time(use_cuda_events=use_cuda_events)
|
||||
|
||||
e2e_times = []
|
||||
|
@ -68,7 +42,7 @@ class TestModelProfiling(DistributedTest):
|
|||
get_accelerator().synchronize()
|
||||
start = time.perf_counter_ns()
|
||||
|
||||
r = pipe(query, **inf_kwargs)
|
||||
r = pipe(query)
|
||||
|
||||
get_accelerator().synchronize()
|
||||
end = time.perf_counter_ns()
|
||||
|
|
Загрузка…
Ссылка в новой задаче