Re-enable GPT-J unit tests and refactor inference tests (#3618)

This commit is contained in:
Michael Wyatt 2023-06-28 12:33:55 -07:00 коммит произвёл GitHub
Родитель 7726fc8d54
Коммит 78b7693591
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 102 добавлений и 119 удалений

5
.flake8 Normal file
Просмотреть файл

@ -0,0 +1,5 @@
[flake8]
ignore = E,F403,F405,F541,F841,W
select = E9,F,W6
per-file-ignores =
__init__.py:F401

Просмотреть файл

@ -67,7 +67,7 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
args: ['--config=.flake8']
- repo: local
hooks:

Просмотреть файл

@ -6,3 +6,4 @@ markers =
inference_ops:Individual inference operator tests
seq_inference:Inference model tests to run sequentially
nightly:Tests that should be run nightly
world_size:Change world size of individual tests in a class

Просмотреть файл

@ -49,75 +49,57 @@ _gpt_models = [
"gpt2",
"distilgpt2",
"Norod78/hebrew-bad_wiki-gpt_neo-tiny",
"EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later!
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m-deduped",
"bigscience/bloom-560m",
]
_opt_models = [
"facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture.
"facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants.
]
_all_models = HfApi().list_models()
test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
test_tasks = [
_test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
_test_tasks = [
"fill-mask", "question-answering", "text-classification", "token-classification", "text-generation",
"text2text-generation", "summarization", "translation"
]
pytest.all_models = {task: [m.modelId for m in _all_models if m.pipeline_tag == task] for task in test_tasks}
_model_w_tasks = itertools.product(*[test_models, test_tasks])
# Get a list of all models and mapping from task to supported models
_hf_models = HfApi().list_models()
_hf_model_names = [m.modelId for m in _hf_models]
_hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}
# Get all combinations of task:model to test
_model_w_tasks = [(m, t) for m, t in itertools.product(*[_test_models, _test_tasks]) if m in _hf_task_to_models[t]]
# Assign to pytest variables for testing
pytest.model_w_tasks = _model_w_tasks
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.model_w_tasks]
def _valid_model_task(model_task):
m, t = model_task
return m in pytest.all_models[t]
pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks))
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
"""
These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph
"""
@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param
"""
This fixture will validate the configuration
"""
@pytest.fixture(scope="module", autouse=True)
def verify_models():
# Verify all test models are registered in HF
_test_models_not_found = [m for m in _test_models if m not in _hf_model_names]
if _test_models_not_found:
pytest.fail(f"Model(s) not found in HuggingFace: {_test_models_not_found}")
# Verify all models are assigned to at least one task
_models_to_be_tested = set(m for m, t in _model_w_tasks)
_missing_task_models = _models_to_be_tested.difference(_test_models)
if _missing_task_models:
pytest.fail(f"Model(s) do not have an assigned task: {_missing_task_models}")
# Fixture to add skips for certain configurations
@pytest.fixture()
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton):
def invalid_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task
msg = ""
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
msg = "DS inference injection doesn't work well on older torch versions"
elif model not in pytest.all_models[task]:
msg = f"Not a valid model / task combination: {model} / {task}"
elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
if enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
msg = "CUDA not detected, cannot use CUDA Graph"
elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
msg = "CUDA Graph is only available in torch versions >= 1.10"
elif "gpt-j-6B" in model:
elif "gpt-j-6b" in model:
if dtype != torch.half:
msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
elif enable_cuda_graph:
@ -139,10 +121,30 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_tri
return msg
"""
These fixtures can be used to customize the query, inference args, and assert
statement for each combination of model /task
"""
""" Fixtures for inference config """
@pytest.fixture(params=pytest.model_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param
@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param
@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param
""" Fixtures for running query """
@pytest.fixture
@ -178,7 +180,7 @@ def query(model_w_task):
def inf_kwargs(model_w_task):
model, task = model_w_task
if task == "text-generation":
if model == "EleutherAI/gpt-j-6B":
if model == "EleutherAI/gpt-j-6b":
# This model on V100 is hitting memory problems that limit the number of output tokens
return {"do_sample": False, "max_length": 12}
return {"do_sample": False, "max_length": 20}
@ -186,6 +188,9 @@ def inf_kwargs(model_w_task):
return {}
""" Assertion fixture for verifying model outputs """
def fill_mask_assert(x, y):
return set(res["token_str"] for res in x) == set(res["token_str"] for res in y)
@ -237,6 +242,7 @@ def assert_fn(model_w_task):
return assert_fn
# Used to verify DeepSpeed kernel injection worked with a model
def check_injection(model):
def verify_injection(module):
@ -251,27 +257,24 @@ def check_injection(model):
verify_injection(model)
"""
Tests
"""
@pytest.mark.inference
class TestModelTask(DistributedTest):
world_size = 1
def test(self,
model_w_task,
dtype,
enable_cuda_graph,
enable_triton,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
perf_meas=True):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
def test(
self,
model_w_task,
dtype,
enable_cuda_graph,
enable_triton,
query,
inf_kwargs,
assert_fn,
invalid_test,
perf_meas=True,
):
if invalid_test:
pytest.skip(invalid_test)
model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -338,10 +341,10 @@ class TestModelTask(DistributedTest):
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
("EleutherAI/gpt-neox-20b", "text-generation"),
("bigscience/bloom-3b", "text-generation"),
("EleutherAI/gpt-j-6B", "text-generation")],
("EleutherAI/gpt-j-6b", "text-generation")],
ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"])
class TestMPSize(DistributedTest):
world_size = 4
world_size = 2
def test(
self,
@ -350,10 +353,10 @@ class TestMPSize(DistributedTest):
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
invalid_test,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
if invalid_test:
pytest.skip(invalid_test)
model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -402,12 +405,12 @@ class TestInjectionPolicy(DistributedTest):
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
invalid_test,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
if invalid_test:
pytest.skip(invalid_test)
model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -452,12 +455,12 @@ class TestAutoTensorParallelism(DistributedTest):
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
invalid_test,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
if invalid_test:
pytest.skip(invalid_test)
model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
@ -483,7 +486,7 @@ class TestAutoTensorParallelism(DistributedTest):
"model_family, model_name",
(
["gpt2", "EleutherAI/gpt-neo-2.7B"],
["gpt2", "EleutherAI/gpt-j-6B"],
["gpt2", "EleutherAI/gpt-j-6b"],
["gpt2", "gpt2-xl"],
),
)
@ -503,7 +506,7 @@ class TestLMCorrectness(DistributedTest):
dtype = torch.float
task_dict = lm_eval.tasks.get_task_dict([task])
if 'gpt-j-6B' in model_name:
if 'gpt-j-6b' in model_name:
dtype = torch.half
lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
{"device": "cpu"})

Просмотреть файл

@ -13,43 +13,17 @@ from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
@pytest.fixture
def query(model, task):
if task == "text-generation":
return "DeepSpeed is"
elif task == "fill-mask":
if "roberta" in model:
return "I am a <mask> model"
else:
return "I am a [MASK] model"
else:
raise NotImplementedError
@pytest.fixture
def inf_kwargs(task):
if task == "text-generation":
return {"do_sample": False, "min_length": 50, "max_length": 50}
else:
return {}
@pytest.mark.inference
@pytest.mark.parametrize("model,task", [
("bert-base-cased", "fill-mask"),
("roberta-base", "fill-mask"),
("gpt2", "text-generation"),
("facebook/opt-125m", "text-generation"),
("bigscience/bloom-560m", "text-generation"),
])
@pytest.mark.parametrize("cuda_graphs", [True, False])
@pytest.mark.parametrize("use_cuda_events", [True, False])
@pytest.mark.parametrize("enable_cuda_graph", [True, False])
class TestModelProfiling(DistributedTest):
world_size = 1
def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dtype=torch.float16):
if cuda_graphs and "bert" not in model:
pytest.skip(f"CUDA Graph not supported for {model}")
def test(self, enable_cuda_graph, use_cuda_events):
task = "fill-mask"
model = "bert-base-cased"
dtype = torch.float16
query = "I am a [MASK] model"
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
@ -59,7 +33,7 @@ class TestModelProfiling(DistributedTest):
dtype=dtype,
mp_size=world_size,
replace_with_kernel_inject=True,
enable_cuda_graph=cuda_graphs)
enable_cuda_graph=enable_cuda_graph)
pipe.model.profile_model_time(use_cuda_events=use_cuda_events)
e2e_times = []
@ -68,7 +42,7 @@ class TestModelProfiling(DistributedTest):
get_accelerator().synchronize()
start = time.perf_counter_ns()
r = pipe(query, **inf_kwargs)
r = pipe(query)
get_accelerator().synchronize()
end = time.perf_counter_ns()