зеркало из https://github.com/microsoft/DeepSpeed.git
Merge branch 'master' into jomayeri/aio-file-offset
This commit is contained in:
Коммит
2f5446a584
|
@ -12,6 +12,7 @@ on:
|
|||
type: string
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/nv-ds-chat.yml"
|
||||
- "deepspeed/runtime/zero/stage_1_and_2.py"
|
||||
- "deepspeed/runtime/zero/stage3.py"
|
||||
- "deepspeed/runtime/hybrid_engine.py"
|
||||
|
@ -42,6 +43,7 @@ jobs:
|
|||
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
pip install transformers==4.45.2
|
||||
pip install .[dev]
|
||||
ds_report
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram
|
|||
| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) |
|
||||
| Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml) |
|
||||
| Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) |
|
||||
| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/cosdt/DeepSpeed/actions/workflows/huawei-ascend-npu.yml/badge.svg?branch=master)](https://github.com/cosdt/DeepSpeed/actions/workflows/huawei-ascend-npu.yml) |
|
||||
| Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml/badge.svg?branch=main)](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml) |
|
||||
|
||||
# Installation
|
||||
|
||||
|
|
|
@ -101,7 +101,15 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
|
|||
|
||||
#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64
|
||||
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
|
||||
if (threads_per_head == 64) { \
|
||||
if (threads_per_head == 4) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \
|
||||
} else if (threads_per_head == 8) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \
|
||||
} else if (threads_per_head == 16) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \
|
||||
} else if (threads_per_head == 32) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \
|
||||
} else if (threads_per_head == 64) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \
|
||||
} else { \
|
||||
assert(false); \
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.ops.op_builder import InferenceBuilder
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
|
||||
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.mark.inference_ops
|
||||
@pytest.mark.parametrize("num_heads", [64, 32, 16, 8])
|
||||
def test_rope_warp_size_alignment(num_heads):
|
||||
if get_accelerator().device_name() != "cuda":
|
||||
pytest.skip("This test runs only on GPU")
|
||||
|
||||
batch = 1
|
||||
head = 8
|
||||
seq_len = 1024
|
||||
head_dim = 32
|
||||
rotary_dim = 32
|
||||
offset = 8
|
||||
rotate_half = False
|
||||
rope_theta = 2
|
||||
|
||||
cuda0 = torch.device('cuda:0')
|
||||
query = torch.randn(batch, head, seq_len, head_dim, device=cuda0)
|
||||
key = torch.randn(batch, head, seq_len, head_dim, device=cuda0)
|
||||
|
||||
inference = InferenceBuilder().load()
|
||||
# For num_heads values of 64, 32, 16, 8
|
||||
# corresponding threads_per_head (defined in apply_rotary_pos_emb.cu) values are 4, 8, 16, 32
|
||||
inference.apply_rotary_pos_emb(query, key, rotary_dim, offset, num_heads, rotate_half, rope_theta)
|
Загрузка…
Ссылка в новой задаче