Merge branch 'master' into reduce_scatter_coalesced
|
@ -2,6 +2,10 @@ name: amd-mi200
|
|||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/amd-mi200.yml'
|
||||
- 'requirements/**'
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
|
@ -21,7 +25,7 @@ jobs:
|
|||
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||
steps:
|
||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -24,6 +24,8 @@ jobs:
|
|||
unit-tests:
|
||||
runs-on: [self-hosted, cpu]
|
||||
|
||||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
|
@ -97,5 +99,5 @@ jobs:
|
|||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
# LOCAL_SIZE=2 enforce CPU to report 2 devices, this helps run the test on github default runner
|
||||
LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
|
||||
LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/
|
||||
LOCAL_SIZE=2 COLUMNS=240 HF_HOME=~/tmp/hf_home/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
|
||||
LOCAL_SIZE=2 COLUMNS=240 HF_HOME=~/tmp/hf_home/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/
|
||||
|
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
@ -50,5 +50,5 @@ jobs:
|
|||
run: |
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2"
|
||||
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2"
|
||||
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
|
||||
HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
|
||||
|
|
|
@ -21,7 +21,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: environment
|
||||
run: |
|
||||
|
|
|
@ -8,7 +8,23 @@ on:
|
|||
paths:
|
||||
- ".github/workflows/hpu-gaudi2.yml"
|
||||
- "accelerator/hpu_accelerator.py"
|
||||
|
||||
- "op_builder/hpu/**"
|
||||
- "deepspeed/runtime/engine.py"
|
||||
- "deepspeed/runtime/bf16_optimizer.py"
|
||||
- "deepspeed/runtime/zero/stage_1_and_2.py"
|
||||
- "deepspeed/runtime/zero/stage3.py"
|
||||
- "deepspeed/runtime/zero/partition_parameters.py"
|
||||
- "deepspeed/runtime/zero/partitioned_param_coordinator.py"
|
||||
- "deepspeed/runtime/zero/parameter_offload.py"
|
||||
- "deepspeed/runtime/pipe/engine.py"
|
||||
- "deepspeed/runtime/utils.py"
|
||||
- "deepspeed/inference/engine.py"
|
||||
- "deepspeed/module_inject/auto_tp.py"
|
||||
- "deepspeed/module_inject/replace_module.py"
|
||||
- "deepspeed/module_inject/load_checkpoint.py"
|
||||
- "deepspeed/module_inject/inject.py"
|
||||
- "deepspeed/ops/transformer/**"
|
||||
- "deepspeed/ops/adam/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
@ -23,7 +39,7 @@ jobs:
|
|||
# The type of runner that the job will run on
|
||||
runs-on: [self-hosted, intel, gaudi2]
|
||||
container:
|
||||
image: vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
|
||||
image: vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
|
||||
ports:
|
||||
- 80
|
||||
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
|
||||
|
@ -36,7 +52,6 @@ jobs:
|
|||
test_compression.py
|
||||
test_dist.py
|
||||
test_elastic.py
|
||||
(test_intX_quantization.py and test_quantized_linear)
|
||||
test_ds_arguments.py
|
||||
test_run.py
|
||||
test_multinode_runner.py
|
||||
|
@ -83,7 +98,7 @@ jobs:
|
|||
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||
steps:
|
||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check container state
|
||||
run: |
|
||||
|
|
|
@ -29,7 +29,7 @@ jobs:
|
|||
options: --gpus all --shm-size "8G"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check container state
|
||||
run: |
|
||||
|
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -10,6 +10,11 @@ on:
|
|||
required: false
|
||||
default: 'master'
|
||||
type: string
|
||||
pull_request:
|
||||
paths:
|
||||
- "deepspeed/runtime/zero/stage_1_and_2.py"
|
||||
- "deepspeed/runtime/zero/stage3.py"
|
||||
- "deepspeed/runtime/hybrid_engine.py"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
@ -24,7 +29,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
options: --gpus all --shm-size "8G"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check container state
|
||||
run: |
|
||||
|
|
|
@ -17,7 +17,7 @@ jobs:
|
|||
options: --gpus all --shm-size "8G"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check container state
|
||||
run: |
|
||||
|
|
|
@ -25,7 +25,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -21,6 +21,8 @@ jobs:
|
|||
unit-tests:
|
||||
runs-on: [self-hosted, nvidia, cu111, v100]
|
||||
|
||||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
@ -46,7 +46,7 @@ jobs:
|
|||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
# if needed switch to the last known good SHA until transformers@master is fixed
|
||||
# git checkout 1cc453d33
|
||||
git checkout bdf36dc
|
||||
git rev-parse --short HEAD
|
||||
pip install .
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -26,7 +26,7 @@ jobs:
|
|||
image: deepspeed/gh-builder:ubuntu1804-py38-torch1131-cu116
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: environment
|
||||
run: |
|
||||
|
|
|
@ -33,7 +33,7 @@ jobs:
|
|||
options: --gpus all --shm-size "8G"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check container state
|
||||
run: |
|
||||
|
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
@ -55,5 +55,5 @@ jobs:
|
|||
run: |
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.2" --cuda_ver="11.8"
|
||||
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.2" --cuda_ver="11.8"
|
||||
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8"
|
||||
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8"
|
||||
|
|
|
@ -18,7 +18,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -17,6 +17,8 @@ jobs:
|
|||
unit-tests:
|
||||
runs-on: [self-hosted, nvidia, cu111, p40]
|
||||
|
||||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ jobs:
|
|||
unit-tests:
|
||||
runs-on: [self-hosted, nvidia, cu111, v100]
|
||||
|
||||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu117, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -21,7 +21,7 @@ jobs:
|
|||
unit-tests:
|
||||
strategy:
|
||||
matrix:
|
||||
pyVersion: ["3.6", "3.7", "3.8", "3.9", "3.10"]
|
||||
pyVersion: ["3.7", "3.8", "3.9", "3.10"]
|
||||
fail-fast: false
|
||||
|
||||
runs-on: ubuntu-20.04
|
||||
|
@ -29,7 +29,7 @@ jobs:
|
|||
image: deepspeed/gh-builder:py${{ matrix.pyVersion }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: environment
|
||||
run: |
|
||||
|
|
|
@ -11,7 +11,7 @@ jobs:
|
|||
environment: release-env
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: "master"
|
||||
- id: setup-venv
|
||||
|
@ -35,7 +35,7 @@ jobs:
|
|||
run: |
|
||||
python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }}
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v4
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GH_PAT }}
|
||||
add-paths: |
|
||||
|
|
|
@ -22,7 +22,7 @@ runs:
|
|||
- id: set-env-vars
|
||||
run: |
|
||||
echo TEST_DATA_DIR=/blob/ >> $GITHUB_ENV
|
||||
echo TRANSFORMERS_CACHE=/blob/transformers_cache/ >> $GITHUB_ENV
|
||||
echo HF_HOME=/blob/hf_home/ >> $GITHUB_ENV
|
||||
echo TORCH_EXTENSIONS_DIR=./torch-extensions/ >> $GITHUB_ENV
|
||||
echo TORCH_CACHE=/blob/torch_cache/ >> $GITHUB_ENV
|
||||
echo HF_DATASETS_CACHE=/blob/datasets_cache/ >> $GITHUB_ENV
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
name: xpu-max1100
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/xpu-max1100.yml"
|
||||
- "accelerator/xpu_accelerator.py"
|
||||
- "accelerator/abstract_accelerator.py"
|
||||
- "accelerator/cpu_accelerator.py"
|
||||
- "accelerator/real_accelerator.py"
|
||||
- "csrc/xpu/**"
|
||||
- "deepspeed/runtime/engine.py"
|
||||
- "deepspeed/runtime/bf16_optimizer.py"
|
||||
- "deepspeed/runtime/zero/stage_1_and_2.py"
|
||||
- "deepspeed/runtime/zero/stage3.py"
|
||||
- "deepspeed/runtime/zero/partition_parameters.py"
|
||||
- "deepspeed/runtime/zero/partitioned_param_coordinator.py"
|
||||
- "deepspeed/runtime/zero/parameter_offload.py"
|
||||
- "deepspeed/runtime/pipe/engine.py"
|
||||
- "deepspeed/runtime/utils.py"
|
||||
- "opbuilder/xpu/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
runs-on: [self-hosted, intel, xpu]
|
||||
container:
|
||||
image: intel/intel-extension-for-pytorch:2.1.30-xpu
|
||||
ports:
|
||||
- 80
|
||||
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Check container state
|
||||
shell: bash
|
||||
run: |
|
||||
ldd --version
|
||||
python -c "import torch; print('torch:', torch.__version__, torch)"
|
||||
python -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
|
||||
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
pip install py-cpuinfo
|
||||
pip install .[dev,autotuning]
|
||||
ds_report
|
||||
python -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
|
||||
|
||||
- name: Python environment
|
||||
run: |
|
||||
pip list
|
||||
|
||||
- name: Unit tests
|
||||
run: |
|
||||
pip install pytest pytest-timeout tabulate tensorboard wandb
|
||||
export ONEAPI_ROOT=/opt/intel/oneapi/redist
|
||||
export FI_PROVIDER_PATH=$ONEAPI_ROOT/opt/mpi/libfabric/lib/prov
|
||||
export LD_LIBRARY_PATH=$ONEAPI_ROOT/opt/mpi/libfabric/lib:$LD_LIBRARY_PATH
|
||||
export LD_LIBRARY_PATH=$ONEAPI_ROOT/lib:$LD_LIBRARY_PATH
|
||||
cd tests/unit
|
||||
pytest --verbose accelerator/*
|
||||
pytest --verbose autotuning/*
|
||||
pytest --verbose checkpoint/test_reshape_checkpoint.py
|
||||
pytest --verbose checkpoint/test_moe_checkpoint.py
|
||||
pytest --verbose checkpoint/test_shared_weights.py
|
||||
pytest --verbose launcher/test_ds_arguments.py launcher/test_run.py
|
||||
pytest --verbose moe/test_moe_tp.py
|
||||
pytest --verbose monitor/*
|
||||
pytest --verbose runtime/test_ds_config_model.py
|
||||
pytest --verbose runtime/pipe/test_pipe_schedule.py
|
||||
pytest --verbose runtime/zero/test_zero_config.py
|
||||
pytest --verbose runtime/zero/test_zero_tiled.py
|
||||
pytest --verbose runtime/zero/test_zeropp.py
|
||||
pytest --verbose runtime/test_autocast.py
|
||||
pytest --verbose runtime/test_data.py
|
||||
pytest --verbose runtime/test_runtime_utils.py
|
||||
pytest --verbose runtime/activation_checkpointing/*
|
||||
pytest --verbose runtime/utils/*
|
||||
pytest --verbose runtime/zero/test_zero_dynamic_class.py
|
|
@ -2,8 +2,8 @@ include *.txt README.md
|
|||
include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
|
||||
include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
|
||||
recursive-include requirements *.txt
|
||||
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
|
||||
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
|
||||
recursive-include deepspeed *.cpp *.h *.hpp *.cu *.hip *.tr *.cuh *.cc *.json
|
||||
recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc
|
||||
recursive-include op_builder *.py
|
||||
recursive-include benchmarks *.py
|
||||
recursive-include accelerator *.py
|
||||
|
|
18
README.md
|
@ -15,8 +15,9 @@
|
|||
## Latest News
|
||||
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>
|
||||
|
||||
* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md) [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)]
|
||||
* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
|
||||
* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
|
||||
* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
|
||||
* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html)
|
||||
* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp)
|
||||
* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)]
|
||||
|
@ -133,6 +134,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram
|
|||
| AMD | [![amd-mi200](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml) |
|
||||
| CPU | [![torch-latest-cpu](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml) [![cpu-inference](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml) |
|
||||
| Intel Gaudi | [![hpu-gaudi2](https://github.com/microsoft/DeepSpeed/actions/workflows/hpu-gaudi2.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/hpu-gaudi2.yml) |
|
||||
| Intel XPU | [![xpu-max1100](https://github.com/microsoft/DeepSpeed/actions/workflows/xpu-max1100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/xpu-max1100.yml) |
|
||||
| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) |
|
||||
| Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml) |
|
||||
| Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) |
|
||||
|
@ -158,11 +160,12 @@ dynamically link them at runtime.
|
|||
## Contributed HW support
|
||||
* DeepSpeed now support various HW accelerators.
|
||||
|
||||
| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated |
|
||||
| ----------- | -------- | ---------------- | --------------------- | ------------------ |
|
||||
| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
|
||||
| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
|
||||
| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | No |
|
||||
| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated |
|
||||
|-------------|-------------------------------------|------------------| --------------------- |--------------------|
|
||||
| Huawei | Huawei Ascend NPU | npu | Yes | No |
|
||||
| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
|
||||
| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
|
||||
| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes |
|
||||
|
||||
## PyPI
|
||||
We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases.
|
||||
|
@ -268,6 +271,9 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
|
|||
30. Xiaoxia Wu, Haojun Xia, Stephen Youn, Zhen Zheng, Shiyang Chen, Arash Bakhtiari, Michael Wyatt, Reza Yazdani Aminabadi, Yuxiong He, Olatunji Ruwase, Leon Song, Zhewei Yao (2023) ZeroQuant(4+2): Redefining LLMs Quantization with a New FP6-Centric Strategy for Diverse Generative Tasks [arXiv:2312.08583](https://arxiv.org/abs/2312.08583)
|
||||
|
||||
31. Haojun Xia, Zhen Zheng, Xiaoxia Wu, Shiyang Chen, Zhewei Yao, Stephen Youn, Arash Bakhtiari, Michael Wyatt, Donglin Zhuang, Zhongzhu Zhou, Olatunji Ruwase, Yuxiong He, Shuaiwen Leon Song. (2024) FP6-LLM: Efficiently Serving Large Language Models Through FP6-Centric Algorithm-System Co-Design [arXiv:2401.14112](https://arxiv.org/abs/2401.14112)
|
||||
32. Sam Ade Jacobs, Masahiro Tanaka, Chengming Zhang, Minjia Zhang, Reza Yazdani Aminadabi, Shuaiwen Leon Song, Samyam Rajbhandari, Yuxiong He. (2024) [System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://dl.acm.org/doi/10.1145/3662158.3662806)
|
||||
33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC):
|
|||
def __init__(self):
|
||||
self._name = None
|
||||
self._communication_backend_name = None
|
||||
self._compile_backend = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_synchronized_device(self):
|
||||
|
@ -80,7 +81,7 @@ class DeepSpeedAccelerator(ABC):
|
|||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def initial_seed(self, seed):
|
||||
def initial_seed(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -287,3 +288,19 @@ class DeepSpeedAccelerator(ABC):
|
|||
@abc.abstractmethod
|
||||
def export_envs(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def visible_devices_envs(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_compile_backend(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_compile_backend(self, backend):
|
||||
...
|
||||
|
|
|
@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def __init__(self):
|
||||
self._name = 'cpu'
|
||||
self._compile_backend = "inductor"
|
||||
if oneccl_imported_p:
|
||||
self._communication_backend_name = 'ccl'
|
||||
else:
|
||||
|
@ -99,8 +100,8 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
|||
def manual_seed_all(self, seed):
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
return torch.initial_seed(seed)
|
||||
def initial_seed(self):
|
||||
return torch.initial_seed()
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return torch.default_generator
|
||||
|
@ -300,12 +301,14 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
|||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
||||
from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
|
||||
from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
|
||||
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
|
||||
|
||||
if class_name == "CCLCommBuilder":
|
||||
return CCLCommBuilder
|
||||
elif class_name == "ShareMemCommBuilder":
|
||||
return ShareMemCommBuilder
|
||||
elif class_name == "FusedAdamBuilder":
|
||||
return FusedAdamBuilder
|
||||
elif class_name == "CPUAdamBuilder":
|
||||
|
@ -320,3 +323,22 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def export_envs(self):
|
||||
return []
|
||||
|
||||
# TODO: cpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES
|
||||
def visible_devices_envs(self):
|
||||
return ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
for env in self.visible_devices_envs():
|
||||
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
||||
|
||||
def get_compile_backend(self):
|
||||
return self._compile_backend
|
||||
|
||||
def set_compile_backend(self, backend):
|
||||
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
||||
if backend in supported_backends:
|
||||
self._compile_backend = backend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
|
||||
|
|
|
@ -7,6 +7,7 @@ import functools
|
|||
import os
|
||||
import pkgutil
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from .abstract_accelerator import DeepSpeedAccelerator
|
||||
# During setup stage torch may not be installed, pass on no torch will
|
||||
|
@ -24,7 +25,8 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def __init__(self):
|
||||
self._name = 'cuda'
|
||||
self._communication_backend_name = 'nccl'
|
||||
self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'
|
||||
self._compile_backend = "inductor"
|
||||
if pynvml is None:
|
||||
self._init_pynvml()
|
||||
|
||||
|
@ -98,8 +100,8 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
def manual_seed_all(self, seed):
|
||||
return torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
return torch.cuda.initial_seed(seed)
|
||||
def initial_seed(self):
|
||||
return torch.cuda.initial_seed()
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return torch.cuda.default_generators[device_index]
|
||||
|
@ -360,3 +362,21 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def export_envs(self):
|
||||
return ['NCCL']
|
||||
|
||||
def visible_devices_envs(self):
|
||||
return ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
for env in self.visible_devices_envs():
|
||||
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
||||
|
||||
def get_compile_backend(self):
|
||||
return self._compile_backend
|
||||
|
||||
def set_compile_backend(self, backend):
|
||||
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
||||
if backend in supported_backends:
|
||||
self._compile_backend = backend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
|
||||
|
|
|
@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator):
|
|||
def __init__(self):
|
||||
self._name = 'hpu'
|
||||
self._communication_backend_name = 'hccl'
|
||||
self._compile_backend = "hpu_backend"
|
||||
try:
|
||||
import habana_frameworks.torch.hpu as hpu
|
||||
hpu.setDeterministic(True)
|
||||
|
@ -73,13 +74,13 @@ class HPU_Accelerator(DeepSpeedAccelerator):
|
|||
return self.hpu.random.get_rng_state()
|
||||
|
||||
def manual_seed(self, seed):
|
||||
self.hpu.random.manual_seed(seed)
|
||||
return self.hpu.random.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed):
|
||||
self.hpu.random.manual_seed_all(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
self.hpu.random.initial_seed(seed)
|
||||
def initial_seed(self):
|
||||
return self.hpu.random.initial_seed()
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return self.hpu.random.default_generators[device_index]
|
||||
|
@ -294,3 +295,21 @@ class HPU_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def export_envs(self):
|
||||
return []
|
||||
|
||||
def visible_devices_envs(self):
|
||||
return ['HABANA_VISIBLE_MODULES']
|
||||
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
for env in self.visible_devices_envs():
|
||||
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
||||
|
||||
def get_compile_backend(self):
|
||||
return self._compile_backend
|
||||
|
||||
def set_compile_backend(self, backend):
|
||||
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
||||
if backend in supported_backends:
|
||||
self._compile_backend = backend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
|
||||
|
|
|
@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator):
|
|||
def __init__(self):
|
||||
self._name = "mps"
|
||||
self._communication_backend_name = None
|
||||
self._compile_backend = "inductor"
|
||||
|
||||
def is_synchronized_device(self):
|
||||
return False
|
||||
|
@ -76,7 +77,7 @@ class MPS_Accelerator(DeepSpeedAccelerator):
|
|||
def seed(self):
|
||||
return torch.mps.seed()
|
||||
|
||||
def initial_seed(self, seed):
|
||||
def initial_seed(self):
|
||||
return
|
||||
|
||||
def default_generator(self, device_index):
|
||||
|
@ -258,3 +259,23 @@ class MPS_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def export_envs(self):
|
||||
return []
|
||||
|
||||
# TODO: mpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES
|
||||
def visible_devices_envs(self):
|
||||
# TODO: could not find visible devices env for mps
|
||||
return ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
for env in self.visible_devices_envs():
|
||||
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
||||
|
||||
def get_compile_backend(self):
|
||||
return self._compile_backend
|
||||
|
||||
def set_compile_backend(self, backend):
|
||||
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
||||
if backend in supported_backends:
|
||||
self._compile_backend = backend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
|
||||
|
|
|
@ -20,6 +20,7 @@ class NPU_Accelerator(DeepSpeedAccelerator):
|
|||
super().__init__()
|
||||
self._name = 'npu'
|
||||
self._communication_backend_name = 'hccl'
|
||||
self._compile_backend = "inductor"
|
||||
# dict that holds class name <--> class type mapping i.e.
|
||||
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
|
||||
# this dict will be filled at init stage
|
||||
|
@ -83,8 +84,8 @@ class NPU_Accelerator(DeepSpeedAccelerator):
|
|||
def manual_seed_all(self, seed):
|
||||
return torch.npu.manual_seed_all(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
return torch.npu.initial_seed(seed)
|
||||
def initial_seed(self):
|
||||
return torch.npu.initial_seed()
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return torch.npu.default_generators[device_index]
|
||||
|
@ -278,3 +279,21 @@ class NPU_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def export_envs(self):
|
||||
return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH']
|
||||
|
||||
def visible_devices_envs(self):
|
||||
return ['ASCEND_RT_VISIBLE_DEVICES']
|
||||
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
for env in self.visible_devices_envs():
|
||||
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
||||
|
||||
def get_compile_backend(self):
|
||||
return self._compile_backend
|
||||
|
||||
def set_compile_backend(self, backend):
|
||||
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
||||
if backend in supported_backends:
|
||||
self._compile_backend = backend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }")
|
||||
|
|
|
@ -7,6 +7,10 @@ import torch
|
|||
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
|
||||
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
|
||||
import functools
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
|
||||
class XPU_Accelerator(DeepSpeedAccelerator):
|
||||
|
@ -14,7 +18,9 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
|||
def __init__(self):
|
||||
self._name = 'xpu'
|
||||
self._communication_backend_name = 'ccl'
|
||||
self._compile_backend = "inductor"
|
||||
self.aligned_tensors = []
|
||||
self.class_dict = None
|
||||
|
||||
def is_synchronized_device(self):
|
||||
return False
|
||||
|
@ -72,8 +78,8 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
|||
def manual_seed_all(self, seed):
|
||||
return torch.xpu.manual_seed_all(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
return torch.xpu.initial_seed(seed)
|
||||
def initial_seed(self):
|
||||
return torch.xpu.initial_seed()
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return torch.xpu.default_generators[device_index]
|
||||
|
@ -157,7 +163,10 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
|||
return
|
||||
|
||||
def lazy_call(self, callback):
|
||||
return torch.xpu.lazy_init._lazy_call(callback)
|
||||
if hasattr(torch.xpu, "_lazy_call"):
|
||||
return torch.xpu._lazy_call(callback)
|
||||
else:
|
||||
return torch.xpu.lazy_init._lazy_call(callback)
|
||||
|
||||
def communication_backend_name(self):
|
||||
return self._communication_backend_name
|
||||
|
@ -190,37 +199,37 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
@property
|
||||
def BFloat16Tensor(self):
|
||||
return torch.xpu.BFloat16Tensor
|
||||
return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name)
|
||||
|
||||
@property
|
||||
def ByteTensor(self):
|
||||
return torch.xpu.ByteTensor
|
||||
return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name)
|
||||
|
||||
@property
|
||||
def DoubleTensor(self):
|
||||
return torch.xpu.DoubleTensor
|
||||
return functools.partial(torch.tensor, dtype=torch.double, device=self._name)
|
||||
|
||||
@property
|
||||
def FloatTensor(self):
|
||||
return torch.xpu.FloatTensor
|
||||
return functools.partial(torch.tensor, dtype=torch.float, device=self._name)
|
||||
|
||||
@property
|
||||
def HalfTensor(self):
|
||||
return torch.xpu.HalfTensor
|
||||
return functools.partial(torch.tensor, dtype=torch.half, device=self._name)
|
||||
|
||||
@property
|
||||
def IntTensor(self):
|
||||
return torch.xpu.IntTensor
|
||||
return functools.partial(torch.tensor, dtype=torch.int, device=self._name)
|
||||
|
||||
@property
|
||||
def LongTensor(self):
|
||||
return torch.xpu.LongTensor
|
||||
return functools.partial(torch.tensor, dtype=torch.long, device=self._name)
|
||||
|
||||
def pin_memory(self, tensor, align_bytes=1):
|
||||
if align_bytes == 1:
|
||||
return tensor.pin_memory(device=self.current_device_name())
|
||||
elif align_bytes == 0:
|
||||
from intel_extension_for_deepspeed.op_builder.async_io import AsyncIOBuilder
|
||||
from deepspeed.ops.op_builder.xpu import AsyncIOBuilder
|
||||
self.aio_handle = AsyncIOBuilder().load().aio_handle(128 * 1024, 8, False, False, False)
|
||||
aligned_t = self.aio_handle.new_cpu_locked_tensor(tensor.numel(), tensor)
|
||||
aligned_t = aligned_t[:tensor.numel()].copy_(tensor)
|
||||
|
@ -252,33 +261,29 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
|||
else:
|
||||
return False
|
||||
|
||||
def _lazy_init_class_dict(self):
|
||||
if self.class_dict:
|
||||
return
|
||||
|
||||
op_builder_module = importlib.import_module(self.op_builder_dir())
|
||||
|
||||
# get op builder class from op_builder/xpu/__init__.py
|
||||
self.class_dict = {}
|
||||
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
|
||||
self.class_dict[class_name] = class_obj
|
||||
|
||||
# create an instance of op builder and return, name specified by class_name
|
||||
def create_op_builder(self, op_name):
|
||||
builder_class = self.get_op_builder(op_name)
|
||||
if builder_class != None:
|
||||
return builder_class()
|
||||
return None
|
||||
def create_op_builder(self, class_name):
|
||||
builder_class = self.get_op_builder(class_name)
|
||||
return builder_class()
|
||||
|
||||
# return an op builder class, name specified by class_name
|
||||
def get_op_builder(self, class_name):
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
||||
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
|
||||
|
||||
if class_name == "AsyncIOBuilder":
|
||||
return AsyncIOBuilder
|
||||
elif class_name == "CPUAdagradBuilder":
|
||||
return CPUAdagradBuilder
|
||||
elif class_name == "CPUAdamBuilder":
|
||||
return CPUAdamBuilder
|
||||
elif class_name == "FusedAdamBuilder":
|
||||
return FusedAdamBuilder
|
||||
self._lazy_init_class_dict()
|
||||
if class_name in self.class_dict:
|
||||
return self.class_dict[class_name]
|
||||
else:
|
||||
return None
|
||||
return self.class_dict['NotImplementedBuilder']
|
||||
|
||||
def build_extension(self):
|
||||
try:
|
||||
|
@ -289,3 +294,21 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def export_envs(self):
|
||||
return []
|
||||
|
||||
def visible_devices_envs(self):
|
||||
return ['ZE_AFFINITY_MASK']
|
||||
|
||||
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
||||
for env in self.visible_devices_envs():
|
||||
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
||||
|
||||
def get_compile_backend(self):
|
||||
return self._compile_backend
|
||||
|
||||
def set_compile_backend(self, backend):
|
||||
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
||||
if backend in supported_backends:
|
||||
self._compile_backend = backend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
@echo off
|
||||
python "%~dp0\ds" %*
|
|
@ -0,0 +1,2 @@
|
|||
@echo off
|
||||
python "%~dp0\ds_report" %*
|
|
@ -0,0 +1,273 @@
|
|||
<div align="center">
|
||||
|
||||
# DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training
|
||||
|
||||
</div>
|
||||
|
||||
<img src="./media/image1.png" style="width:6.5in;height:3.65625in" />
|
||||
|
||||
To cite DeepSpeed Universal Checkpoint, please cite our [arxiv report](https://arxiv.org/abs/2406.18820):
|
||||
|
||||
```
|
||||
@article{lian2024-ucp,
|
||||
title={Universal Checkpointing: Efficient and Flexible Checkpointing for
|
||||
Large Scale Distributed Training},
|
||||
author={Xinyu Lian and Sam Ade Jacobs and Lev Kurilenko and Masahiro Tanaka
|
||||
and Stas Bekman and Olatunji Ruwase and Minjia Zhang},
|
||||
journal={arxiv preprint arxiv:406.18820},
|
||||
year={2024},
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
# Introduction
|
||||
|
||||
Checkpointing is a crucial technique for reducing the cost of training
|
||||
machine learning models, as it enables saving the model state during the process.
|
||||
This way, if the system fails, the training can resume from the most recent checkpoint
|
||||
instead of from the beginning. Additionally, checkpointing allows for
|
||||
evaluating the model performance at various stages of training, which
|
||||
facilitates hyperparameter tuning and finetuning for different and
|
||||
varied downstream tasks.
|
||||
|
||||
However, there are challenges in the design, implementation and usage of
|
||||
checkpointing especially in distributed training and finetuning
|
||||
scenarios. Parallel training methods such as ZeRO data parallelism (ZeRO-DP),
|
||||
pipeline parallelism (PP), tensor parallelism (TP) and sequence
|
||||
parallelism (SP) are popular technologies for accelerating LLMs training.
|
||||
However, elastic and flexible composition of these different parallelism
|
||||
topologies with checkpointing is not currently available, in part, because
|
||||
these techniques shard model and/or optimizer states making it difficult to
|
||||
resume training with a checkpoint that was created on a different number of GPUs or
|
||||
accelerators.
|
||||
|
||||
In this release, we are excited to introduce DeepSpeed Universal
|
||||
Checkpointing (*UCP*), a most comprehensive solution to the problem of
|
||||
distributed checkpointing. *UCP* enables efficient checkpoint creation
|
||||
while providing the flexibility of resuming on arbitrary parallelism
|
||||
strategies and hardware configurations. *UCP* also unlocks unprecedented
|
||||
capabilities for large-scale training such as improved resilience to
|
||||
hardware failures through continued training on remaining healthy
|
||||
hardware, and reduced training time through opportunistic exploitation
|
||||
of elastic capacity.
|
||||
|
||||
In summary, this release of *UCP* unlocks the following capabilities:
|
||||
|
||||
- Flexible checkpoints reshape along any of the training parallelism
|
||||
techniques (i.e., PP, TP, DP, ZeRO-DP, SP, MoE)
|
||||
|
||||
- Elastic resource management, scale up or scale down of training and
|
||||
finetuning accelerator resources
|
||||
|
||||
- Real world examples with support for multiple commercial-scale models
|
||||
(i.e., BLOOM, Megatron GPT, LLAMA, Microsoft Phi)
|
||||
|
||||
# Core Design
|
||||
|
||||
The key insight of DeepSpeed *UCP* is the selection of the optimal
|
||||
representation in each phase of the checkpointing life cycle:
|
||||
distributed representation for saving, and consolidated representation
|
||||
for loading. This is achieved using two key mechanisms. First, the
|
||||
universal checkpoint format, which consists of a consolidated
|
||||
representation of each model parameter, and metadata for mapping
|
||||
parameter fragments to the ranks of an arbitrary parallel training
|
||||
configuration. Second, the universal checkpoint language, a simple but
|
||||
powerful and robust specification language for converting distributed
|
||||
checkpoints into the universal checkpoint format.
|
||||
|
||||
## Universal Checkpoint Format
|
||||
|
||||
<img src="./media/image2.png" style="width:6.5in;height:3.42153in" />
|
||||
|
||||
Figure 1: UCP overview: top row and bottom row are Source and Target
|
||||
parallelism configurations respectively. The middle row shows UCP as
|
||||
an intermediate format of translating from Source to Target.
|
||||
|
||||
Figure 1 shows high level schematic description of *UCP* conversion
|
||||
process and format. Conversion starts with top block of checkpointing in
|
||||
any parallel format e.g, DP, TP, PP, SP. Saving in the native format of parallel training avoids any overhead of
|
||||
consolidating into a single global checkpoint. To ensure that
|
||||
a checkpoint saved in one parallel configuration (herein called *Source*) can be
|
||||
easily converted and loaded for continuous training in another parallel configuration (herein called *Target*),
|
||||
we introduce the idea of atomic checkpoint as an intermediate format.
|
||||
|
||||
The concept of atomic checkpoint is central to *UCP*. These are
|
||||
fine-grained files containing the consolidated representation of each
|
||||
model parameter, along with optimizer states. The atomic checkpoint
|
||||
format is useful for three reasons. First, the atomic representation of
|
||||
checkpoints decouples the dependencies of distributed checkpoints and
|
||||
specific parallelism techniques and hardware configurations. As such,
|
||||
one does not need to implement individual converters for each *Source*
|
||||
and *Target* pair. Instead, *UCP* can act as a common interchange format
|
||||
between different distributed training techniques, which then can be
|
||||
easily transformed into other distributed training strategies, as shown
|
||||
in Fig 2. By keeping the consolidated representation of each model
|
||||
parameter, *UCP* enables easy splitting and flexible mapping of model states
|
||||
or fragmented states to different GPUs on a parameter-by-parameter
|
||||
basis, effectively reducing the working memory needed to load large
|
||||
model checkpoints. Second, the *UCP* conversion happens lazily and
|
||||
on-demand, e.g., when a training process detects a change of parallelism
|
||||
technique and hardware configuration. In other words, the existing
|
||||
distributed checkpoint saving logic does not need any change. Third, the
|
||||
structure of the *UCP* also makes it easy to handle advanced techniques
|
||||
in distributed training, such as mixed-precision training. In practice,
|
||||
researchers and practitioners may switch between fp16 and bfloat16 mixed
|
||||
precision training. By keeping the fp32 weight/optimizer values, the
|
||||
training can resume either with fp16 or bfloat16.
|
||||
|
||||
## Universal Checkpoint Language
|
||||
|
||||
<img src="./media/flowchart.png" style="width:6.5in;height:2.22222in" />
|
||||
|
||||
Figure 2: UCP language helps transform distributed checkpoints into the
|
||||
UCP format and load UCP checkpoints based on the Target parallel
|
||||
technique and new hardware configuration.
|
||||
|
||||
|
||||
While *UCP* provides a common interface for different parallelism
|
||||
strategies, the development of transformation from arbitrary distributed
|
||||
checkpoints to *UCP* can still incur a high engineering and
|
||||
implementation cost. This is because the number of distributed checkpoint files
|
||||
and their contents can vary across the different parallel training techniques.
|
||||
|
||||
To tackle this challenge, *UCP* provides *UCP* language, which is a
|
||||
simple but powerful specification language for converting a distributed checkpoint
|
||||
into the atomic checkpoint format, described in previous
|
||||
section. *UCP* does this in two ways. First, it provides a declarative
|
||||
system with pre-defined *parameter patterns*, which cover a wide range
|
||||
of parallelism strategies for model states. Parameter patterns contain
|
||||
runtime information about how a parameter is partitioned across GPUs.
|
||||
For instance, *nopattern* means that a parameter is uniquely associated
|
||||
with a GPU rank, which is the most common pattern seen in techniques
|
||||
such as ZeRO-1/2 and PP (see our technical report for a completed list
|
||||
of currently supported parameter patterns). Second, *UCP* language
|
||||
provides a set of common operators that facilitate the transformation of
|
||||
distributed checkpoints into consolidated atomic checkpoints. At a
|
||||
high-level, as illustrated in Figure 3, *UCP* language is invoked when
|
||||
support for a new *Target* is needed or the hardware
|
||||
configuration changes. It first transforms distributed checkpoints into
|
||||
the *UCP* format. It then loads the *UCP* checkpoints based on the
|
||||
*Target* parallel technique and new hardware configuration.
|
||||
|
||||
# Key Results
|
||||
|
||||
We evaluate *UCP* through a series of experiments on training LLMs. We
|
||||
focus on the decoder-only Transformers: an architecture chosen due to
|
||||
its state-of-the-art performance. Some of the largest models are also
|
||||
decoder-based, making flexible and efficient checkpointing especially
|
||||
important. In this blog, we present results of correctness verification
|
||||
across different models and parallel strategies. For more results on
|
||||
parallel efficiency analysis, detailed system and model architectures
|
||||
and training hyperparameters, please see our technical report referenced
|
||||
above.
|
||||
|
||||
*UCP* provides flexible checkpointing from a *Source* parallelism
|
||||
strategy to a different *Target* with different hardware configurations.
|
||||
To verify this capability, we conduct correctness tests of *UCP* with
|
||||
two groups of experiments.
|
||||
|
||||
## Single Source to Multiple Target
|
||||
|
||||
<img src="./media/image4.png" style="width:4.85477in;height:4in" />
|
||||
|
||||
Figure 3: Training curves of loading UCP checkpoints into different
|
||||
Target at iteration 101 with various GPU counts and parallelism
|
||||
strategies
|
||||
|
||||
To test if UCP allows resuming training with different parallelism
|
||||
strategies and hardware configuration, we first train the GPT-3 model
|
||||
using a configuration of TP=2, PP=2, DP=2 (ZeRO-1), and SP=1. Due to
|
||||
constraints in time and resources, we limited the experiment to the
|
||||
first 200 iterations. We convert the checkpoints saved at the 100th
|
||||
iteration to *UCP* checkpoints and resume training with these *UCP*
|
||||
checkpoints using different GPU counts and parallelism strategies. We
|
||||
record the LM loss (average losses across the data parallel group) for
|
||||
each iteration. Figure 3 illustrates that the training can be seamlessly
|
||||
resumed with *UCP* checkpoints using different *Target* parallelism
|
||||
strategies, achieving consistent convergence if the training were to
|
||||
continue with the *Source* strategy.
|
||||
|
||||
## Multiple Source to Single Target
|
||||
|
||||
<img src="./media/image5.png" style="width:4.85477in;height:4in" />
|
||||
|
||||
Figure 4: Training curves of transforming different Source parallelism
|
||||
strategies at iteration 100 to UCP and loading UCP with a different
|
||||
Target.
|
||||
|
||||
Figure 4 shows the training curves from multiple *Source* configurations
|
||||
to a single *Target*. Given a fixed random seed, we first train the
|
||||
GPT-3 model using different *Source* configurations. We then convert
|
||||
their distributed checkpoints saved at the 100th iteration to *UCP*
|
||||
checkpoints and resume training with a configuration of TP=2, PP=2,
|
||||
DP=1, and SP=1. The results show that regardless of the different
|
||||
*Source* configurations, their checkpoints can all be converted into
|
||||
*UCP* and resume training with a different configuration. Most
|
||||
importantly, the resumed training curves match the curves from the
|
||||
*Source* at iterations 101--200. These results validate the
|
||||
effectiveness of *UCP* of converting an arbitrary configuration to a
|
||||
different configuration for resumed training.
|
||||
|
||||
## Varying Model Architectures
|
||||
|
||||
*UCP* is model architecture agnostic. As such, it is not only compatible
|
||||
with GPT models but also flexible enough to support various other model
|
||||
architectures and sizes. Figures 5, 6 and 7 show the training
|
||||
convergence for LLaMA 7B, BLOOM 176B, and a variant of Mixtral-7x8B MoE,
|
||||
when resuming from *UCP* at the middle of training with new parallelism
|
||||
strategies. These figures show that training is seamlessly resumed with
|
||||
*UCP*, achieving consistent convergence that aligns with the initial
|
||||
training phase across these diverse models. These results suggest that
|
||||
*UCP* is quite flexible for various model architectures and sizes.
|
||||
|
||||
<img src="./media/image6.png" style="width:5in;height:4in"
|
||||
alt="A graph of training step Description automatically generated" />
|
||||
|
||||
Figure 5: Training curve with LLaMA model architecture. Source is
|
||||
TP=PP=DP=2. Training is resumed at iteration 101 with new Targets
|
||||
TP=DP=2, PP=1 and TP=PP=2, DP=1
|
||||
|
||||
<img src="./media/image7.png" style="width:5in;height:4in"
|
||||
alt="A graph with numbers and lines Description automatically generated" />
|
||||
|
||||
Figure 6: Training curve of BLOOM model architecture. Source is TP=2,
|
||||
PP=24, DP=8. Training is resumed at iteration 94767 with a new Targets
|
||||
TP=2, DP=4, PP=24.
|
||||
|
||||
<img src="./media/image8.png" style="width:5in;height:4in"
|
||||
alt="A graph of training step Description automatically generated" />
|
||||
|
||||
Figure 7: Training curve with a variant of the Mixtral-MoE model
|
||||
architecture. Source is TP=1, PP=2, DP=4. Training is resumed at
|
||||
iteration 501 with a new Target TP=PP=DP=2.
|
||||
|
||||
# General Availability of DeepSpeed Universal Checkpoint
|
||||
|
||||
We are excited to release DeepSpeed Universal Checkpoint. DeepSpeed
|
||||
Universal Checkpoint is available in DeepSpeed versions >=
|
||||
[0.14.4](https://github.com/microsoft/DeepSpeed/releases/tag/v0.14.4),
|
||||
has been fully integrated with [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed) ([commit c3a13be](https://github.com/microsoft/Megatron-DeepSpeed/commit/c3a13be721da0d0de16c338d0d665b0f7d13d14f)).
|
||||
Detailed tutorial on usage is available on
|
||||
[DeepSpeed tutorial page](https://www.deepspeed.ai/tutorials/universal-checkpointing/).
|
||||
|
||||
We welcome contributions and collaboration from the broader open-source
|
||||
community. DeepSpeed Universal Checkpoint is part of the bigger
|
||||
DeepSpeed ecosystem of large-scale AI training and inference. For more
|
||||
details on all DeepSpeed technologies and innovations, please visit our
|
||||
[website]((https://www.deepspeed.ai/)) and follow us
|
||||
on X, formerly Twitter, ([English](https://twitter.com/MSFTDeepSpeed),
|
||||
[Japanese](https://twitter.com/MSFTDeepSpeedJP))
|
||||
and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed).
|
||||
|
||||
# Acknowledgements and Contributions
|
||||
We thank the collaboration of University of Illinois at Urbana-Champaign,
|
||||
Statosphere, and Intel Habana.
|
||||
|
||||
Contributions:
|
||||
Xinyu Lian $^1$, Sam Ade Jacobs $^2$, Lev Kurilenko $^2$, Masahiro Tanaka $^2$,
|
||||
Stas Bekman $^3$, Olatunji Ruwase $^2$, Minjia Zhang $^1$, Moshe Island $^4$
|
||||
|
||||
1: University of Illinois at Urbana-Champaign
|
||||
2: Microsoft
|
||||
3: StasoSphere
|
||||
4: Intel Habana
|
|
@ -0,0 +1,124 @@
|
|||
|
||||
<div align="center">
|
||||
|
||||
# DeepSpeed通用检查点:用于大规模分布式训练的高效灵活检查点系统
|
||||
|
||||
</div>
|
||||
|
||||
<img src="../media/image1.png" style="width:6.5in;height:3.65625in" />
|
||||
|
||||
要引用DeepSpeed通用检查点,请引用我们的[arxiv报告](https://arxiv.org/abs/2406.18820):
|
||||
|
||||
```
|
||||
@article{lian2024-ucp,
|
||||
title={Universal Checkpointing: Efficient and Flexible Checkpointing for
|
||||
Large Scale Distributed Training},
|
||||
author={Xinyu Lian and Sam Ade Jacobs and Lev Kurilenko and Masahiro Tanaka
|
||||
and Stas Bekman and Olatunji Ruwase and Minjia Zhang},
|
||||
journal={arxiv preprint arxiv:406.18820},
|
||||
year={2024},
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
# 引言
|
||||
|
||||
检查点是降低训练大型语言模型成本的关键技术,它使我们在训练过程中可以保存模型状态。这样,如果训练失败,训练可以从最后保存的点继续,而不是从头开始。此外,检查点还允许在训练的不同阶段评估模型性能,从而便于进行超参数调整以及针对不同和多样化下游任务的微调。
|
||||
|
||||
然而,在分布式训练和微调场景中设计、实施和使用检查点存在困难。ZeRO数据并行(ZeRO-DP)、流水线并行(PP)、张量并行(TP)和序列并行(SP)等方法是加速大型语言模型训练的出色技术,但与传统的默认(Torch)保存和加载检查点机制不兼容。此外,目前尚无技术支持将这些不同的并行拓扑与检查点灵活组合,部分原因是这些技术将模型和/或优化器状态分片,使得在不同GPU或加速器数量上创建的检查点难以用于恢复训练。
|
||||
|
||||
在此,我们很高兴地发布DeepSpeed通用检查点(*UCP*),这是解决分布式检查点问题的最全面的解决方案。*UCP*在高效创建检查点的同时,提供了在任意并行策略和硬件配置上恢复的灵活性。*UCP*还解锁了大规模训练的前所未有的能力,例如通过在剩余健康硬件上继续训练来提高对硬件故障的抵抗力,以及通过机会性利用弹性容量来减少训练时间。
|
||||
|
||||
简单来说,当前版本的*UCP*解锁了以下功能:
|
||||
|
||||
- 灵活的检查点可沿任何训练并行技术(即PP、TP、DP、ZeRO-DP、SP、MoE)重塑训练
|
||||
|
||||
- 弹性资源管理,在训练和微调中随意增加或减少硬件资源
|
||||
|
||||
- 支持多种商业规模模型的真实世界用例(例如BLOOM、Megatron GPT、LLAMA、Microsoft Phi)
|
||||
|
||||
# 核心设计
|
||||
|
||||
DeepSpeed *UCP*的关键洞察是在检查点生命周期的每个阶段选择最佳表示:分布式表示用于保存,合并表示用于加载。这通过两个关键机制实现。首先,通用检查点格式,它包括每个模型参数的合并表示和用于将参数片段映射到任意模型并行配置的训练级别的元数据。其次,通用检查点语言,这是一个简单但强大且健壮的规范语言,用于将分布式检查点转换为通用检查点格式。
|
||||
|
||||
## 通用检查点格式
|
||||
|
||||
<img src="../media/image2.png" style="width:6.5in;height:3.42153in" />
|
||||
|
||||
图1:UCP概述:顶部行和底部行分别为源并行配置和目标并行配置。中间行显示UCP作为从源到目标的转换中介块。
|
||||
|
||||
图1显示了*UCP*转换过程和格式的整体概念性描述。转换从任何并行策略格式的检查点顶部块开始。允许以训练的本地格式保存消除了可能因同步全局检查点保存而产生的任何开销。为确保保存的检查点(称为*源*)可以轻松转换并加载到任何并行策略以进行连续训练(称为*目标*),我们引入了作为中介块的原子检查点格式的概念。
|
||||
|
||||
原子检查点是*UCP*的核心概念。这些是包含每个模型参数的合并表示及其优化器状态的细粒度文件。原子检查点格式有三个用途。首先,原子检查点的表示解除了分布式检查点与特定并行技术和硬件配置的依赖。因此,无需为每个*源*到*目标*实现单独的转换器。相反,*UCP*可以充当不同分布式训练技术之间的通用交换格式,然后可以轻松地转换为其他分布式训练策略,如图2所示。通过保持每个模型参数的合并表示,*UCP*可以轻松地将模型状态或片段状态拆分并灵活地映射到不同GPU上,有效减少加载大型模型检查点所需的工作内存。其次,*UCP*转换是懒惰和按需进行的,例如,当训练过程检测到并行技术和硬件配置的变化时。换句话说,现有的分布式检查点保存逻辑不需要任何改变。第三,*UCP*的结构还易于处理分布式训练中的高级技术,例如混合精度训练。在实践中,研究人员和从业者可能在fp16和bfloat16混合精度训练之间切换。通过保持fp32的权重/优化器值,训练可以继续使用fp16或bfloat16恢复。
|
||||
|
||||
## 通用检查点语言
|
||||
|
||||
<img src="../media/flowchart.png" style="width:6.5in;height:2.22222in" />
|
||||
|
||||
图2:UCP语言帮助将分布式检查点转换为UCP格式,并根据目标并行技术和新硬件配置加载UCP检查点。
|
||||
|
||||
|
||||
虽然*UCP*为不同的并行策略提供了一个公共接口,但从任意分布式检查点到*UCP*的转换仍然可能具有不菲的工程和实施成本。这是因为分布式训练中的每个GPU都调用一个持久方法(例如,在PyTorch中使用torch.save())将其拥有的GPU模型状态保存到磁盘上的检查点文件中,而每个检查点的具体内容在不同技术之间会有所不同。
|
||||
|
||||
为了应对这一挑战,*UCP*提供了*UCP*语言,这是一个简单但强大的规范语言,用于将几种类型的分布式检查点转换为前一节中描述的通用格式。*UCP*以两种方式实现这一点。首先,它提供了一个具有预定义*参数模式*的声明式系统,这些模式涵盖了模型状态的广泛并行
|
||||
|
||||
策略。参数模式包含有关参数如何在GPU之间分区的运行时信息。例如,*nopattern*表示一个参数与某个GPU唯一相关,这是ZeRO-1/2和PP等技术中最常见的模式(参见我们的技术报告,以获得当前支持的参数模式完整列表)。其次,*UCP*语言提供了一组常见操作符,以便将分布式检查点转换为合并的原子检查点。从高层次来看,如图3所示,当需要新的*目标*并行技术或硬件配置发生变化时,将调用*UCP*语言。它首先将分布式检查点转换为*UCP*格式。然后根据*目标*并行技术和新硬件配置加载*UCP*检查点。
|
||||
|
||||
# 关键结果
|
||||
|
||||
我们通过一系列实验评估*UCP*,专注于仅解码器的Transformers架构,这是由于其最先进的性能。一些最大的模型也是基于解码器的,这使得灵活高效的检查点尤为重要。在本博客中,我们展示了在不同模型和并行策略下正确性验证的结果。有关并行效率分析、详细的系统和模型架构以及训练超参数的更多结果,请参阅上面引用的技术报告。
|
||||
|
||||
*UCP*提供了从一个*源*并行策略到不同的*目标*和不同硬件配置的灵活检查点。为验证这一能力,我们进行了正确性测试的两组实验。
|
||||
|
||||
## 单源到多目标
|
||||
|
||||
<img src="../media/image4.png" style="width:4.85477in;height:4in" />
|
||||
|
||||
图3:在第101次迭代时使用不同目标加载UCP检查点的训练曲线,具有不同GPU数量和并行策略
|
||||
|
||||
为测试UCP是否允许使用不同并行策略和硬件配置恢复训练,我们首先使用TP=2、PP=2、DP=2(ZeRO-1)和SP=1的配置训练GPT-3模型。由于时间和资源的限制,我们将实验限制在前200次迭代。我们将在第100次迭代保存的检查点转换为*UCP*检查点,并使用不同GPU数量和并行策略恢复训练。我们记录了每次迭代的LM损失(数据并行组的平均损失)。图3显示,训练可以使用不同的*目标*并行策略无缝地使用*UCP*检查点恢复,如果训练继续使用*源*策略,将实现一致的收敛。
|
||||
|
||||
## 多源到单目标
|
||||
|
||||
<img src="../media/image5.png" style="width:4.85477in;height:4in" />
|
||||
|
||||
图4:在第100次迭代将不同源并行策略转换为UCP并加载UCP的训练曲线,具有不同的目标。
|
||||
|
||||
图4显示了从多个*源*配置到单一*目标*的训练曲线。在固定随机种子的情况下,我们首先使用不同的*源*配置训练GPT-3模型。然后我们将它们在第100次迭代保存的分布式检查点转换为*UCP*检查点,并使用TP=2、PP=2、DP=1和SP=1的配置恢复训练。结果显示,无论不同的*源*配置如何,它们的检查点都可以转换为*UCP*并使用不同的配置恢复训练。最重要的是,恢复的训练曲线与第101--200次迭代的*源*曲线匹配。这些结果验证了*UCP*将任意配置转换为不同配置以恢复训练的有效性。
|
||||
|
||||
## 不同模型架构的变化
|
||||
|
||||
*UCP*与模型架构无关。因此,它不仅与GPT模型兼容,而且足够灵活,可以支持各种其他模型架构和大小。图5、6和7显示了使用新并行策略从*UCP*中恢复训练时的训练收敛情况。这些图表显示,训练可以使用*UCP*无缝恢复,实现与初始训练阶段一致的收敛,这与这些不同模型相符。这些结果表明,*UCP*对于各种模型架构和大小都非常灵活。
|
||||
|
||||
<img src="../media/image6.png" style="width:5in;height:4in"
|
||||
alt="A graph of training step Description automatically generated" />
|
||||
|
||||
图5:使用LLaMA模型架构的训练曲线。源是TP=PP=DP=2。训练在第101次迭代时使用新目标TP=DP=2, PP=1和TP=PP=2, DP=1恢复
|
||||
|
||||
<img src="../media/image7.png" style="width:5in;height:4in"
|
||||
alt="A graph with numbers and lines Description automatically generated" />
|
||||
|
||||
图6:使用BLOOM模型架构的训练曲线。源是TP=2, PP=24, DP=8。训练在第94767次迭代时使用新目标TP=2, DP=4, PP=24恢复。
|
||||
|
||||
<img src="../media/image8.png" style="width:5in;height:4in"
|
||||
alt="A graph of training step Description automatically generated" />
|
||||
|
||||
图7:使用Mixtral-MoE模型架构变种的训练曲线。源是TP=1, PP=2, DP=4。训练在第501次迭代时使用新目标TP=PP=DP=2恢复。
|
||||
|
||||
# DeepSpeed通用检查点的普遍可用性
|
||||
|
||||
我们很高兴发布DeepSpeed通用检查点。DeepSpeed通用检查点已与Megatron-DeepSpeed的重构版本完全集成,并可通过DeepSpeed和Megatron-DeepSpeed的GitHub仓库访问。详细的使用教程可在[DeepSpeed教程页面](https://www.deepspeed.ai/tutorials/universal-checkpointing/)上找到。
|
||||
|
||||
我们欢迎来自更广泛开源社区的贡献和合作。DeepSpeed通用检查点是大规模AI训练和推理DeepSpeed生态系统的一部分。有关所有DeepSpeed技术和创新的更多详细信息,请访问我们的[网站](https://www.deepspeed.ai/)并在X(前Twitter)([英文](https://twitter.com/MSFTDeepSpeed),[日文](https://twitter.com/MSFTDeepSpeedJP))和[中文知乎](https://www.zhihu.com/people/deepspeed)上关注我们。
|
||||
|
||||
# 致谢和贡献
|
||||
我们感谢伊利诺伊大学厄巴纳-香槟分校、Statosphere和英特尔Habana的合作。
|
||||
|
||||
贡献者:
|
||||
Xinyu Lian $^1$, Sam Ade Jacobs $^2$, Lev Kurilenko $^2$, Masahiro Tanaka $^2$,
|
||||
Stas Bekman $^3$, Olatunji Ruwase $^2$, Minjia Zhang $^1$, Moshe Island $^4$
|
||||
|
||||
1: 伊利诺伊大学厄巴纳-香槟分校
|
||||
2: 微软
|
||||
3: Statosphere
|
||||
4: 英特尔Habana
|
|
@ -0,0 +1,115 @@
|
|||
<div align="center">
|
||||
|
||||
# DeepSpeed Universal Checkpointing: 大規模分散学習のための効率的かつ柔軟なチェックポイント
|
||||
|
||||
</div>
|
||||
|
||||
<img src="../media/image1.png" style="width:6.5in;height:3.65625in" />
|
||||
|
||||
DeepSpeed Universal Checkpointを引用する際は、こちらの[arxiv report](https://arxiv.org/abs/2406.18820)を参照してください。
|
||||
|
||||
```
|
||||
@article{lian2024-ucp,
|
||||
title={Universal Checkpointing: Efficient and Flexible Checkpointing for
|
||||
Large Scale Distributed Training},
|
||||
author={Xinyu Lian and Sam Ade Jacobs and Masahiro Tanaka and Lev
|
||||
Kurilenko and Stas Bekman and Olatunji Ruwase and Minjia Zhang},
|
||||
journal={arxiv preprint arxiv:406.18820},
|
||||
year={2024},
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
# はじめに
|
||||
|
||||
モデルの状態を保存するをチェックポイントは、システム障害が発生した場合に途中から学習を再開するために、LLMのトレーニングコストを削減するための重要な技術です。さらに、学習のさまざまな段階でモデルのパフォーマンスを評価することができるため、ハイパーパラメータの調整や異なる下流タスクのためのファインチューニングが容易になります。
|
||||
|
||||
しかし、特に分散学習やファインチューニングのシナリオにおいて、チェックポイントの設計、実装、および使用には多くの課題があります。DeepSpeedが備えるZeROを用いたデータ並列化(ZeRO-DP)、パイプライン並列化(PP)、テンソル並列化(TP)、およびシーケンス並列化(SP)などのいくつかの方法は、LLM学習を加速するための優れた技術ですが、一般的なチェックポイント保存と読み込みのメカニズムと互換性がありません。さらに、これらの異なる並列化を用いたエラスティックで柔軟な組み合わせは、現在サポートされていません。主な理由の一つは、こうした並列化技術がモデルおよび/またはオプティマイザの状態を分割するため、異なるGPUまたはアクセラレータの数に基づいて作成されたチェックポイントから学習を再開することが困難であるためです。
|
||||
|
||||
このリリースでは、分散チェックポイントの問題に対する包括的なソリューションであるDeepSpeed Universal Checkpointing (*UCP*) を紹介します。*UCP*は、任意の並列化戦略とハードウェア構成で再開する柔軟性を提供しながら、効率的なチェックポイント作成を可能にします。また、*UCP*は、ハードウェア障害の際にも、残りの正常なハードウェアでのトレーニングの継続を可能にするため、キャパシティがエラスティックに変化するハードウェアを活用でき、トレーニング時間を短縮するなど、大規模学習を最大限に効率化できます。
|
||||
|
||||
現在のリリースには、*UCP*の次の機能が含まれます。
|
||||
|
||||
- 任意のトレーニング並列技術(例:PP、TP、DP、ZeRO-DP、SP、MoE)に沿った柔軟なチェックポイントの再構成
|
||||
- ファインチューニングを含む学習およびアクセラレータリソースのエラスティックなリソース管理、スケールアップまたはスケールダウン
|
||||
- BLOOM、Megatron GPT、LLAMA、Microsoft Phiなどの複数の商用規模モデルのサポートを伴う実利用例
|
||||
|
||||
# UCPの設計
|
||||
|
||||
DeepSpeed *UCP*における中心的な考え方は、チェックポイントライフサイクルの各段階で最適な表現を選択することです。保存のための分散表現と、読み込みのための統合表現です。これは、2つの重要なメカニズムを使用して実現されます。一つ目は、各モデルパラメータの統合表現と、パラメータのフラグメントを任意のモデル並列化構成におけるランク(プロセスのインデックス)にマッピングするためのメタデータからなるユニバーサルチェックポイントフォーマットです。二つ目は、分散チェックポイントをユニバーサルチェックポイント形式に変換するためのシンプルで強力かつ堅牢な仕様言語であるユニバーサルチェックポイント言語です。
|
||||
|
||||
## ユニバーサルチェックポイントフォーマット
|
||||
|
||||
<img src="../media/image2.png" style="width:6.5in;height:3.42153in" />
|
||||
|
||||
図1:*UCP*の概要:上段と下段はそれぞれソースとターゲットの並列化構成です。中央の段は、ソースからターゲットへの翻訳の仲介ブロックとしての*UCP*を示しています。
|
||||
|
||||
図1は、*UCP*の変換プロセスとフォーマットの抽象レベルの概略図を示しています。変換は、DP、TP、PP、SPなどの任意の並列戦略形式のチェックポイントから始まります。訓練結果のモデルやオプティマイザ状態をネイティブ形式で保存することで、同期されたグローバルチェックポイントの保存に伴うオーバーヘッドを回避します。保存されたチェックポイント(以下、*ソース*と呼びます)を任意の並列戦略に簡単に変換してロードできるようにするために、中間ブロックとして原子チェックポイント (atomic checkpoint) 形式のアイデアを導入します。
|
||||
|
||||
原子チェックポイントの概念は、*UCP*の中心となるものです。これらは、各モデルパラメータの統合表現とオプティマイザ状態を含む細粒度のファイルです。原子チェックポイント形式は、次の3つの理由で有用です。まず、チェックポイントの原子表現は、分散チェックポイントと特定の並列技術およびハードウェア構成の依存関係を切り離します。そのため、*ソース*から*ターゲット*への個別のコンバータを実装する必要はありません。代わりに、*UCP*は異なる分散トレーニング技術間の共通交換形式として機能し、他の分散トレーニング戦略に簡単に変換できます(図2参照)。各モデルパラメータの統合表現を保持することで、*UCP*はモデル状態またはフラグメント状態をパラメータごとに異なるGPUに柔軟にマッピングし、大規模モデルチェックポイントを読み込むために必要な作業メモリを効果的に削減します。第二に、*UCP*の変換は遅延してオンデマンドで行われます。たとえば、トレーニングプロセスが並列技術とハードウェア構成の変更を検出したときです。つまり、既存の分散チェックポイント保存ロジックには変更が必要ありません。第三に、*UCP*の構造により、混合精度トレーニングなどの高度な技術を分散トレーニングで簡単に処理できます。実際には、研究者や実務者はfp16とbfloat16の混合精度トレーニングを切り替えることがあります。fp32の重み/オプティマイザの値を保持することで、トレーニングはfp16またはbfloat16のいずれかで再開できます。
|
||||
|
||||
## ユニバーサルチェックポイント言語
|
||||
|
||||
<img src="../media/flowchart.png" style="width:6.5in;height:2.22222in" />
|
||||
|
||||
図2:*UCP*言語は、分散チェックポイントを*UCP*形式に変換し、新しいハードウェア構成とターゲットの並列技術に基づいて*UCP*チェックポイントを読み込みます。
|
||||
|
||||
*UCP*は異なる並列戦略に対する共通インターフェースを提供しますが、任意の分散チェックポイントから*UCP*への変換の開発には依然として高いエンジニアリングおよび実装コストがかかる場合があります。これは、分散トレーニングの各GPUが保存のためのメソッド(例:PyTorchのtorch.save())を呼び出して、所有するGPUモデル状態のチェックポイントファイルをディスクに保存し、各チェックポイントの正確な内容が異なる技術によって異なるためです。
|
||||
|
||||
この課題に取り組むために、*UCP*は*UCP*言語を提供します。これは、前述の共通形式にいくつかの種類の分散チェックポイントを変換するためのシンプルで強力な仕様言語です。*UCP*はこれを2つの方法で行います。まず、モデル状態の並列戦略の広範な範囲をカバーする事前定義された*パラメータパターン*を持つ宣言型システムを提供します。パラメータパターンには、パラメータがGPU間でどのように分割されているかについてのランタイム情報が含まれています。たとえば、*nopattern*は、パラメータがGPUランクに一意に関連付けられていることを意味し、これはZeRO-1/2やPPなどの技術で最も一般的に見られるパターンです(現在サポートされているパラメータパターンの完全なリストについては、技術レポートを参照してください)。第二に、*UCP*言語は、分散チェックポイントを統合された原子チェックポイントに変換するための一般的な演算子のセットを提供します。抽象的なレベルで見ると、図2に示すように、ターゲットへの移行後に新しい並列技術が必要な場合やハードウェア構成が変更された場合に、*UCP*言語が使用されます。最初に、分散チェックポイントを*UCP*形式に変換し、次にターゲットの並列技術と新しいハードウェア構成に基づいて*UCP*チェックポイントを読み込みます。
|
||||
|
||||
# 主要な結果
|
||||
|
||||
我々は、LLMの訓練に関する一連の実験を通じて*UCP*を評価します。デコーダーのみのトランスフォーマーに焦点を当てました。これは最先端のパフォーマンスを持つアーキテクチャです。いくつかの最大のモデルもデコーダーベースであるため、柔軟で効率的なチェックポイントは特に重要です。このブログでは、さまざまなモデルと並列戦略にわたる正確性の検証結果を紹介します。並列効率分析、詳細なシステムおよびモデルアーキテクチャ、および訓練のハイパーパラメータに関する詳細な結果については、上記の技術レポートを参照してください。
|
||||
|
||||
*UCP*は、異なるハードウェア構成を持つ異なる*ターゲットの*並列戦略に対する*ソース*並列戦略からの柔軟なチェックポイントを提供します。この能力を検証するために、2つの実験グループで*UCP*の正確さを確認しました。
|
||||
|
||||
## シングルソースから複数のターゲットへ
|
||||
|
||||
<img src="../media/image4.png" style="width:4.85477in;height:4in" />
|
||||
|
||||
図3:さまざまなGPU数と並列戦略で*ターゲット*に*UCP*チェックポイントをロードする訓練lossの曲線(イテレーション100で保存・ロード)
|
||||
|
||||
*UCP*が異なる並列戦略とハードウェア構成での訓練再開を可能にするかどうかをテストするために、まずTP=2、PP=2、DP=2(ZeRO-1)、SP=1の構成でGPT-3モデルを訓練します。時間とリソースの制約のため、この実験は最初の200イテレーションに限定しました。100イテレーション目で保存されたチェックポイントを*UCP*チェックポイントに変換し、異なるGPU数と並列戦略を使用してこれらの*UCP*チェックポイントで訓練を再開します。各イテレーションのLM損失(データ並列グループ全体の平均損失)を記録しました。図3は、異なる*ターゲット*並列戦略を使用して*UCP*チェックポイントで訓練をシームレスに再開し、*ソース*戦略を継続して訓練する場合と一致する収束を達成することを示しています。
|
||||
|
||||
## 複数ソースからシングルターゲットへ
|
||||
|
||||
<img src="../media/image5.png" style="width:4.85477in;height:4in" />
|
||||
|
||||
図4:100イテレーション目で異なるソース並列戦略を*UCP*に変換し、異なるターゲットで*UCP*をロードする訓練lossの曲線
|
||||
|
||||
図4は、複数の*ソース*構成から単一の*ターゲット*へのlossの曲線を示しています。固定されたランダムシードを使用して、まずGPT-3モデルを異なる*ソース*構成で訓練します。次に、100イテレーション目で保存された分散チェックポイントを*UCP*チェックポイントに変換し、TP=2、PP=2、DP=1、SP=1の構成でトレーニングを再開します。結果は、異なる*ソース*構成にもかかわらず、そのチェックポイントはすべて*UCP*に変換され、異なる構成で訓練を再開できることを示しています。最も重要なのは、再開されたlossの曲線が、イテレーション101~200での*ソース*の曲線と一致することです。これらの結果は、訓練再開時に任意の構成を異なる構成に変換する*UCP*の効果を検証しています。
|
||||
|
||||
## 異なるモデルアーキテクチャへの対応
|
||||
|
||||
*UCP*はモデルアーキテクチャに依存しません。したがって、GPTモデルとの互換性だけでなく、さまざまなモデルアーキテクチャとサイズをサポートする柔軟性も備えています。図5、6、7は、新しい並列戦略で*UCP*から訓練を再開したときのLLaMA 7B、BLOOM 176B、およびMixtral-7x8B MoEを元にしたモデルのトレーニング収束を示しています。これらの図は、トレーニングが*UCP*でシームレスに再開され、これらの多様なモデル全体で訓練の初期フェーズと一致する収束を達成することを示しています。これらの結果は、さまざまなモデルアーキテクチャとサイズに対する*UCP*の柔軟性を示しています。
|
||||
|
||||
<img src="../media/image6.png" style="width:5in;height:4in" alt="A graph of training step Description automatically generated" />
|
||||
|
||||
図5:LLaMAモデルアーキテクチャの訓練lossの曲線。ソースはTP=PP=DP=2。訓練はイテレーション101で新しいターゲットTP=DP=2、PP=1およびTP=PP=2、DP=1で再開しました。
|
||||
|
||||
<img src="../media/image7.png" style="width:5in;height:4in" alt="A graph with numbers and lines Description automatically generated" />
|
||||
|
||||
図6:BLOOMモデルアーキテクチャの訓練lossの曲線。ソースはTP=2、PP=24、DP=8。訓練はイテレーション94767で新しいターゲットTP=2、DP=4、PP=24で再開しました。
|
||||
|
||||
<img src="../media/image8.png" style="width:5in;height:4in" alt="A graph of training step Description automatically generated" />
|
||||
|
||||
図7:Mixtral-MoEモデルアーキテクチャに基づくモデルの訓練lossの曲線。ソースはTP=1、PP=2、DP=4。訓練はイテレーション501で新しいターゲットTP=PP=DP=2で再開しました。
|
||||
|
||||
# DeepSpeed Universal Checkpointの一般公開
|
||||
|
||||
DeepSpeed Universal Checkpointは、リベースされたMegatron-DeepSpeedバージョンに完全に統合されており、DeepSpeedおよびMegatron-DeepSpeedのGitHubリポジトリを通じてアクセスできます。使用に関する詳細なチュートリアルは、[DeepSpeedチュートリアルページ](https://www.deepspeed.ai/tutorials/universal-checkpointing/)にあります。
|
||||
|
||||
DeepSpeedでは、広範なオープンソースコミュニティからの貢献とコラボレーションを受け入れています。DeepSpeed Universal Checkpointは、大規模AIトレーニングおよび推論のためのDeepSpeedエコシステムの一部です。すべてのDeepSpeed技術とイノベーションについての詳細は、[ウェブサイト](https://www.deepspeed.ai/)をご覧いただき、X(旧Twitter)での[英語](https://twitter.com/MSFTDeepSpeed)、[日本語](https://twitter.com/MSFTDeepSpeedJP)、および[中国のZhihu](https://www.zhihu.com/people/deepspeed)をフォローしてください。
|
||||
|
||||
# 謝辞と貢献
|
||||
|
||||
University of Illinois at Urbana-Champaign、Statosphere、およびIntel Habanaとの協力に感謝します。
|
||||
|
||||
コントリビュータ:
|
||||
Xinyu Lian $^1$, Sam Ade Jacobs $^2$, Lev Kurilenko $^2$, Masahiro Tanaka $^2$, Stas Bekman $^3$, Olatunji Ruwase $^2$, Minjia Zhang $^1$, Moshe Island $^4$
|
||||
|
||||
1: University of Illinois at Urbana-Champaign
|
||||
2: Microsoft
|
||||
3: StasoSphere
|
||||
4: Intel Habana
|
После Ширина: | Высота: | Размер: 57 KiB |
После Ширина: | Высота: | Размер: 140 KiB |
После Ширина: | Высота: | Размер: 175 KiB |
После Ширина: | Высота: | Размер: 41 KiB |
После Ширина: | Высота: | Размер: 130 KiB |
После Ширина: | Высота: | Размер: 130 KiB |
После Ширина: | Высота: | Размер: 80 KiB |
После Ширина: | Высота: | Размер: 144 KiB |
После Ширина: | Высота: | Размер: 91 KiB |
|
@ -1,19 +1,15 @@
|
|||
@echo off
|
||||
|
||||
set CUDA_HOME=%CUDA_PATH%
|
||||
set DISTUTILS_USE_SDK=1
|
||||
|
||||
set DS_BUILD_AIO=0
|
||||
set DS_BUILD_CUTLASS_OPS=0
|
||||
set DS_BUILD_EVOFORMER_ATTN=0
|
||||
set DS_BUILD_FP_QUANTIZER=0
|
||||
set DS_BUILD_RAGGED_DEVICE_OPS=0
|
||||
set DS_BUILD_SPARSE_ATTN=0
|
||||
|
||||
echo Administrative permissions required. Detecting permissions...
|
||||
|
||||
net session >nul 2>&1
|
||||
if %errorLevel% == 0 (
|
||||
echo Success: Administrative permissions confirmed.
|
||||
) else (
|
||||
echo Failure: Current permissions inadequate.
|
||||
goto end
|
||||
)
|
||||
|
||||
|
||||
python setup.py bdist_wheel
|
||||
|
||||
:end
|
||||
|
|
|
@ -5,55 +5,38 @@
|
|||
|
||||
#include "cpu_adagrad.h"
|
||||
#include <torch/extension.h>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adagrad_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float step_size = -1 * _alpha;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float grad = (float)grads[k];
|
||||
float param = (float)_params[k];
|
||||
float momentum = grads[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
|
||||
|
@ -64,58 +47,30 @@ void Adagrad_Optimizer::Step_1(float* _params,
|
|||
grad += _eps;
|
||||
grad = momentum / grad;
|
||||
param = grad * step_size + param;
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_params[k] = param;
|
||||
// STORE UPDATE TERM TO GRAD'S MEMORY
|
||||
grads[k] = grad * step_size;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adagrad_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_4(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
int create_adagrad_optimizer(int optimizer_id,
|
||||
|
@ -149,25 +104,77 @@ int create_adagrad_optimizer(int optimizer_id,
|
|||
return 0;
|
||||
}
|
||||
|
||||
void Adagrad_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void step_invoker(std::shared_ptr<Adagrad_Optimizer> opt,
|
||||
void* _params,
|
||||
void* grads,
|
||||
void* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
opt->Step_8((ds_params_percision_t*)(_params),
|
||||
(ds_params_percision_t*)(grads),
|
||||
(ds_state_precision_t*)(_exp_avg_sq),
|
||||
_param_size);
|
||||
}
|
||||
|
||||
std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
||||
std::function<void(std::shared_ptr<Adagrad_Optimizer>, void*, void*, void*, size_t)>>
|
||||
invokers;
|
||||
|
||||
// Fill map with template functions for each type
|
||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
||||
void create_invoker()
|
||||
{
|
||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
||||
}
|
||||
struct InvokerInitializer {
|
||||
InvokerInitializer()
|
||||
{
|
||||
create_invoker<c10::Half, float>();
|
||||
create_invoker<c10::Half, c10::Half>();
|
||||
create_invoker<c10::BFloat16, float>();
|
||||
create_invoker<c10::BFloat16, c10::BFloat16>();
|
||||
create_invoker<float, float>();
|
||||
}
|
||||
} _invoker_initializer;
|
||||
|
||||
void invoke(std::shared_ptr<Adagrad_Optimizer> opt,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
size_t param_size)
|
||||
{
|
||||
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
|
||||
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg_sq.options().dtype());
|
||||
|
||||
auto it = invokers.find(std::tuple(params_type, state_type));
|
||||
if (it == invokers.end()) {
|
||||
throw std::runtime_error("Adagrad optimizer with param type "s +
|
||||
c10::toString(params_type) + " and state type "s +
|
||||
c10::toString(state_type) +
|
||||
" is not supported on current hardware"s);
|
||||
}
|
||||
|
||||
it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg_sq.data_ptr(), param_size);
|
||||
}
|
||||
|
||||
int ds_adagrad_step(int optimizer_id,
|
||||
|
@ -183,58 +190,13 @@ int ds_adagrad_step(int optimizer_id,
|
|||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adagrad_Optimizer> opt =
|
||||
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step);
|
||||
opt->update_state(lr, epsilon, weight_decay);
|
||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
invoke(opt, params_c, grads_c, exp_avg_sq_c, params_c.numel());
|
||||
|
||||
int ds_adagrad_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adagrad_Optimizer> opt =
|
||||
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step);
|
||||
opt->update_state(lr, epsilon, weight_decay);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
gpu_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -248,9 +210,6 @@ int destroy_adagrad_optimizer(int optimizer_id)
|
|||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
|
||||
m.def("adagrad_update_copy",
|
||||
&ds_adagrad_step_plus_copy,
|
||||
"DeepSpeed CPU Adagrad update and param copy (C++)");
|
||||
m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
|
||||
m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
|
||||
}
|
||||
|
|
|
@ -8,9 +8,6 @@
|
|||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
||||
m.def("adam_update_copy",
|
||||
&ds_adam_step_plus_copy,
|
||||
"DeepSpeed CPU Adam update and param copy (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
||||
}
|
||||
|
|
|
@ -5,42 +5,29 @@
|
|||
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_adam.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_1(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
|
@ -48,26 +35,15 @@ void Adam_Optimizer::Step_1(float* _params,
|
|||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float grad = (float)grads[k];
|
||||
float param = (float)_params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
|
||||
|
@ -83,66 +59,31 @@ void Adam_Optimizer::Step_1(float* _params,
|
|||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
|
||||
param = grad * step_size + param;
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_4(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
int create_adam_optimizer(int optimizer_id,
|
||||
|
@ -185,33 +126,86 @@ int create_adam_optimizer(int optimizer_id,
|
|||
return 0;
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_8(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void step_invoker(std::shared_ptr<Adam_Optimizer> opt,
|
||||
void* _params,
|
||||
void* grads,
|
||||
void* _exp_avg,
|
||||
void* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
opt->Step_8((ds_params_percision_t*)(_params),
|
||||
(ds_params_percision_t*)(grads),
|
||||
(ds_state_precision_t*)(_exp_avg),
|
||||
(ds_state_precision_t*)(_exp_avg_sq),
|
||||
_param_size);
|
||||
}
|
||||
|
||||
std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
||||
std::function<void(std::shared_ptr<Adam_Optimizer>, void*, void*, void*, void*, size_t)>>
|
||||
invokers;
|
||||
|
||||
// Fill map with template functions for each type
|
||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
||||
void create_invoker()
|
||||
{
|
||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
||||
}
|
||||
struct InvokerInitializer {
|
||||
InvokerInitializer()
|
||||
{
|
||||
create_invoker<c10::Half, float>();
|
||||
create_invoker<c10::Half, c10::Half>();
|
||||
create_invoker<c10::BFloat16, float>();
|
||||
create_invoker<c10::BFloat16, c10::BFloat16>();
|
||||
create_invoker<float, float>();
|
||||
}
|
||||
} _invoker_initializer;
|
||||
|
||||
void invoke(std::shared_ptr<Adam_Optimizer> opt,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
size_t param_size)
|
||||
{
|
||||
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
|
||||
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype());
|
||||
|
||||
auto it = invokers.find(std::tuple(params_type, state_type));
|
||||
if (it == invokers.end()) {
|
||||
throw std::runtime_error("Adam optimizer with param type "s + c10::toString(params_type) +
|
||||
" and state type "s + c10::toString(state_type) +
|
||||
" is not supported on current hardware"s);
|
||||
}
|
||||
|
||||
it->second(opt,
|
||||
params.data_ptr(),
|
||||
grads.data_ptr(),
|
||||
exp_avg.data_ptr(),
|
||||
exp_avg_sq.data_ptr(),
|
||||
param_size);
|
||||
}
|
||||
|
||||
int ds_adam_step(int optimizer_id,
|
||||
|
@ -232,75 +226,13 @@ int ds_adam_step(int optimizer_id,
|
|||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& device_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto device_params_c = device_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
device_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ typedef enum : int {
|
|||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename index_t>
|
||||
struct AdamFunctor {
|
||||
__device__ __forceinline__ void operator()(int chunk_size,
|
||||
volatile int* noop_gmem,
|
||||
|
@ -48,13 +48,13 @@ struct AdamFunctor {
|
|||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
index_t n = tl.sizes[tensor_loc];
|
||||
|
||||
T* g = (T*)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
@ -71,7 +71,8 @@ struct AdamFunctor {
|
|||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
|
||||
for (index_t i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
|
@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
|
|||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
size_t max_size = 0;
|
||||
bool requires_64bit_indexing = false;
|
||||
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
|
||||
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
|
||||
if (it2->numel() > max_size) {
|
||||
max_size = it2->numel();
|
||||
if (max_size >= INT_MAX) {
|
||||
requires_64bit_indexing = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (requires_64bit_indexing) { break; }
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
|
||||
0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
if (requires_64bit_indexing) {
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
|
||||
0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
|
||||
(int64_t)chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0, int64_t>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
} else {
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
|
||||
0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0, int32_t>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ struct TensorListMetadata {
|
|||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(int chunk_size,
|
||||
__global__ void multi_tensor_apply_kernel(int64_t chunk_size,
|
||||
volatile int* noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
|
@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size,
|
|||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(int block_size,
|
||||
int chunk_size,
|
||||
void multi_tensor_apply(int64_t block_size,
|
||||
int64_t chunk_size,
|
||||
const at::Tensor& noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
|
@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size,
|
|||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "custom_cuda_layers.h"
|
||||
|
||||
__global__ void param_update_kernel(const float* input, __half* output, int size)
|
||||
{
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (id < size) { output[id] = (__half)input[id]; }
|
||||
}
|
||||
|
||||
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
|
||||
{
|
||||
int threads = 1024;
|
||||
|
||||
dim3 grid_dim((size - 1) / threads + 1);
|
||||
dim3 block_dim(threads);
|
||||
|
||||
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
|
||||
}
|
||||
|
||||
__global__ void param_update_kernel_half(const float* input, __half* output, int size)
|
||||
{
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
__half2* output_cast = reinterpret_cast<__half2*>(output);
|
||||
if (id < size) {
|
||||
float input_f = input[id];
|
||||
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
|
||||
output_cast[id] = *input_h;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream)
|
||||
{
|
||||
int threads = 1024;
|
||||
size /= 2;
|
||||
dim3 grid_dim((size - 1) / threads + 1);
|
||||
dim3 block_dim(threads);
|
||||
|
||||
param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
|
||||
}
|
|
@ -247,7 +247,7 @@ void all_reduce_caching(torch::Tensor& data,
|
|||
.wait());
|
||||
}
|
||||
|
||||
void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op)
|
||||
void inference_all_reduce(torch::Tensor& data, py::object op)
|
||||
{
|
||||
#ifdef DO_PROFILE
|
||||
static double total_time = 0.0;
|
||||
|
|
|
@ -21,9 +21,13 @@
|
|||
// states for collectives
|
||||
enum coll_state {
|
||||
coll_begin = 0,
|
||||
coll_allreduce_naive__copy_in_done, // this state is for rank != 0
|
||||
coll_allreduce_naive__reduce_done, // this state is for rank == 0
|
||||
coll_allreduce_naive__copy_out_done, // this state is for rank != 0
|
||||
coll_allreduce_naive__copy_in_done,
|
||||
coll_allreduce_naive__reduce_done,
|
||||
// alternative state when allreduce is working on alternative buffer
|
||||
// of the double buffer.
|
||||
coll_alt1_allreduce_naive__copy_in_done,
|
||||
coll_alt2_allreduce_naive__copy_in_done,
|
||||
coll_alt1_allreduce_naive__reduce_done,
|
||||
};
|
||||
|
||||
// SHM building blocks
|
||||
|
@ -71,6 +75,8 @@ void shared_close(SharedData* data)
|
|||
}
|
||||
}
|
||||
|
||||
static int world_size;
|
||||
|
||||
// SHM based allreduce helper functions
|
||||
// buffer that holds shm name
|
||||
#define NAME_BUF_SIZE 1000
|
||||
|
@ -78,64 +84,37 @@ void shared_close(SharedData* data)
|
|||
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
|
||||
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
|
||||
struct allreduce_workspace {
|
||||
enum coll_state state;
|
||||
sem_t mutex;
|
||||
sem_t turnstile1;
|
||||
sem_t turnstile2;
|
||||
int counter;
|
||||
char buffer[MAX_BUF_SIZE];
|
||||
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
|
||||
// idx=1 -- state for distributed_naive_all_reduce
|
||||
// double buffer to avoid syncing between rounds
|
||||
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for symmetric_naive_all_reduce
|
||||
// after that : buffer for distributed_naive_all_reduce
|
||||
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
|
||||
};
|
||||
|
||||
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
|
||||
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
|
||||
|
||||
struct allreduce_workspace** workspace;
|
||||
|
||||
void wait_buffer_state_until(int index, enum coll_state state)
|
||||
{
|
||||
volatile enum coll_state* state_ptr = &(workspace[index]->state);
|
||||
// buffer for small messages, double buffer
|
||||
char** symmetric_buffer[2];
|
||||
// buffer for large messages, double buffer
|
||||
char** distributed_buffer[2];
|
||||
|
||||
while (*state_ptr != state)
|
||||
;
|
||||
}
|
||||
|
||||
void wait_buffer_state_until_range(int index, enum coll_state start, int size)
|
||||
void wait_buffer_state_until_2(int index,
|
||||
enum coll_state state0,
|
||||
enum coll_state state1,
|
||||
int state_group)
|
||||
{
|
||||
volatile enum coll_state* state_ptr = &(workspace[index]->state);
|
||||
enum coll_state end = (enum coll_state)(start + size);
|
||||
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
|
||||
|
||||
while (1) {
|
||||
volatile enum coll_state cur_state = *state_ptr;
|
||||
if (cur_state >= start and cur_state < end) break;
|
||||
if (cur_state == state0 || cur_state == state1) break;
|
||||
}
|
||||
}
|
||||
|
||||
void wait_buffer_state_until_not(int index, enum coll_state state)
|
||||
{
|
||||
volatile enum coll_state* state_ptr = &(workspace[index]->state);
|
||||
|
||||
while (*state_ptr == state)
|
||||
;
|
||||
}
|
||||
|
||||
void barrier_wait(int root_idx, int num_ranks)
|
||||
{
|
||||
// Phase 1: Wait for all threads to enter the barrier
|
||||
auto shared = workspace[root_idx];
|
||||
sem_wait(&shared->mutex);
|
||||
shared->counter++;
|
||||
if (shared->counter == num_ranks) {
|
||||
for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile1); }
|
||||
}
|
||||
sem_post(&shared->mutex);
|
||||
sem_wait(&shared->turnstile1);
|
||||
|
||||
// Phase 2: Wait for all threads to exit the barrier
|
||||
sem_wait(&shared->mutex);
|
||||
shared->counter--;
|
||||
if (shared->counter == 0) {
|
||||
for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile2); }
|
||||
}
|
||||
sem_post(&shared->mutex);
|
||||
sem_wait(&shared->turnstile2);
|
||||
}
|
||||
|
||||
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||
inline __m512 cvt_bf16_to_fp32(const __m256i src)
|
||||
{
|
||||
|
@ -164,139 +143,58 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src)
|
|||
return _mm512_cvtusepi32_epi16(t_value);
|
||||
}
|
||||
|
||||
void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
|
||||
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||
inline __m512 cvt_fp16_to_fp32(const __m256i src) { return _mm512_cvtph_ps(src); }
|
||||
|
||||
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
|
||||
inline __m256i cvt_fp32_to_fp16(const __m512 src)
|
||||
{
|
||||
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
}
|
||||
|
||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_bf16_buffers(int start_elements,
|
||||
int num_elements,
|
||||
int num_buffers,
|
||||
int to_buffer_idx,
|
||||
struct allreduce_workspace** workspace)
|
||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
|
||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_fp32_buffers(int start_elements,
|
||||
int num_elements,
|
||||
int num_buffers,
|
||||
int to_buffer_idx,
|
||||
struct allreduce_workspace** workspace)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
// N_REDUCE_LIMIT is the number of buffers that can be reduced together in one shot.
|
||||
// Compared with do N-1 2-reduces which needs 2*(N-1) read and N-1 write,
|
||||
// N-reduce only needs N read and 1 write, this saves 2/3 memory bandwidth.
|
||||
// When increase N_REDUCE_LIMIT to a bigger number, do the following steps
|
||||
// 1. Extend REPEAT_<X> macros list down below
|
||||
// 2. Extend switch cases which call "REPEAT(X, ...)" down below
|
||||
#define N_REDUCE_LIMIT 16
|
||||
|
||||
void reduce_all_buffers(struct allreduce_workspace** workspace,
|
||||
int start_elements,
|
||||
void reduce_all_buffers(int start_elements,
|
||||
int num_elements,
|
||||
c10::ScalarType scalar_type,
|
||||
int num_buffers,
|
||||
int to_buffer_idx)
|
||||
int to_buffer_idx,
|
||||
char* to_buffer,
|
||||
char** buffers)
|
||||
{
|
||||
switch (scalar_type) {
|
||||
case c10::ScalarType::BFloat16:
|
||||
if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) {
|
||||
reduce_bf16_buffers(
|
||||
start_elements, num_elements, num_buffers, to_buffer_idx, workspace);
|
||||
} else {
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
if (i == to_buffer_idx) continue;
|
||||
reduce_2_bf16_buffers_iio(
|
||||
num_elements,
|
||||
workspace[i]->buffer + start_elements * 2,
|
||||
workspace[to_buffer_idx]->buffer + start_elements * 2,
|
||||
workspace[to_buffer_idx]->buffer + start_elements * 2);
|
||||
}
|
||||
}
|
||||
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
case c10::ScalarType::Half:
|
||||
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) {
|
||||
reduce_fp32_buffers(
|
||||
start_elements, num_elements, num_buffers, to_buffer_idx, workspace);
|
||||
} else {
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
if (i == to_buffer_idx) continue;
|
||||
reduce_2_fp32_buffers_iio(
|
||||
num_elements,
|
||||
workspace[i]->buffer + start_elements * 4,
|
||||
workspace[to_buffer_idx]->buffer + start_elements * 4,
|
||||
workspace[to_buffer_idx]->buffer + start_elements * 4);
|
||||
}
|
||||
}
|
||||
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
default: assert(!"Should not get here");
|
||||
}
|
||||
}
|
||||
|
||||
#define REPEAT(N, x) REPEAT_##N(x)
|
||||
#define REPEAT_1(x) x(1)
|
||||
#define REPEAT_2(x) \
|
||||
REPEAT_1(x); \
|
||||
x(2)
|
||||
#define REPEAT_3(x) \
|
||||
REPEAT_2(x); \
|
||||
x(3)
|
||||
#define REPEAT_4(x) \
|
||||
REPEAT_3(x); \
|
||||
x(4)
|
||||
#define REPEAT_5(x) \
|
||||
REPEAT_4(x); \
|
||||
x(5)
|
||||
#define REPEAT_6(x) \
|
||||
REPEAT_5(x); \
|
||||
x(6)
|
||||
#define REPEAT_7(x) \
|
||||
REPEAT_6(x); \
|
||||
x(7)
|
||||
#define REPEAT_8(x) \
|
||||
REPEAT_7(x); \
|
||||
x(8)
|
||||
#define REPEAT_9(x) \
|
||||
REPEAT_8(x); \
|
||||
x(9)
|
||||
#define REPEAT_10(x) \
|
||||
REPEAT_9(x); \
|
||||
x(10)
|
||||
#define REPEAT_11(x) \
|
||||
REPEAT_10(x); \
|
||||
x(11)
|
||||
#define REPEAT_12(x) \
|
||||
REPEAT_11(x); \
|
||||
x(12)
|
||||
#define REPEAT_13(x) \
|
||||
REPEAT_12(x); \
|
||||
x(13)
|
||||
#define REPEAT_14(x) \
|
||||
REPEAT_13(x); \
|
||||
x(14)
|
||||
#define REPEAT_15(x) \
|
||||
REPEAT_14(x); \
|
||||
x(15)
|
||||
|
||||
#define CVT_ADD_BF16(x) \
|
||||
do { \
|
||||
auto in##x##_val = \
|
||||
cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[x]->buffer + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
#define CVT_ADD_BF16(x) \
|
||||
do { \
|
||||
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
|
||||
// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes
|
||||
// If you change implementation of reduce_2_bf16_buffers_iio or reduce_2_fp32_buffers_iio, check
|
||||
// whether this number needs to be changed
|
||||
// If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs
|
||||
// to be changed
|
||||
#define VECTOR_LENGTH_IN_BYTES 32
|
||||
|
||||
void reduce_bf16_buffers(int start_elements,
|
||||
int num_elements,
|
||||
int num_buffers,
|
||||
int to_buffer_idx,
|
||||
struct allreduce_workspace** workspace)
|
||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
{
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
|
@ -307,77 +205,106 @@ void reduce_bf16_buffers(int start_elements,
|
|||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[0]->buffer + i)));
|
||||
switch (num_buffers) {
|
||||
case 16: REPEAT(15, CVT_ADD_BF16); break;
|
||||
case 15: REPEAT(14, CVT_ADD_BF16); break;
|
||||
case 14: REPEAT(13, CVT_ADD_BF16); break;
|
||||
case 13: REPEAT(12, CVT_ADD_BF16); break;
|
||||
case 12: REPEAT(11, CVT_ADD_BF16); break;
|
||||
case 11: REPEAT(10, CVT_ADD_BF16); break;
|
||||
case 10: REPEAT(9, CVT_ADD_BF16); break;
|
||||
case 9: REPEAT(8, CVT_ADD_BF16); break;
|
||||
case 8: REPEAT(7, CVT_ADD_BF16); break;
|
||||
case 7: REPEAT(6, CVT_ADD_BF16); break;
|
||||
case 6: REPEAT(5, CVT_ADD_BF16); break;
|
||||
case 5: REPEAT(4, CVT_ADD_BF16); break;
|
||||
case 4: REPEAT(3, CVT_ADD_BF16); break;
|
||||
case 3: REPEAT(2, CVT_ADD_BF16); break;
|
||||
default: assert(!"Should not get here.");
|
||||
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
||||
switch (world_size) {
|
||||
case 16: CVT_ADD_BF16(15);
|
||||
case 15: CVT_ADD_BF16(14);
|
||||
case 14: CVT_ADD_BF16(13);
|
||||
case 13: CVT_ADD_BF16(12);
|
||||
case 12: CVT_ADD_BF16(11);
|
||||
case 11: CVT_ADD_BF16(10);
|
||||
case 10: CVT_ADD_BF16(9);
|
||||
case 9: CVT_ADD_BF16(8);
|
||||
case 8: CVT_ADD_BF16(7);
|
||||
case 7: CVT_ADD_BF16(6);
|
||||
case 6: CVT_ADD_BF16(5);
|
||||
case 5: CVT_ADD_BF16(4);
|
||||
case 4: CVT_ADD_BF16(3);
|
||||
case 3: CVT_ADD_BF16(2);
|
||||
case 2: CVT_ADD_BF16(1);
|
||||
case 1: break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(workspace[to_buffer_idx]->buffer + i),
|
||||
cvt_fp32_to_bf16(inout_val));
|
||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < num_buffers; j++) { val += *(at::BFloat16*)(workspace[j]->buffer + i); }
|
||||
*(at::BFloat16*)(workspace[to_buffer_idx]->buffer + i) = val;
|
||||
for (int j = 0; j < world_size; j++) { val += *(at::BFloat16*)(buffers[j] + i); }
|
||||
*(at::BFloat16*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
|
||||
{
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto in0_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in0 + i)));
|
||||
auto in1_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in1 + i)));
|
||||
auto out_val = _mm512_add_ps(in0_val, in1_val);
|
||||
_mm256_storeu_si256((__m256i*)((char*)out + i), cvt_fp32_to_bf16(out_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = main_elements * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float in0_val = *((at::BFloat16*)((char*)in0 + i));
|
||||
float in1_val = *((at::BFloat16*)((char*)in1 + i));
|
||||
*((at::BFloat16*)((char*)out + i)) = in0_val + in1_val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_F32(x) \
|
||||
do { \
|
||||
auto in##x##_val = _mm256_loadu_ps((float*)(workspace[x]->buffer + i)); \
|
||||
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
|
||||
#define CVT_ADD_FP16(x) \
|
||||
do { \
|
||||
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
void reduce_fp32_buffers(int start_elements,
|
||||
int num_elements,
|
||||
int num_buffers,
|
||||
int to_buffer_idx,
|
||||
struct allreduce_workspace** workspace)
|
||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
{
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
||||
switch (world_size) {
|
||||
case 16: CVT_ADD_FP16(15);
|
||||
case 15: CVT_ADD_FP16(14);
|
||||
case 14: CVT_ADD_FP16(13);
|
||||
case 13: CVT_ADD_FP16(12);
|
||||
case 12: CVT_ADD_FP16(11);
|
||||
case 11: CVT_ADD_FP16(10);
|
||||
case 10: CVT_ADD_FP16(9);
|
||||
case 9: CVT_ADD_FP16(8);
|
||||
case 8: CVT_ADD_FP16(7);
|
||||
case 7: CVT_ADD_FP16(6);
|
||||
case 6: CVT_ADD_FP16(5);
|
||||
case 5: CVT_ADD_FP16(4);
|
||||
case 4: CVT_ADD_FP16(3);
|
||||
case 3: CVT_ADD_FP16(2);
|
||||
case 2: CVT_ADD_FP16(1);
|
||||
case 1: break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) { val += *(at::Half*)(buffers[j] + i); }
|
||||
*(at::Half*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_F32(x) \
|
||||
do { \
|
||||
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
|
||||
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
{
|
||||
const int element_size = 4;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
|
@ -388,67 +315,45 @@ void reduce_fp32_buffers(int start_elements,
|
|||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = _mm256_loadu_ps((float*)(workspace[0]->buffer + i));
|
||||
switch (num_buffers) {
|
||||
case 16: REPEAT(15, CVT_ADD_F32); break;
|
||||
case 15: REPEAT(14, CVT_ADD_F32); break;
|
||||
case 14: REPEAT(13, CVT_ADD_F32); break;
|
||||
case 13: REPEAT(12, CVT_ADD_F32); break;
|
||||
case 12: REPEAT(11, CVT_ADD_F32); break;
|
||||
case 11: REPEAT(10, CVT_ADD_F32); break;
|
||||
case 10: REPEAT(9, CVT_ADD_F32); break;
|
||||
case 9: REPEAT(8, CVT_ADD_F32); break;
|
||||
case 8: REPEAT(7, CVT_ADD_F32); break;
|
||||
case 7: REPEAT(6, CVT_ADD_F32); break;
|
||||
case 6: REPEAT(5, CVT_ADD_F32); break;
|
||||
case 5: REPEAT(4, CVT_ADD_F32); break;
|
||||
case 4: REPEAT(3, CVT_ADD_F32); break;
|
||||
case 3: REPEAT(2, CVT_ADD_F32); break;
|
||||
default: assert(!"Should not get here.");
|
||||
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
|
||||
switch (world_size) {
|
||||
case 16: CVT_ADD_F32(15);
|
||||
case 15: CVT_ADD_F32(14);
|
||||
case 14: CVT_ADD_F32(13);
|
||||
case 13: CVT_ADD_F32(12);
|
||||
case 12: CVT_ADD_F32(11);
|
||||
case 11: CVT_ADD_F32(10);
|
||||
case 10: CVT_ADD_F32(9);
|
||||
case 9: CVT_ADD_F32(8);
|
||||
case 8: CVT_ADD_F32(7);
|
||||
case 7: CVT_ADD_F32(6);
|
||||
case 6: CVT_ADD_F32(5);
|
||||
case 5: CVT_ADD_F32(4);
|
||||
case 4: CVT_ADD_F32(3);
|
||||
case 3: CVT_ADD_F32(2);
|
||||
case 2: CVT_ADD_F32(1);
|
||||
case 1: break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
|
||||
inout_val = _mm256_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_ps((float*)(workspace[to_buffer_idx]->buffer + i), inout_val);
|
||||
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < num_buffers; j++) { val += *(float*)(workspace[j]->buffer + i); }
|
||||
*(float*)(workspace[to_buffer_idx]->buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
|
||||
{
|
||||
const int element_size = 4;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto in0_val = _mm256_loadu_ps((float*)((char*)in0 + i));
|
||||
auto in1_val = _mm256_loadu_ps((float*)((char*)in1 + i));
|
||||
auto out_val = _mm256_add_ps(in0_val, in1_val);
|
||||
_mm256_storeu_ps((float*)((char*)out + i), out_val);
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = main_elements * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float in0_val = *((float*)((char*)in0 + i));
|
||||
float in1_val = *((float*)((char*)in1 + i));
|
||||
*((float*)((char*)out + i)) = in0_val + in1_val;
|
||||
for (int j = 0; j < world_size; j++) { val += *(float*)(buffers[j] + i); }
|
||||
*(float*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_initialized = 0;
|
||||
static int world_size;
|
||||
static int world_rank;
|
||||
|
||||
void shm_initialize(int size, int rank, char* addr_string, char* port_string)
|
||||
|
@ -477,10 +382,15 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string)
|
|||
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank);
|
||||
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
|
||||
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
||||
workspace_buf->state = coll_begin;
|
||||
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
|
||||
workspace_buf->states[1] = coll_begin;
|
||||
|
||||
// create the workspace pointer list
|
||||
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
|
||||
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
|
||||
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
|
||||
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
|
||||
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
|
||||
|
||||
// map shm of all ranks
|
||||
for (int i = 0; i < size; i++) {
|
||||
|
@ -494,11 +404,11 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string)
|
|||
workspace[i] = workspace_buf_other;
|
||||
} else {
|
||||
workspace[i] = workspace_buf;
|
||||
workspace_buf->counter = 0;
|
||||
sem_init(&workspace_buf->mutex, 1, 1);
|
||||
sem_init(&workspace_buf->turnstile1, 1, 0);
|
||||
sem_init(&workspace_buf->turnstile2, 1, 0);
|
||||
}
|
||||
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
|
||||
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
|
||||
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
|
||||
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -539,46 +449,122 @@ size_t slice_el_start(size_t chunk_el, int slice_idx)
|
|||
return slice_size * slice_idx;
|
||||
}
|
||||
|
||||
void naive_all_reduce(char* data_ptr,
|
||||
c10::ScalarType scalar_type,
|
||||
size_t chunk_size,
|
||||
size_t chunk_el)
|
||||
/*
|
||||
Symmetrical naive all_reduce
|
||||
step 0: before enter the function ith times, state is copy(i-1)
|
||||
step 1: each rank copy data from input (data_ptr) to SHM buffer[i]
|
||||
step 2: set own state to copy(i)
|
||||
step 3: wait each other rank's state equal or later than copy(i)
|
||||
step 4: reduce across SHM buffer(ith) directly into output (data_ptr)
|
||||
*/
|
||||
void symmetric_naive_all_reduce(char* data_ptr,
|
||||
c10::ScalarType scalar_type,
|
||||
size_t chunk_size,
|
||||
size_t chunk_el)
|
||||
{
|
||||
parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_allreduce_naive__copy_in_done;
|
||||
#ifdef DO_PROFILE
|
||||
static double total_t1_t0 = 0.0;
|
||||
static double total_t2_t1 = 0.0;
|
||||
static double total_t3_t2 = 0.0;
|
||||
static int count = -16; // warmup
|
||||
auto t0 = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
if (world_rank == 0) {
|
||||
// compute allreduce result on rank 0
|
||||
for (int i = 1; i < world_size; i++) {
|
||||
// wait until the other rank copy the buffer
|
||||
wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done);
|
||||
/*
|
||||
We can't have infinite number of buffers and states. 2 sets of buffer
|
||||
and 3 sets of states is just enough. Consider current rank is in step 3,
|
||||
with it's own state set to copy(i), the other rank will them have the
|
||||
following situations:
|
||||
------------------------------------------------
|
||||
my state | can I proceed? | the other rank state
|
||||
================================================
|
||||
| N | copy(i-1)
|
||||
|----------------|---------------------
|
||||
copy(i) | Y | copy(i)
|
||||
|----------------|---------------------
|
||||
| Y | copy(i+1)
|
||||
------------------------------------------------
|
||||
* When I have state as copy(i), the other rank cannot have state
|
||||
copy(i-2) or before. In that case I'll be in state copy(i-1) and cannot
|
||||
proceed to copy(i).
|
||||
* The other rank cannot have state copy(i+2) or beyond because my
|
||||
state is still copy(i), copy(i+1) is as far as the other rank could go.
|
||||
* From a rank's POV, all the other ranks can be divided into three sets:
|
||||
- Lagging ranks: ranks that are still working on previous iteration
|
||||
- Syncing ranks: ranks that are working on current iteration
|
||||
- Leading ranks: ranks that are working on next iteration
|
||||
* We can have 3 sets of states, one set for syncing ranks; one set for
|
||||
lagging ranks; one set of leading ranks. With 3 sets of states, we can
|
||||
distinguish between lagging and leading ranks.
|
||||
* Note from any rank's POV, leading ranks and lagging ranks does not
|
||||
appear at the same time. Either all other ranks are syncing or
|
||||
lagging, or all other ranks are syncing or leading. Otherwise leading
|
||||
and lagging ranks will be 2 iterations apart and this should not happen.
|
||||
* So we have 2 sets of buffers, one buffer is used by current iter;
|
||||
one buffer used by either lagging ranks or leading ranks.
|
||||
*/
|
||||
const int state_group = 0;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
enum coll_state copy_current, copy_next;
|
||||
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_alt2_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 2:
|
||||
copy_current = coll_alt2_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
default: assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 3;
|
||||
|
||||
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto t1 = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until the other rank copy the buffer
|
||||
if (i != world_rank) { wait_buffer_state_until_2(i, copy_current, copy_next, state_group); }
|
||||
}
|
||||
#ifdef DO_PROFILE
|
||||
auto t2 = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
// each rank reduce the buffer independently so therre is no need for synchronization afterward
|
||||
reduce_all_buffers(
|
||||
0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
|
||||
|
||||
// switch buffer
|
||||
current_buffer = 1 - current_buffer;
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto t3 = std::chrono::system_clock::now();
|
||||
|
||||
count++;
|
||||
if (count > 0) {
|
||||
total_t1_t0 += std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count();
|
||||
total_t2_t1 += std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
|
||||
total_t3_t2 += std::chrono::duration_cast<std::chrono::microseconds>(t3 - t2).count();
|
||||
if (world_rank == 0 && count == 1000) {
|
||||
printf("symmetric_naive_all_reduce time breakdown:\n");
|
||||
printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count);
|
||||
printf("\twait for copy: %.2f\n", total_t2_t1 / count);
|
||||
printf("\treduce: %.2f\n", total_t3_t2 / count);
|
||||
}
|
||||
reduce_all_buffers(workspace, 0, chunk_el, scalar_type, world_size, 0);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_allreduce_naive__reduce_done;
|
||||
parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size);
|
||||
}
|
||||
if (world_rank != 0) {
|
||||
wait_buffer_state_until(0, coll_allreduce_naive__reduce_done);
|
||||
parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_allreduce_naive__copy_out_done;
|
||||
}
|
||||
if (world_rank == 0) {
|
||||
for (int i = 1; i < world_size; i++) {
|
||||
wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done);
|
||||
}
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_begin;
|
||||
}
|
||||
if (world_rank != 0) {
|
||||
// if rank 0 spin too fast it could be in state 1 of next allreduce
|
||||
// in this case wait_buffer_state_until(0, 0) may cause deadlock
|
||||
// what we are certain is when rank 0 finishes the state won't be 2
|
||||
wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done);
|
||||
workspace[world_rank]->state = coll_begin;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// naive allreduce distributed, each rank do naive reduce on its slice
|
||||
|
@ -597,10 +583,33 @@ void distributed_naive_reduce(char* data_ptr,
|
|||
auto t0 = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
const int state_group = 1;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
enum coll_state copy_current, copy_next, reduce_current;
|
||||
|
||||
// similar to symmetric_naive_allreduce, but here we only need two sets of
|
||||
// states, because distributed naive reduce has two barriers in the algorithm
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allreduce_naive__copy_in_done;
|
||||
reduce_current = coll_allreduce_naive__reduce_done;
|
||||
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
||||
reduce_current = coll_alt1_allreduce_naive__reduce_done;
|
||||
copy_next = coll_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
default: assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 2;
|
||||
|
||||
int data_size = chunk_size / chunk_el;
|
||||
parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size);
|
||||
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_allreduce_naive__copy_in_done;
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto t1 = std::chrono::system_clock::now();
|
||||
|
@ -608,7 +617,8 @@ void distributed_naive_reduce(char* data_ptr,
|
|||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks copy the buffer
|
||||
wait_buffer_state_until_range(i, coll_allreduce_naive__copy_in_done, 2);
|
||||
if (i != world_rank)
|
||||
wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
|
||||
}
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
|
@ -616,40 +626,36 @@ void distributed_naive_reduce(char* data_ptr,
|
|||
#endif
|
||||
|
||||
// reduce scatter
|
||||
reduce_all_buffers(workspace,
|
||||
slice_el_start(chunk_el, world_rank),
|
||||
reduce_all_buffers(slice_el_start(chunk_el, world_rank),
|
||||
slice_size(chunk_el, world_rank),
|
||||
scalar_type,
|
||||
world_size,
|
||||
world_rank);
|
||||
world_rank,
|
||||
distributed_buffer[current_buffer][world_rank],
|
||||
distributed_buffer[current_buffer]);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_allreduce_naive__reduce_done;
|
||||
workspace[world_rank]->states[state_group] = reduce_current;
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto t3 = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
int rank = (i + world_rank) % world_size;
|
||||
// wait until the other rank reduce the buffer
|
||||
wait_buffer_state_until_range(rank, coll_allreduce_naive__reduce_done, 2);
|
||||
parallel_memcpy(slice_data(data_ptr, chunk_el, data_size, rank),
|
||||
slice_data(workspace[rank]->buffer, chunk_el, chunk_size / chunk_el, rank),
|
||||
slice_size(chunk_el, rank) * data_size);
|
||||
// wait until all the other ranks reduce the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
|
||||
}
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_allreduce_naive__copy_out_done;
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto t4 = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
wait_buffer_state_until_not(i, coll_allreduce_naive__reduce_done);
|
||||
int rank = (i + world_rank) % world_size;
|
||||
parallel_memcpy(
|
||||
slice_data(data_ptr, chunk_el, data_size, rank),
|
||||
slice_data(
|
||||
distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
|
||||
slice_size(chunk_el, rank) * data_size);
|
||||
}
|
||||
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->state = coll_begin;
|
||||
current_buffer = 1 - current_buffer;
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto t5 = std::chrono::system_clock::now();
|
||||
|
@ -665,8 +671,8 @@ void distributed_naive_reduce(char* data_ptr,
|
|||
printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count);
|
||||
printf("\twait for copy: %.2f\n", total_t2_t1 / count);
|
||||
printf("\treduce: %.2f\n", total_t3_t2 / count);
|
||||
printf("\tcopy buffer to output: %.2f\n", total_t4_t3 / count);
|
||||
printf("\twait finish: %.2f\n", total_t5_t4 / count);
|
||||
printf("\twait for reduce finish: %.2f\n", total_t4_t3 / count);
|
||||
printf("\tcopy out: %.2f\n", total_t5_t4 / count);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -679,7 +685,7 @@ void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size)
|
|||
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
|
||||
size_t chunk_el = chunk_size / (data_size / numel);
|
||||
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD)
|
||||
naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
else
|
||||
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "shm.h"
|
||||
|
||||
// #define DO_PROFILE
|
||||
#ifdef DO_PROFILE
|
||||
#include <cfloat>
|
||||
#include <chrono>
|
||||
#endif
|
||||
|
||||
// Communication settings
|
||||
static int world_rank = -1;
|
||||
static int world_size = -1;
|
||||
|
||||
static bool is_initialized = 0;
|
||||
|
||||
static bool all_ranks_local_p = false;
|
||||
|
||||
void initialize(int size, int rank)
|
||||
{
|
||||
if (is_initialized) return;
|
||||
|
||||
// Check whether all ranks is on the same physical machine.
|
||||
// If true, we will use an SHM based low latency allreduce
|
||||
|
||||
auto ls_string = std::getenv("LOCAL_SIZE");
|
||||
int ls = 0;
|
||||
if (ls_string != NULL) { ls = std::stoi(std::getenv("LOCAL_SIZE")); }
|
||||
|
||||
if (size >= 1 && size == ls) { all_ranks_local_p = true; }
|
||||
|
||||
world_size = size;
|
||||
world_rank = rank;
|
||||
is_initialized = 1;
|
||||
|
||||
auto addr_string = std::getenv("MASTER_ADDR");
|
||||
if (addr_string == NULL) { addr_string = ""; }
|
||||
auto port_string = std::getenv("MASTER_PORT");
|
||||
if (port_string == NULL) { port_string = ""; }
|
||||
|
||||
if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
|
||||
}
|
||||
|
||||
int get_rank(int group = 0) { return world_rank; }
|
||||
|
||||
int get_world_size(int group = 0) { return world_size; }
|
||||
|
||||
// Success - return 0
|
||||
// Fail (cannot hornor the request and need to fall back) - return -1
|
||||
int inference_all_reduce(torch::Tensor& data, py::object op)
|
||||
{
|
||||
if (!all_ranks_local_p) return -1;
|
||||
#ifdef DO_PROFILE
|
||||
static double total_time = 0.0;
|
||||
static double total_time_sq = 0.0;
|
||||
static int count = -16; // warmup
|
||||
static double max_time = 0.0;
|
||||
static double min_time = DBL_MAX;
|
||||
// make sure all rank reach this point before measuring time
|
||||
// turn on this if you suspect each rank didn't reach here at the same time (stragger)
|
||||
// if (all_ranks_local_p) { barrier_wait(0, world_size); }
|
||||
auto start = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
|
||||
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
|
||||
|
||||
assert(py::int_(op.attr("value")) == ReduceOpSum);
|
||||
|
||||
auto numel = data.numel();
|
||||
|
||||
int data_size = 0;
|
||||
bool data_type_fallback = false;
|
||||
|
||||
switch (data.scalar_type()) {
|
||||
case c10::ScalarType::BFloat16: data_size = numel * 2; break;
|
||||
case c10::ScalarType::Half: data_size = numel * 2; break;
|
||||
case c10::ScalarType::Float: data_size = numel * 4; break;
|
||||
default: data_type_fallback = true;
|
||||
}
|
||||
|
||||
if (data_type_fallback) return -1;
|
||||
|
||||
all_reduce_outer_loop(data, numel, data_size);
|
||||
|
||||
#ifdef DO_PROFILE
|
||||
auto end = std::chrono::system_clock::now();
|
||||
count++;
|
||||
if (count > 0) {
|
||||
double elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
|
||||
if (elapsed > max_time) { max_time = elapsed; }
|
||||
if (elapsed < min_time) { min_time = elapsed; }
|
||||
total_time += elapsed;
|
||||
total_time_sq += elapsed * elapsed;
|
||||
if (world_rank == 0 && count == 1000) {
|
||||
auto avg = total_time / count;
|
||||
auto sd =
|
||||
sqrt(total_time_sq / count - total_time * total_time / (count * count)) / avg * 100;
|
||||
printf(" C++ kernel\t\t %.2f\t %.2f\t%.2f\t %.2f\n",
|
||||
min_time,
|
||||
max_time,
|
||||
total_time / count,
|
||||
sd);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("initialize", &initialize, "shm initialize");
|
||||
m.def("get_rank", &get_rank, "get rank");
|
||||
m.def("get_world_size", &get_world_size, "get world size");
|
||||
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
|
||||
}
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "quantize.h"
|
||||
#include "fp_quantize.h"
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
|
@ -78,8 +78,39 @@ void dequantize(torch::Tensor& val,
|
|||
#endif
|
||||
}
|
||||
|
||||
#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \
|
||||
if (val.options().dtype() == torch::T_TYPE) { \
|
||||
launch_selective_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
|
||||
(C_TYPE*)val.data_ptr(), \
|
||||
(int32_t*)indexes.data_ptr(), \
|
||||
num_groups, \
|
||||
group_size, \
|
||||
num_indexes, \
|
||||
q_mantisa_bits, \
|
||||
q_exponent_bits, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
return; \
|
||||
}
|
||||
void selective_dequantize(torch::Tensor& val,
|
||||
torch::Tensor& val_q,
|
||||
torch::Tensor& indexes,
|
||||
int group_size,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits)
|
||||
{
|
||||
int total_elems = at::numel(val);
|
||||
int num_indexes = indexes.size(0);
|
||||
int num_groups = total_elems / group_size;
|
||||
|
||||
DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7);
|
||||
#endif
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("quantize", &quantize, "quantize function");
|
||||
m.def("dequantize", &dequantize, "dequantize function");
|
||||
m.def("selective_dequantize", &selective_dequantize, "selective dequantize function");
|
||||
}
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
#include <stdexcept>
|
||||
#include "context.h"
|
||||
#include "fp_quantize.h"
|
||||
#include "memory_access_utils.h"
|
||||
#include "quantize.h"
|
||||
#include "reduction_utils.h"
|
||||
|
||||
#include <cuda.h>
|
||||
|
@ -219,119 +219,101 @@ __global__ void apply_quantization(T* val,
|
|||
}
|
||||
|
||||
template <typename T,
|
||||
int unroll,
|
||||
int q_mantisa_bits,
|
||||
int total_q_bits = 16,
|
||||
int _mantisa_bits = 3,
|
||||
int _exponent_bits = 4>
|
||||
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size)
|
||||
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
|
||||
{
|
||||
int tidx = threadIdx.x;
|
||||
int wid = tidx >> 5;
|
||||
int lane = tidx & 0x1f;
|
||||
int gid = blockIdx.x * quantization::warps + wid;
|
||||
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
|
||||
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;
|
||||
|
||||
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
|
||||
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
|
||||
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
|
||||
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
|
||||
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
|
||||
|
||||
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
|
||||
constexpr uint32_t load_stride = vector_size * hw_warp_size;
|
||||
const uint32_t thread_offset = lane * vector_size;
|
||||
const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8;
|
||||
const uint32_t base_load_offset =
|
||||
gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset
|
||||
const uint32_t base_store_offset = gid * group_size + thread_offset;
|
||||
const uint8_t* load_base_ptr = val + base_load_offset;
|
||||
const uint32_t g_index = (tidx / group_size);
|
||||
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
|
||||
const uint8_t* load_base_ptr =
|
||||
val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8;
|
||||
|
||||
int mantisa_mask = ((1 << q_mantisa_bits) - 1);
|
||||
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
|
||||
|
||||
T* store_base_ptr = q_val + base_store_offset;
|
||||
float scale; //= q_scale[gid];
|
||||
T* store_base_ptr = q_val + tidx;
|
||||
float scale;
|
||||
|
||||
uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
scale_as_int8,
|
||||
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
|
||||
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
|
||||
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) +
|
||||
val + g_index * (group_size_bytes + 4) + group_size_bytes +
|
||||
quantization::quanitzed_access_granularity_6bits);
|
||||
} else
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
scale_as_int8,
|
||||
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
|
||||
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < unroll; i++) {
|
||||
if (i * load_stride + thread_offset < group_size) {
|
||||
uint64_t q_buf_in;
|
||||
uint64_t q_buf_in1;
|
||||
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
|
||||
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
|
||||
uint32_t loading_offset = i * load_stride * quantized_bits / 8;
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data, load_base_ptr + loading_offset);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity_6bits);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity_6bits * 2);
|
||||
} else {
|
||||
if (tidx < total_num_elements) {
|
||||
uint64_t q_buf_in;
|
||||
uint64_t q_buf_in1;
|
||||
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
|
||||
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data, load_base_ptr);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits,
|
||||
load_base_ptr + quantization::quanitzed_access_granularity_6bits);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
|
||||
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
|
||||
|
||||
} else {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
|
||||
load_base_ptr);
|
||||
if (quantized_bits > 4) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data, load_base_ptr + loading_offset);
|
||||
if (quantized_bits > 4) {
|
||||
int8_data + quantization::quanitzed_access_granularity,
|
||||
load_base_ptr + quantization::quanitzed_access_granularity);
|
||||
if (quantized_bits == 12) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data + quantization::quanitzed_access_granularity,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity);
|
||||
if (quantized_bits == 12) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data1,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity * 2);
|
||||
}
|
||||
int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
|
||||
}
|
||||
}
|
||||
T store_buf[vector_size];
|
||||
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vector_size; j++) {
|
||||
uint16_t new_data;
|
||||
if (j < 5 || quantized_bits != 12) {
|
||||
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
|
||||
} else {
|
||||
if (j == 5) {
|
||||
new_data = (uint16_t)(q_buf_in1);
|
||||
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
|
||||
} else
|
||||
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
|
||||
}
|
||||
|
||||
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
|
||||
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
|
||||
uint16_t dst_mantisa = (new_data & _mantisa_mask);
|
||||
|
||||
if (dst_exponent != (1 << q_exponent_bits) - 1)
|
||||
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
|
||||
(1 << (q_exponent_bits - 1)) - 1;
|
||||
|
||||
q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) |
|
||||
(dst_exponent << q_mantisa_bits) |
|
||||
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
|
||||
float up_cast = conversion::to<float>(store_buf[j]);
|
||||
store_buf[j] = conversion::to<T>(up_cast * scale);
|
||||
}
|
||||
mem_access::store_global<quantization::access_granularity>(
|
||||
store_base_ptr + i * load_stride, store_buf);
|
||||
}
|
||||
T store_buf[vector_size];
|
||||
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vector_size; j++) {
|
||||
uint16_t new_data;
|
||||
if (j < 5 || quantized_bits != 12) {
|
||||
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
|
||||
} else {
|
||||
if (j == 5) {
|
||||
new_data = (uint16_t)(q_buf_in1);
|
||||
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
|
||||
} else
|
||||
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
|
||||
}
|
||||
|
||||
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
|
||||
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
|
||||
uint16_t dst_mantisa = (new_data & _mantisa_mask);
|
||||
|
||||
if (dst_exponent != (1 << q_exponent_bits) - 1)
|
||||
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
|
||||
(1 << (q_exponent_bits - 1)) - 1;
|
||||
|
||||
q_buf[j] =
|
||||
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
|
||||
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
|
||||
float up_cast = conversion::to<float>(store_buf[j]);
|
||||
store_buf[j] = conversion::to<T>(up_cast * scale);
|
||||
}
|
||||
mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -386,12 +368,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
|
|||
#endif
|
||||
INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);
|
||||
|
||||
#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \
|
||||
case COUNT: \
|
||||
apply_dequantization<T, COUNT, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS> \
|
||||
<<<grid, block, 0, stream>>>(val, q_val, group_size); \
|
||||
break;
|
||||
|
||||
template <typename T, int mantisa>
|
||||
void launch_dequantization(uint8_t* val,
|
||||
T* q_val,
|
||||
|
@ -401,21 +377,14 @@ void launch_dequantization(uint8_t* val,
|
|||
int q_exponent_bits,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
|
||||
int blocks = ((num_groups * group_size) - 1) /
|
||||
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
|
||||
1;
|
||||
const dim3 grid(blocks);
|
||||
const dim3 block(quantization::threads);
|
||||
|
||||
constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
|
||||
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;
|
||||
|
||||
DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
|
||||
switch (copy_unroll) {
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(1)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(2)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(3)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(4)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(5)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(6)
|
||||
}
|
||||
apply_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
|
||||
<<<grid, block, 0, stream>>>(val, q_val, group_size, (num_groups * group_size));
|
||||
});
|
||||
}
|
||||
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
|
||||
|
@ -425,3 +394,137 @@ void launch_dequantization(uint8_t* val,
|
|||
INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
|
||||
#endif
|
||||
INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);
|
||||
|
||||
template <typename T,
|
||||
int q_mantisa_bits,
|
||||
int total_q_bits = 16,
|
||||
int _mantisa_bits = 3,
|
||||
int _exponent_bits = 4>
|
||||
__global__ void apply_selective_dequantization(uint8_t* val,
|
||||
T* q_val,
|
||||
int32_t* indexes,
|
||||
int group_size,
|
||||
int total_num_elements)
|
||||
{
|
||||
int index = indexes[blockIdx.x];
|
||||
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
|
||||
int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
|
||||
int input_index = index * total_num_elements + tidx;
|
||||
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
|
||||
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
|
||||
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
|
||||
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
|
||||
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
|
||||
const uint32_t g_index = (input_index / group_size);
|
||||
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
|
||||
const uint8_t* load_base_ptr =
|
||||
val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8;
|
||||
|
||||
int mantisa_mask = ((1 << q_mantisa_bits) - 1);
|
||||
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
|
||||
|
||||
T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements;
|
||||
float scale;
|
||||
|
||||
uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
|
||||
val + g_index * (group_size_bytes + 4) + group_size_bytes +
|
||||
quantization::quanitzed_access_granularity_6bits);
|
||||
} else
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
|
||||
|
||||
if (tidx < total_num_elements) {
|
||||
uint64_t q_buf_in;
|
||||
uint64_t q_buf_in1;
|
||||
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
|
||||
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data, load_base_ptr);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits,
|
||||
load_base_ptr + quantization::quanitzed_access_granularity_6bits);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
|
||||
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
|
||||
} else {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
|
||||
load_base_ptr);
|
||||
if (quantized_bits > 4) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data + quantization::quanitzed_access_granularity,
|
||||
load_base_ptr + quantization::quanitzed_access_granularity);
|
||||
if (quantized_bits == 12) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
T store_buf[vector_size];
|
||||
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vector_size; j++) {
|
||||
uint16_t new_data;
|
||||
if (j < 5 || quantized_bits != 12) {
|
||||
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
|
||||
} else {
|
||||
if (j == 5) {
|
||||
new_data = (uint16_t)(q_buf_in1);
|
||||
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
|
||||
} else
|
||||
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
|
||||
}
|
||||
|
||||
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
|
||||
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
|
||||
uint16_t dst_mantisa = (new_data & _mantisa_mask);
|
||||
|
||||
if (dst_exponent != (1 << q_exponent_bits) - 1)
|
||||
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
|
||||
(1 << (q_exponent_bits - 1)) - 1;
|
||||
|
||||
q_buf[j] =
|
||||
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
|
||||
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
|
||||
float up_cast = conversion::to<float>(store_buf[j]);
|
||||
store_buf[j] = conversion::to<T>(up_cast * scale);
|
||||
}
|
||||
mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int mantisa>
|
||||
void launch_selective_dequantization(uint8_t* val,
|
||||
T* q_val,
|
||||
int32_t* indexes,
|
||||
int num_groups,
|
||||
int group_size,
|
||||
int num_indexes,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int total_elements_per_index = (num_groups / num_indexes) * group_size;
|
||||
int blocks = (total_elements_per_index - 1) /
|
||||
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
|
||||
1;
|
||||
const dim3 grid(num_indexes, blocks);
|
||||
const dim3 block(quantization::threads);
|
||||
DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
|
||||
apply_selective_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
|
||||
<<<grid, block, 0, stream>>>(val, q_val, indexes, group_size, total_elements_per_index);
|
||||
});
|
||||
}
|
||||
#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \
|
||||
template void launch_selective_dequantization<T, mantisa>( \
|
||||
uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t);
|
||||
// fp8(E4M3)
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7);
|
||||
#endif
|
||||
INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10);
|
|
@ -113,3 +113,14 @@ void launch_dequantization(uint8_t* val,
|
|||
int q_mantisa_bits,
|
||||
int q_exponent_bits,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T, int mantisa>
|
||||
void launch_selective_dequantization(uint8_t* val,
|
||||
T* q_val,
|
||||
int32_t* indexes,
|
||||
int num_groups,
|
||||
int group_size,
|
||||
int num_indexes,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits,
|
||||
cudaStream_t stream);
|
|
@ -9,84 +9,35 @@
|
|||
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
|
||||
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
#define STEP(SPAN) \
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
||||
ds_params_percision_t* grads, \
|
||||
ds_state_precision_t* _exp_avg_sq, \
|
||||
size_t _param_size);
|
||||
|
||||
class Adagrad_Optimizer {
|
||||
public:
|
||||
Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
|
||||
: _alpha(alpha), _eps(eps), _weight_decay(weight_decay)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Adagrad_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
~Adagrad_Optimizer() {}
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t param_size);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step)
|
||||
{
|
||||
_step++;
|
||||
|
@ -107,29 +58,22 @@ private:
|
|||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
bool _buf_index;
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
#if !defined(__AVX512__)
|
||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
size_t new_rounded_size = 0;
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
@ -145,24 +89,19 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + i, half_precision);
|
||||
simd_load<span>(grad_4, grads + i);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, grads + i, false);
|
||||
simd_load<span>(momentum_4, grads + i);
|
||||
|
||||
AVX_Data variance_4[span];
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i, false);
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + i, half_precision);
|
||||
simd_load<span>(param_4, _params + i);
|
||||
|
||||
if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }
|
||||
|
||||
|
@ -172,38 +111,9 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
simd_div<span>(grad_4, momentum_4, grad_4);
|
||||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + i, param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4, false);
|
||||
simd_store<span>(_params + i, param_4);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memcpy_size /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
|
|
|
@ -13,29 +13,13 @@
|
|||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
#define STEP(SPAN) \
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
||||
ds_params_percision_t* grads, \
|
||||
ds_state_precision_t* _exp_avg, \
|
||||
ds_state_precision_t* _exp_avg_sq, \
|
||||
size_t _param_size);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
|
@ -55,56 +39,21 @@ public:
|
|||
_step(0),
|
||||
_adamw_mode(adamw_mode)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Adam_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
~Adam_Optimizer() {}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t param_size);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
|
@ -154,32 +103,24 @@ private:
|
|||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
#if !defined(__AVX512__)
|
||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
size_t new_rounded_size = 0;
|
||||
int rshft = half_precision ? 1 : 0;
|
||||
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
|
@ -212,24 +153,19 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
|
||||
simd_load<span>(grad_4, grads + i);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, _exp_avg + i, false);
|
||||
simd_load<span>(momentum_4, _exp_avg + i);
|
||||
|
||||
AVX_Data variance_4[span];
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i, false);
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
|
||||
simd_load<span>(param_4, _params + i);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
|
||||
|
@ -250,39 +186,10 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
|
||||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg + i, momentum_4, false);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4, false);
|
||||
simd_store<span>(_params + i, param_4);
|
||||
simd_store<span>(_exp_avg + i, momentum_4);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memcpy_size /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
|
@ -310,18 +217,4 @@ int ds_adam_step(int optimizer_id,
|
|||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq);
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params);
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id);
|
||||
|
|
|
@ -13,28 +13,12 @@
|
|||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
#define STEP(SPAN) \
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
||||
ds_params_percision_t* grads, \
|
||||
ds_state_precision_t* _exp_avg, \
|
||||
size_t _param_size);
|
||||
|
||||
class Lion_Optimizer {
|
||||
public:
|
||||
|
@ -44,55 +28,21 @@ public:
|
|||
float weight_decay = 0)
|
||||
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Lion_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
~Lion_Optimizer() {}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t param_size);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
_step++;
|
||||
|
@ -114,31 +64,23 @@ private:
|
|||
float _betta2;
|
||||
float _weight_decay;
|
||||
size_t _step;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
#if !defined(__AVX512__)
|
||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
size_t new_rounded_size = 0;
|
||||
int rshft = half_precision ? 1 : 0;
|
||||
|
||||
constexpr float neg1 = -1.0f;
|
||||
AVX_Data neg1_4;
|
||||
|
@ -169,21 +111,17 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
|
||||
simd_load<span>(grad_4, grads + i);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, _exp_avg + i, false);
|
||||
simd_load<span>(momentum_4, _exp_avg + i);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
|
||||
simd_load<span>(param_4, _params + i);
|
||||
|
||||
AVX_Data tmp_4[span];
|
||||
|
||||
|
@ -201,38 +139,9 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
simd_mul<span>(momentum_4, momentum_4, betta2_4);
|
||||
simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg + i, momentum_4, false);
|
||||
simd_store<span>(_params + i, param_4);
|
||||
simd_store<span>(_exp_avg + i, momentum_4);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memcpy_size /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
|
@ -255,15 +164,4 @@ int ds_lion_step(int optimizer_id,
|
|||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg);
|
||||
|
||||
int ds_lion_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params);
|
||||
|
||||
int destroy_lion_optimizer(int optimizer_id);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <rocblas/rocblas.h>
|
||||
#endif
|
||||
#include <stdio.h>
|
||||
#include <torch/version.h>
|
||||
|
||||
int cublas_gemm_ex(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
|
@ -29,7 +30,9 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
const float* A,
|
||||
const float* B,
|
||||
float* C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
||||
|
@ -46,7 +49,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
const __half* A,
|
||||
const __half* B,
|
||||
__half* C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -67,7 +71,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
int stride_B,
|
||||
int stride_C,
|
||||
int batch,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
|
||||
|
@ -88,7 +93,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
int stride_B,
|
||||
int stride_C,
|
||||
int batch,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
|
|
@ -272,9 +272,6 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
|
|||
int cols,
|
||||
cudaStream_t stream);
|
||||
|
||||
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
|
||||
void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);
|
||||
|
||||
void launch_token_sort(int32_t* indices,
|
||||
int layers,
|
||||
int batch_size,
|
||||
|
|
|
@ -23,7 +23,7 @@ used throughout the codebase.
|
|||
#ifdef __HIP_PLATFORM_AMD__
|
||||
|
||||
// constexpr variant of warpSize for templating
|
||||
constexpr int hw_warp_size = 64;
|
||||
constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE;
|
||||
#define HALF_PRECISION_AVAILABLE = 1
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
|
|
@ -48,7 +48,9 @@ public:
|
|||
weights,
|
||||
input_ptr,
|
||||
out,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(config_.gemm_algos[0]));
|
||||
#else
|
||||
cublasGemmAlgo_t(config_.gemm_algos[0]));
|
||||
|
@ -77,7 +79,8 @@ public:
|
|||
input_ptr,
|
||||
out_grad,
|
||||
weights_grad,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(config_.gemm_algos[1]));
|
||||
#else
|
||||
cublasGemmAlgo_t(config_.gemm_algos[1]));
|
||||
|
@ -94,7 +97,8 @@ public:
|
|||
weights,
|
||||
out_grad,
|
||||
inp_grad_out,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(config_.gemm_algos[2]));
|
||||
#else
|
||||
cublasGemmAlgo_t(config_.gemm_algos[2]));
|
||||
|
|
|
@ -67,7 +67,9 @@ public:
|
|||
B,
|
||||
A,
|
||||
C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
static_cast<rocblas_gemm_algo>(algo));
|
||||
#else
|
||||
static_cast<cublasGemmAlgo_t>(algo));
|
||||
|
@ -86,7 +88,8 @@ public:
|
|||
A,
|
||||
C,
|
||||
B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
static_cast<rocblas_gemm_algo>(algo));
|
||||
#else
|
||||
static_cast<cublasGemmAlgo_t>(algo));
|
||||
|
@ -105,7 +108,8 @@ public:
|
|||
B,
|
||||
C,
|
||||
A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
static_cast<rocblas_gemm_algo>(algo));
|
||||
#else
|
||||
static_cast<cublasGemmAlgo_t>(algo));
|
||||
|
@ -121,8 +125,11 @@ public:
|
|||
float fast_latency = (std::numeric_limits<float>::max)();
|
||||
int fast_algo = 0;
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
for (int algo = (int)HIPBLAS_GEMM_DEFAULT; algo <= (int)HIPBLAS_GEMM_DEFAULT;
|
||||
#else
|
||||
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
|
||||
|
@ -211,7 +218,8 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
bsz,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
static_cast<rocblas_gemm_algo>(algo));
|
||||
#else
|
||||
static_cast<cublasGemmAlgo_t>(algo));
|
||||
|
@ -245,7 +253,8 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
bsz,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
static_cast<rocblas_gemm_algo>(algo));
|
||||
#else
|
||||
static_cast<cublasGemmAlgo_t>(algo));
|
||||
|
@ -276,7 +285,8 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
bsz,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
static_cast<rocblas_gemm_algo>(algo));
|
||||
#else
|
||||
static_cast<cublasGemmAlgo_t>(algo));
|
||||
|
@ -292,11 +302,17 @@ public:
|
|||
float fast_latency = (std::numeric_limits<float>::max)();
|
||||
int fast_algo = 0;
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard;
|
||||
#else
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
algo <= (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
#else
|
||||
for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
|
||||
#endif
|
||||
#endif
|
||||
algo++) {
|
||||
int warm_up = 5;
|
||||
|
|
|
@ -12,6 +12,20 @@
|
|||
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
#include <immintrin.h>
|
||||
|
||||
template <typename T>
|
||||
inline T readAs(const void* src)
|
||||
{
|
||||
T res;
|
||||
std::memcpy(&res, src, sizeof(T));
|
||||
return res;
|
||||
}
|
||||
template <typename T>
|
||||
inline void writeAs(void* dst, const T& val)
|
||||
{
|
||||
std::memcpy(dst, &val, sizeof(T));
|
||||
}
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
|
||||
|
@ -30,11 +44,52 @@
|
|||
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 16
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x))
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm512_storeu_ps(x, d))
|
||||
static __m512 load_16_bf16_as_f32(const void* data)
|
||||
{
|
||||
__m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing
|
||||
__m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
|
||||
__m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
|
||||
// 16 bits (representing bf16->f32)
|
||||
return readAs<__m512>(&c); // use memcpy to avoid aliasing
|
||||
}
|
||||
|
||||
static void store_16_f32_as_bf16_nearest(__m512 v, void* data)
|
||||
{
|
||||
__m512i u32 = readAs<__m512i>(&v);
|
||||
|
||||
// flow assuming non-nan:
|
||||
|
||||
// uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
|
||||
__m512i b = _mm512_srli_epi32(u32, 16);
|
||||
__m512i lsb_mask = _mm512_set1_epi32(0x00000001);
|
||||
__m512i c = _mm512_and_si512(b, lsb_mask);
|
||||
__m512i bias_constant = _mm512_set1_epi32(0x00007fff);
|
||||
__m512i rounding_bias = _mm512_add_epi32(c, bias_constant);
|
||||
|
||||
// uint16_t res = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
||||
__m512i d = _mm512_add_epi32(u32, rounding_bias);
|
||||
__m512i e = _mm512_srli_epi32(d, 16);
|
||||
__m256i non_nan_res = _mm512_cvtusepi32_epi16(e);
|
||||
|
||||
// handle nan (exp is all 1s and mantissa != 0)
|
||||
// if ((x & 0x7fffffffU) > 0x7f800000U)
|
||||
__m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff);
|
||||
__m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign);
|
||||
__m512i nan_threshold = _mm512_set1_epi32(0x7f800000);
|
||||
__mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT);
|
||||
|
||||
// mix in results with nans as needed
|
||||
__m256i nans = _mm256_set1_epi16(0x7fc0);
|
||||
__m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans);
|
||||
|
||||
writeAs(data, res);
|
||||
}
|
||||
#define SIMD_LOAD_BF16(x) load_16_bf16_as_f32(x)
|
||||
#define SIMD_STORE_BF16(x, d) store_16_f32_as_bf16_nearest(d, x)
|
||||
|
||||
#define SIMD_LOAD_FP16(x) _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x)))
|
||||
#define SIMD_STORE_FP16(x, d) \
|
||||
_mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#define INTV __m256i
|
||||
#elif defined(__AVX256__)
|
||||
|
@ -52,11 +107,11 @@
|
|||
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 8
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) : _mm256_loadu_ps(x))
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm256_storeu_ps(x, d))
|
||||
#define SIMD_LOAD_BF16(x) static_assert(false && "AVX256 does not support BFloat16")
|
||||
#define SIMD_STORE_BF16(x, d) static_assert(false && "AVX256 does not support BFloat16")
|
||||
#define SIMD_LOAD_FP16(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x))
|
||||
#define SIMD_STORE_FP16(x, d) \
|
||||
_mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#define INTV __m128i
|
||||
#endif
|
||||
|
@ -70,20 +125,66 @@ union AVX_Data {
|
|||
// float data_f[16];
|
||||
};
|
||||
|
||||
template <int span>
|
||||
inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::Half>, void> simd_store(T* dst,
|
||||
AVX_Data* src)
|
||||
{
|
||||
size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); }
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE_FP16((float*)(dst + width * i), src[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::BFloat16>, void> simd_store(T* dst,
|
||||
AVX_Data* src)
|
||||
{
|
||||
size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
|
||||
#ifdef __AVX512__
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); }
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE_BF16((float*)(dst + width * i), src[i].data); }
|
||||
#else
|
||||
throw std::runtime_error("AVX512 required for BFloat16");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, float>, void> simd_store(T* dst, AVX_Data* src)
|
||||
{
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE(dst + width * i, src[i].data); }
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::Half>, void> simd_load(AVX_Data* dst,
|
||||
T* src)
|
||||
{
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_FP16((float*)(src + width * i)); }
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::BFloat16>, void> simd_load(AVX_Data* dst,
|
||||
T* src)
|
||||
{
|
||||
#ifdef __AVX512__
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_BF16((float*)(src + width * i)); }
|
||||
#else
|
||||
throw std::runtime_error("AVX512 required for BFloat16");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, float>, void> simd_load(AVX_Data* dst, T* src)
|
||||
{
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD(src + width * i); }
|
||||
}
|
||||
|
||||
template <int span>
|
||||
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
|
||||
{
|
||||
|
|
|
@ -77,7 +77,9 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
bsz,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(_config.gemm_algos[0]));
|
||||
#else
|
||||
cublasGemmAlgo_t(_config.gemm_algos[0]));
|
||||
|
@ -105,7 +107,8 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
_config.batch_size,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(_config.gemm_algos[0]));
|
||||
#else
|
||||
cublasGemmAlgo_t(_config.gemm_algos[0]));
|
||||
|
@ -149,7 +152,8 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
bsz,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(_config.gemm_algos[1]));
|
||||
#else
|
||||
cublasGemmAlgo_t(_config.gemm_algos[1]));
|
||||
|
@ -178,7 +182,8 @@ public:
|
|||
stride_b,
|
||||
stride_c,
|
||||
bsz,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo(_config.gemm_algos[2]));
|
||||
#else
|
||||
cublasGemmAlgo_t(_config.gemm_algos[2]));
|
||||
|
|
|
@ -8,9 +8,6 @@
|
|||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)");
|
||||
m.def("lion_update_copy",
|
||||
&ds_lion_step_plus_copy,
|
||||
"DeepSpeed CPU Lion update and param copy (C++)");
|
||||
m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)");
|
||||
m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)");
|
||||
}
|
||||
|
|
|
@ -6,34 +6,28 @@
|
|||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_lion.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Lion_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_1(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
|
@ -41,26 +35,15 @@ void Lion_Optimizer::Step_1(float* _params,
|
|||
|
||||
float alpha = _alpha;
|
||||
float after_decay = 1 - alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float grad = (float)grads[k];
|
||||
float param = (float)_params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float tmp = momentum * _betta1;
|
||||
tmp = grad * betta1_minus1 + tmp;
|
||||
|
@ -74,56 +57,28 @@ void Lion_Optimizer::Step_1(float* _params,
|
|||
}
|
||||
momentum = momentum * _betta2;
|
||||
momentum = grad * betta2_minus1 + momentum;
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Lion_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_4(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
int create_lion_optimizer(int optimizer_id,
|
||||
|
@ -162,24 +117,76 @@ int create_lion_optimizer(int optimizer_id,
|
|||
return 0;
|
||||
}
|
||||
|
||||
void Lion_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_8(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void step_invoker(std::shared_ptr<Lion_Optimizer> opt,
|
||||
void* _params,
|
||||
void* grads,
|
||||
void* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
opt->Step_8((ds_params_percision_t*)(_params),
|
||||
(ds_params_percision_t*)(grads),
|
||||
(ds_state_precision_t*)(_exp_avg),
|
||||
_param_size);
|
||||
}
|
||||
|
||||
std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
||||
std::function<void(std::shared_ptr<Lion_Optimizer>, void*, void*, void*, size_t)>>
|
||||
invokers;
|
||||
|
||||
// Fill map with template functions for each type
|
||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
||||
void create_invoker()
|
||||
{
|
||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
||||
}
|
||||
struct InvokerInitializer {
|
||||
InvokerInitializer()
|
||||
{
|
||||
create_invoker<c10::Half, float>();
|
||||
create_invoker<c10::Half, c10::Half>();
|
||||
create_invoker<c10::BFloat16, float>();
|
||||
create_invoker<c10::BFloat16, c10::BFloat16>();
|
||||
create_invoker<float, float>();
|
||||
}
|
||||
} _invoker_initializer;
|
||||
|
||||
void invoke(std::shared_ptr<Lion_Optimizer> opt,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
size_t param_size)
|
||||
{
|
||||
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
|
||||
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype());
|
||||
|
||||
auto it = invokers.find(std::tuple(params_type, state_type));
|
||||
if (it == invokers.end()) {
|
||||
throw std::runtime_error("Lion optimizer with param type "s + c10::toString(params_type) +
|
||||
" and state type "s + c10::toString(state_type) +
|
||||
" is not supported on current hardware"s);
|
||||
}
|
||||
|
||||
it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg.data_ptr(), param_size);
|
||||
}
|
||||
|
||||
int ds_lion_step(int optimizer_id,
|
||||
|
@ -196,67 +203,13 @@ int ds_lion_step(int optimizer_id,
|
|||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Lion_Optimizer> opt =
|
||||
std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, weight_decay);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
invoke(opt, params_c, grads_c, exp_avg_c, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_lion_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Lion_Optimizer> opt =
|
||||
std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, weight_decay);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
params_c.numel(),
|
||||
gpu_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -241,7 +241,7 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
|
|||
.device(at::kCUDA)
|
||||
.requires_grad(false);
|
||||
|
||||
std::vector<long int> sz(input_vals.sizes().begin(), input_vals.sizes().end());
|
||||
std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
|
||||
sz[sz.size() - 1] = sz.back() / devices_per_node; // num of GPU per nodes
|
||||
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
|
||||
auto output = torch::empty(sz, output_options);
|
||||
|
|
|
@ -16,7 +16,7 @@ constexpr int mem_vals = granularity / sizeof(int32_t);
|
|||
constexpr int max_buffer_size = (threads + 1) * mem_vals;
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
constexpr int warp_size = 64;
|
||||
constexpr int warp_size = ROCM_WAVEFRONT_SIZE;
|
||||
#else
|
||||
constexpr int warp_size = 32;
|
||||
#endif
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
#include "cublas_wrappers.h"
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_gemm_ex(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
|
@ -33,7 +35,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
cublasGemmAlgo_t algo)
|
||||
#endif
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_status status = rocblas_gemm_ex(handle,
|
||||
transa,
|
||||
transb,
|
||||
|
@ -67,20 +70,39 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
k,
|
||||
(const void*)alpha,
|
||||
(const void*)A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(transa == CUBLAS_OP_N) ? m : k,
|
||||
(const void*)B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(transb == CUBLAS_OP_N) ? k : n,
|
||||
(const void*)beta,
|
||||
C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
m,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -96,7 +118,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_gemm_ex(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
|
@ -124,7 +147,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
cublasGemmAlgo_t algo)
|
||||
#endif
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_status status = rocblas_gemm_ex(handle,
|
||||
transa,
|
||||
transb,
|
||||
|
@ -158,20 +182,39 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
k,
|
||||
(const void*)alpha,
|
||||
(const void*)A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_16F,
|
||||
#else
|
||||
CUDA_R_16F,
|
||||
#endif
|
||||
(transa == CUBLAS_OP_N) ? m : k,
|
||||
(const void*)B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_16F,
|
||||
#else
|
||||
CUDA_R_16F,
|
||||
#endif
|
||||
(transb == CUBLAS_OP_N) ? k : n,
|
||||
(const void*)beta,
|
||||
(void*)C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_16F,
|
||||
#else
|
||||
CUDA_R_16F,
|
||||
#endif
|
||||
m,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -187,7 +230,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_strided_batched_gemm(rocblas_handle handle,
|
||||
int m,
|
||||
int n,
|
||||
|
@ -223,7 +267,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
cublasGemmAlgo_t algo)
|
||||
#endif
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_status status =
|
||||
rocblas_gemm_strided_batched_ex(handle,
|
||||
op_A,
|
||||
|
@ -263,24 +308,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
k,
|
||||
alpha,
|
||||
A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(op_A == CUBLAS_OP_N) ? m : k,
|
||||
stride_A,
|
||||
B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(op_B == CUBLAS_OP_N) ? k : n,
|
||||
stride_B,
|
||||
beta,
|
||||
C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
m,
|
||||
stride_C,
|
||||
batch,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -297,7 +361,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_strided_batched_gemm(rocblas_handle handle,
|
||||
int m,
|
||||
int n,
|
||||
|
@ -333,7 +398,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
cublasGemmAlgo_t algo)
|
||||
#endif
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_status status =
|
||||
rocblas_gemm_strided_batched_ex(handle,
|
||||
op_A,
|
||||
|
@ -373,24 +439,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
k,
|
||||
alpha,
|
||||
A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_16F,
|
||||
#else
|
||||
CUDA_R_16F,
|
||||
#endif
|
||||
(op_A == CUBLAS_OP_N) ? m : k,
|
||||
stride_A,
|
||||
B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_16F,
|
||||
#else
|
||||
CUDA_R_16F,
|
||||
#endif
|
||||
(op_B == CUBLAS_OP_N) ? k : n,
|
||||
stride_B,
|
||||
beta,
|
||||
C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_16F,
|
||||
#else
|
||||
CUDA_R_16F,
|
||||
#endif
|
||||
m,
|
||||
stride_C,
|
||||
batch,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
|
|
@ -99,17 +99,9 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
|
|||
rope_theta, \
|
||||
max_out_tokens);
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64
|
||||
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
|
||||
if (threads_per_head == 4) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \
|
||||
} else if (threads_per_head == 8) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \
|
||||
} else if (threads_per_head == 16) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \
|
||||
} else if (threads_per_head == 32) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \
|
||||
} else if (threads_per_head == 64) { \
|
||||
if (threads_per_head == 64) { \
|
||||
LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \
|
||||
} else { \
|
||||
assert(false); \
|
||||
|
|
|
@ -163,7 +163,9 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
|
|||
(T*)W.data_ptr(),
|
||||
(T*)Q.data_ptr(),
|
||||
(T*)O.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -216,7 +218,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
|||
seq_len * k,
|
||||
seq_len * soft_len,
|
||||
bsz * heads,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -253,7 +256,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
|
|||
seq_len * soft_len,
|
||||
seq_len * k,
|
||||
bsz * heads,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -388,7 +392,8 @@ void attention_unfused(T* prev_key_cont,
|
|||
seq_len * k,
|
||||
seq_len * soft_len,
|
||||
bsz * heads,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -421,7 +426,8 @@ void attention_unfused(T* prev_key_cont,
|
|||
seq_len * soft_len,
|
||||
seq_len * k,
|
||||
bsz * heads,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -536,22 +542,23 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
|||
1);
|
||||
|
||||
if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
|
||||
auto prev_key = torch::from_blob(workspace + offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
auto prev_key = torch::from_blob(
|
||||
workspace + offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
|
||||
auto prev_value =
|
||||
torch::from_blob(workspace + offset + value_offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
auto prev_value = torch::from_blob(
|
||||
workspace + offset + value_offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
|
||||
return {output, prev_key, prev_value};
|
||||
}
|
||||
|
@ -886,7 +893,8 @@ void quantized_gemm(void* output,
|
|||
weight16,
|
||||
(T*)input,
|
||||
(T*)output,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -931,7 +939,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
|
|||
(T*)weight.data_ptr(),
|
||||
workspace,
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1003,7 +1012,8 @@ std::vector<at::Tensor> ds_rms_qkv(at::Tensor& input,
|
|||
(T*)weight.data_ptr(),
|
||||
(T*)rms_norm.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1089,7 +1099,8 @@ void quantized_gemm(at::Tensor& output,
|
|||
(T*)weight16.data_ptr(),
|
||||
(T*)input.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1135,7 +1146,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
|
|||
(T*)weight.data_ptr(),
|
||||
(T*)input_cont.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1353,7 +1365,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
|
|||
(T*)weight.data_ptr(),
|
||||
(T*)input.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1439,7 +1452,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
|||
(T*)weight.data_ptr(),
|
||||
inp_norm,
|
||||
intermediate,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1483,7 +1497,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
|
|||
(T*)weight1.data_ptr(),
|
||||
intermediate,
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1578,7 +1593,9 @@ std::vector<at::Tensor> ds_rms_mlp_gemm(at::Tensor& input,
|
|||
auto output = at::from_blob(output_ptr, input.sizes(), options);
|
||||
auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options);
|
||||
auto intermediate_gemm =
|
||||
at::from_blob(intermediate_ptr, {input.size(0), input.size(1), mlp_1_out_neurons}, options);
|
||||
at::from_blob(intermediate_ptr,
|
||||
{input.size(0), input.size(1), static_cast<int64_t>(mlp_1_out_neurons)},
|
||||
options);
|
||||
|
||||
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
|
||||
|
||||
|
@ -1617,7 +1634,8 @@ std::vector<at::Tensor> ds_rms_mlp_gemm(at::Tensor& input,
|
|||
(T*)weight_interm.data_ptr(),
|
||||
(T*)inp_norm.data_ptr(),
|
||||
intermediate_ptr,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1680,7 +1698,8 @@ std::vector<at::Tensor> ds_rms_mlp_gemm(at::Tensor& input,
|
|||
(T*)weight_out.data_ptr(),
|
||||
intermediate_ptr,
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard,
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
|
@ -1742,7 +1761,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
|||
(T*)weight.data_ptr(),
|
||||
(T*)input.data_ptr(),
|
||||
(T*)intermediate.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
@ -1776,7 +1796,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
|
|||
(T*)weight_out.data_ptr(),
|
||||
(T*)intermediate.data_ptr(),
|
||||
(T*)output.data_ptr(),
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_gemm_algo_standard);
|
||||
#else
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
#endif
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_gemm_ex(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
|
@ -49,7 +51,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
#endif
|
||||
{
|
||||
const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride;
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_status status = rocblas_gemm_ex(handle,
|
||||
transa,
|
||||
transb,
|
||||
|
@ -83,20 +86,39 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
k,
|
||||
(const void*)alpha,
|
||||
(const void*)A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(transa == CUBLAS_OP_N) ? m : k,
|
||||
(const void*)B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
ldb,
|
||||
(const void*)beta,
|
||||
C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
m,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -113,7 +135,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_gemm_ex(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
|
@ -144,7 +167,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
#endif
|
||||
{
|
||||
const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride;
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
constexpr auto rocblas_dtype_16 = std::is_same<T, __half>::value ? rocblas_datatype_f16_r
|
||||
: rocblas_datatype_bf16_r;
|
||||
rocblas_status status = rocblas_gemm_ex(handle,
|
||||
|
@ -171,8 +195,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
algo,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
constexpr auto cublas_dtype_16 = std::is_same<T, __half>::value ? HIPBLAS_R_16F : HIPBLAS_R_16B;
|
||||
#else
|
||||
constexpr auto cublas_dtype_16 = std::is_same<T, __half>::value ? CUDA_R_16F : CUDA_R_16BF;
|
||||
#endif
|
||||
cublasStatus_t status = cublasGemmEx(handle,
|
||||
transa,
|
||||
transb,
|
||||
|
@ -190,11 +218,18 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
(void*)C,
|
||||
cublas_dtype_16,
|
||||
m,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -210,7 +245,8 @@ int cublas_gemm_ex(cublasHandle_t handle,
|
|||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_strided_batched_gemm(rocblas_handle handle,
|
||||
int m,
|
||||
int n,
|
||||
|
@ -246,7 +282,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
cublasGemmAlgo_t algo)
|
||||
#endif
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_status status =
|
||||
rocblas_gemm_strided_batched_ex(handle,
|
||||
op_A,
|
||||
|
@ -286,24 +323,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
k,
|
||||
alpha,
|
||||
A,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(op_A == CUBLAS_OP_N) ? m : k,
|
||||
stride_A,
|
||||
B,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
(op_B == CUBLAS_OP_N) ? k : n,
|
||||
stride_B,
|
||||
beta,
|
||||
C,
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
m,
|
||||
stride_C,
|
||||
batch,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -321,7 +377,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
int cublas_strided_batched_gemm(rocblas_handle handle,
|
||||
int m,
|
||||
int n,
|
||||
|
@ -357,7 +414,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
cublasGemmAlgo_t algo)
|
||||
#endif
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
constexpr auto rocblas_dtype_16 = std::is_same<T, __half>::value ? rocblas_datatype_f16_r
|
||||
: rocblas_datatype_bf16_r;
|
||||
rocblas_status status =
|
||||
|
@ -390,8 +448,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
algo,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
constexpr auto cublas_dtype_16 = std::is_same<T, __half>::value ? HIPBLAS_R_16F : HIPBLAS_R_16B;
|
||||
#else
|
||||
constexpr auto cublas_dtype_16 = std::is_same<T, __half>::value ? CUDA_R_16F : CUDA_R_16BF;
|
||||
#endif
|
||||
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
|
||||
op_A,
|
||||
op_B,
|
||||
|
@ -413,11 +475,18 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
|
|||
m,
|
||||
stride_C,
|
||||
batch,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
algo);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "cpu_adam.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
||||
m.def("adam_update_copy",
|
||||
&ds_adam_step_plus_copy,
|
||||
"DeepSpeed CPU Adam update and param copy (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
||||
}
|
|
@ -1,247 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_adam.h"
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
|
||||
param = grad * step_size + param;
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int create_adam_optimizer(int optimizer_id,
|
||||
float alpha,
|
||||
float betta1,
|
||||
float betta2,
|
||||
float eps,
|
||||
float weight_decay,
|
||||
bool adamw_mode,
|
||||
bool should_log)
|
||||
{
|
||||
auto opt =
|
||||
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
|
||||
|
||||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log) {
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
#else
|
||||
#if defined(__AVX256__)
|
||||
avx_type = "AVX2";
|
||||
#else
|
||||
avx_type = "scalar";
|
||||
#endif
|
||||
#endif
|
||||
|
||||
printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
|
||||
optimizer_id,
|
||||
avx_type.c_str());
|
||||
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
|
||||
alpha,
|
||||
betta1,
|
||||
betta2,
|
||||
weight_decay,
|
||||
(int)adamw_mode);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int ds_adam_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq)
|
||||
{
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id)
|
||||
{
|
||||
s_optimizers.erase(optimizer_id);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -10,6 +10,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <ipex.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "compat.h"
|
||||
|
@ -22,10 +23,8 @@ namespace at {
|
|||
namespace cuda {
|
||||
sycl::queue* getCurrentCUDAStream()
|
||||
{
|
||||
auto device_type = c10::DeviceType::XPU;
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
|
||||
auto& queue = xpu::get_queue_from_stream(c10_stream);
|
||||
c10::xpu::XPUStream stream = c10::xpu::getCurrentXPUStream();
|
||||
auto& queue = stream.queue();
|
||||
return &queue;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <ipex.h>
|
||||
#include <torch/extension.h>
|
||||
#include <iostream>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
using namespace sycl;
|
||||
using namespace xpu;
|
||||
|
||||
void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1)
|
||||
{
|
||||
// get the sign bit of each float and pack them into byte
|
||||
int i = item_ct1;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
int k = i * 8 + j;
|
||||
int bit = k < input_size && (!sycl::signbit(input[k]));
|
||||
output[i] |= bit << (7 - j);
|
||||
}
|
||||
}
|
||||
|
||||
void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1)
|
||||
{
|
||||
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1
|
||||
int i = item_ct1;
|
||||
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2;
|
||||
}
|
||||
|
||||
sycl::queue get_current_queue(at::Device device)
|
||||
{
|
||||
c10::impl::VirtualGuardImpl impl(device.type());
|
||||
c10::Stream _stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);
|
||||
sycl::queue queue = xpu::get_queue_from_stream(_stream);
|
||||
return queue;
|
||||
}
|
||||
|
||||
/*
|
||||
pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8
|
||||
if float x >= 0, will be packed as a '1' bit, or will be packed as '0'
|
||||
Arguments:
|
||||
tensor: A bool tensor that get packed.
|
||||
input_size: numel of input tensor
|
||||
rank: device id in order to get corresponding stream
|
||||
*/
|
||||
at::Tensor packbits(at::Tensor tensor, int input_size, int rank)
|
||||
{
|
||||
at::Device device = "xpu:" + std::to_string(rank);
|
||||
sycl::queue q = get_current_queue(device);
|
||||
|
||||
int packed_size = (input_size + 7) / 8;
|
||||
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU);
|
||||
at::Tensor packed = torch::zeros({packed_size}, unit8_options);
|
||||
|
||||
float* input = (float*)tensor.data_ptr();
|
||||
uint8_t* output = (uint8_t*)packed.data_ptr();
|
||||
|
||||
auto event = q.submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) {
|
||||
packbitskernel(input, output, input_size, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
|
||||
/*
|
||||
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float
|
||||
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1).
|
||||
Arguments:
|
||||
tensor: A uint8 tensor that get unpacked.
|
||||
input_size: numel of input tensor
|
||||
rank: device id in order to get corresponding stream
|
||||
*/
|
||||
at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank)
|
||||
{
|
||||
at::Device device = "xpu:" + std::to_string(rank);
|
||||
sycl::queue q = get_current_queue(device);
|
||||
|
||||
auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU);
|
||||
at::Tensor unpacked = torch::empty({input_size * 8}, float_options);
|
||||
|
||||
uint8_t* input = (uint8_t*)tensor.data_ptr();
|
||||
float* output = (float*)unpacked.data_ptr();
|
||||
|
||||
auto event = q.submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(range(input_size * 8),
|
||||
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); });
|
||||
});
|
||||
|
||||
return unpacked;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)");
|
||||
m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)");
|
||||
}
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
import copy
|
||||
|
||||
from numpy import BUFSIZE
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -18,7 +17,7 @@ import hjson
|
|||
from tqdm import tqdm
|
||||
|
||||
from ..utils import logger
|
||||
from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH
|
||||
from .constants import AUTOTUNING, AUTOTUNING_METRIC_PATH, BUFSIZE
|
||||
from .utils import get_val_by_key, search_error, was_interruptted
|
||||
"""
|
||||
thread-0: loop over experiment queue dispatching experiments if they become available
|
||||
|
|
|
@ -15,6 +15,6 @@ from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
|
|||
|
||||
from .zero_checkpoint import ZeROCheckpoint
|
||||
|
||||
from .universal_checkpoint import enable_universal_checkpoint
|
||||
from .universal_checkpoint import enable_universal_checkpoint, SubparamShape
|
||||
|
||||
from .constants import *
|
||||
|
|
|
@ -74,6 +74,8 @@ CAT_DIM = "cat_dim"
|
|||
# Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal.
|
||||
PARAM_N_SUB_PARAMS = "param_n_sub_params"
|
||||
|
||||
SUB_PARAM_SHAPE = "sub_param_shape"
|
||||
|
||||
# Regex list of parameters that require special handling
|
||||
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
|
||||
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
|
||||
|
@ -81,3 +83,5 @@ PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
|
|||
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
|
||||
TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
|
||||
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'
|
||||
PARAMETER_WITH_SUB_PARAMS = 'parameter_with_sub_params'
|
||||
SUB_PARAMS_SHAPE = 'sub_params_shape'
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
import torch
|
||||
|
||||
|
@ -21,6 +22,7 @@ FINAL_LAYER_NORM_INDEX = -1
|
|||
ARGS_KEY = 'args'
|
||||
CHECKPOINT_INFO_KEY = 'checkpoint_info'
|
||||
ITERATION_KEY = 'iteration'
|
||||
LAYER_FILE_PREFIX_PATTERN = r'layer_(\d+)-model_.*'
|
||||
|
||||
SEQUENTIAL_LAYERS = [
|
||||
'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
|
||||
|
@ -32,7 +34,13 @@ LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight'
|
|||
|
||||
class DeepSpeedCheckpoint(object):
|
||||
|
||||
def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
|
||||
def __init__(self,
|
||||
dir,
|
||||
tp_degree=None,
|
||||
pp_degree=None,
|
||||
dp_degree=None,
|
||||
final_layer_norm_idx=FINAL_LAYER_NORM_INDEX):
|
||||
self.final_layer_norm_idx = final_layer_norm_idx
|
||||
self.dir = dir
|
||||
|
||||
pipeline_parallel = len(get_files_with_prefix(get_files(dir), LAYER_FILE_PREFIX)) > 0
|
||||
|
@ -73,7 +81,7 @@ class DeepSpeedCheckpoint(object):
|
|||
self.pp_to_transformer_map = self._build_pp_transformer_map()
|
||||
self.transformer_file_map = self._build_transformer_file_map()
|
||||
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
|
||||
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
|
||||
self.tp_to_final_norm_map = self._build_tp_other_layer_map(self.final_layer_norm_idx)
|
||||
self._build_global_state()
|
||||
|
||||
def is_change_tp_degree(self):
|
||||
|
@ -125,7 +133,7 @@ class DeepSpeedCheckpoint(object):
|
|||
return self.layer_keys[EMBEDDING_LAYER_INDEX]
|
||||
|
||||
def get_final_norm_layer_id(self):
|
||||
return self.layer_keys[FINAL_LAYER_NORM_INDEX]
|
||||
return self.layer_keys[self.final_layer_norm_idx]
|
||||
|
||||
def get_iteration(self):
|
||||
if not ITERATION_KEY in self.global_state:
|
||||
|
@ -214,7 +222,7 @@ class DeepSpeedCheckpoint(object):
|
|||
def _build_pp_transformer_map(self):
|
||||
data_map = {}
|
||||
if self.pp_degree > 0:
|
||||
transformer_layers = self.layer_keys[1:-1]
|
||||
transformer_layers = self.layer_keys[1:self.final_layer_norm_idx]
|
||||
layers_per_pp = len(transformer_layers) // self.pp_degree
|
||||
data_map = {
|
||||
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
|
||||
|
@ -229,7 +237,7 @@ class DeepSpeedCheckpoint(object):
|
|||
print(f'{k} = {v}')
|
||||
|
||||
def _build_transformer_file_map(self):
|
||||
transformer_layer_keys = self.layer_keys[1:-1]
|
||||
transformer_layer_keys = self.layer_keys[1:self.final_layer_norm_idx]
|
||||
file_map = {}
|
||||
# XXX: this is not guaranteed
|
||||
layers_per_pp = 1
|
||||
|
@ -238,7 +246,7 @@ class DeepSpeedCheckpoint(object):
|
|||
#print(f"{transformer_layer_keys} {layers_per_pp}")
|
||||
for key_index, layer_key in enumerate(transformer_layer_keys):
|
||||
pp_index = key_index // layers_per_pp
|
||||
layer_files = get_files_with_prefix(self.layer_files, layer_key)
|
||||
layer_files = get_files_with_prefix(self.layer_files, layer_key + '-')
|
||||
layer_file_partitions = partition_data(layer_files, self.tp_degree)
|
||||
for tp_index in range(self.tp_degree):
|
||||
map_key = (tp_index, pp_index)
|
||||
|
@ -263,11 +271,13 @@ class DeepSpeedCheckpoint(object):
|
|||
|
||||
def _get_layer_keys(self):
|
||||
key_set = set()
|
||||
key_len = len(LAYER_FILE_PREFIX) + 2
|
||||
for file_path in self.layer_files:
|
||||
_, fname = os.path.split(file_path)
|
||||
key_set.add(fname[:key_len])
|
||||
return sorted(list(key_set))
|
||||
layer_id = re.search(LAYER_FILE_PREFIX_PATTERN, fname).group(1)
|
||||
key_set.add(layer_id)
|
||||
sorted_ids = sorted(list(key_set), key=int)
|
||||
layer_keys = [LAYER_FILE_PREFIX + str(layer_id) for layer_id in sorted_ids]
|
||||
return layer_keys
|
||||
|
||||
def _merge_state_dicts(self, sd_list):
|
||||
merged_sd = {}
|
||||
|
|
|
@ -6,10 +6,12 @@
|
|||
# DeepSpeed Team
|
||||
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
import argparse
|
||||
import glob
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
@ -20,6 +22,7 @@ import tqdm
|
|||
from deepspeed.checkpoint import DeepSpeedCheckpoint
|
||||
from deepspeed.checkpoint import (
|
||||
OPTIMIZER_STATE_DICT,
|
||||
ZERO_STAGE,
|
||||
BASE_OPTIMIZER_STATE,
|
||||
SINGLE_PARTITION_OF_FP32_GROUPS,
|
||||
PARAM_GROUPS,
|
||||
|
@ -28,14 +31,19 @@ from deepspeed.checkpoint import (
|
|||
PARAM,
|
||||
CAT_DIM,
|
||||
PARAM_N_SUB_PARAMS,
|
||||
SUB_PARAM_SHAPE,
|
||||
VOCAB_TENSOR,
|
||||
UNIVERSAL_CHECKPOINT_INFO,
|
||||
UNIVERSAL_CHECKPOINT_VERSION_KEY,
|
||||
UNIVERSAL_CHECKPOINT_VERSION_VALUE,
|
||||
VOCABULARY_PARAMETER_PATTERNS,
|
||||
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
|
||||
TP_REPLICATED_PARAMETER_PATTERNS,
|
||||
PARAMETER_TO_AVERAGE_PATTERNS,
|
||||
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
|
||||
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
|
||||
PARAMETER_WITH_SUB_PARAMS,
|
||||
SubparamShape,
|
||||
)
|
||||
|
||||
|
||||
|
@ -61,11 +69,27 @@ def parse_arguments():
|
|||
dest='strict',
|
||||
action='store_false',
|
||||
help='Do not perform validity checks on converted checkpoint.')
|
||||
parser.add_argument('--inject_missing_state',
|
||||
action='store_true',
|
||||
help='Inject missing checkpoint state into the checkpoint if it is absent.')
|
||||
args = parser.parse_args()
|
||||
print(f'args = {args}')
|
||||
return args
|
||||
|
||||
|
||||
def atoi(text):
|
||||
return int(text) if text.isdigit() else text
|
||||
|
||||
|
||||
def natural_keys(text):
|
||||
'''
|
||||
alist.sort(key=natural_keys) sorts in human order
|
||||
http://nedbatchelder.com/blog/200712/human_sorting.html
|
||||
(See Toothy's implementation in the comments)
|
||||
'''
|
||||
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
||||
|
||||
|
||||
def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
|
||||
path_list = []
|
||||
iter_folder = f'iter_{iteration:07d}'
|
||||
|
@ -125,9 +149,33 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
|
|||
fragment_mapping.start, fragment_mapping.numel)
|
||||
|
||||
|
||||
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
|
||||
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
|
||||
|
||||
flat_state = dict(
|
||||
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
|
||||
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
|
||||
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for name, shape in param_shapes.items():
|
||||
unpartitioned_numel = shape.numel()
|
||||
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
|
||||
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
|
||||
for state_key in flat_state.keys():
|
||||
dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
|
||||
padding_free_numel)
|
||||
offset += partitioned_numel
|
||||
|
||||
|
||||
cnt = 0
|
||||
|
||||
|
||||
def dp_index_to_str(dp_index):
|
||||
return f"{dp_index:0>2d}"
|
||||
|
||||
|
||||
def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):
|
||||
|
||||
global cnt # temp hack
|
||||
|
@ -136,9 +184,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
|
|||
os.makedirs(param_base_path, exist_ok=True)
|
||||
|
||||
cnt += 1
|
||||
counter = f"{dp_index:0>2d}"
|
||||
|
||||
path = os.path.join(param_base_path, f"{state_name}.{counter}")
|
||||
path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")
|
||||
|
||||
#print(f"{param_name}: {offset}: {numel} => {path}")
|
||||
|
||||
|
@ -148,21 +195,35 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
|
|||
_save_checkpoint(path, state_flat_tensor)
|
||||
|
||||
|
||||
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
|
||||
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None):
|
||||
slices = []
|
||||
for tp_index in range(tp_degree):
|
||||
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
|
||||
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
|
||||
paths = glob.glob(f"{prefix_path}.*")
|
||||
|
||||
if len(paths) == 0:
|
||||
continue
|
||||
|
||||
pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
|
||||
dp_indices = set()
|
||||
for p in paths:
|
||||
m = pattern.match(p)
|
||||
if m:
|
||||
dp_indices.add(int(m.group(1)))
|
||||
else:
|
||||
raise ValueError(f"Cannot parse dp_rank from {p}")
|
||||
|
||||
paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
|
||||
shards = [torch.load(p) for p in paths]
|
||||
|
||||
if state == "step":
|
||||
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
|
||||
slice = shards[0]
|
||||
else:
|
||||
slice = torch.cat(shards, dim=0).reshape(slice_shape)
|
||||
if slice_shape is None:
|
||||
slice = torch.cat(shards, dim=0)
|
||||
else:
|
||||
slice = torch.cat(shards, dim=0).reshape(slice_shape)
|
||||
|
||||
slices.append(slice)
|
||||
return slices
|
||||
|
@ -180,8 +241,11 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
|
|||
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
|
||||
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
|
||||
parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [])
|
||||
parameter_with_sub_params = universal_checkpoint_info.get(PARAMETER_WITH_SUB_PARAMS, [])
|
||||
|
||||
unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
|
||||
vocabulary_parameters + parameters_with_2_sub_params_cat_dim_0)
|
||||
unmatched_patterns.update(chain.from_iterable(SubparamShape(**s).patterns for s in parameter_with_sub_params))
|
||||
|
||||
def get_matched_pattern(patterns_, name_):
|
||||
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
|
||||
|
@ -192,6 +256,17 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
|
|||
return pattern_
|
||||
return None
|
||||
|
||||
def get_matched_sub_params_pattern(name_):
|
||||
for subparam_shape_dict in parameter_with_sub_params:
|
||||
subparam_shape = SubparamShape(**subparam_shape_dict)
|
||||
for pattern_ in subparam_shape.patterns:
|
||||
if re.match(pattern_, name_):
|
||||
unmatched_patterns.discard(pattern_)
|
||||
return subparam_shape
|
||||
return None
|
||||
|
||||
matched_sub_params_shape = get_matched_sub_params_pattern(name)
|
||||
|
||||
step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
|
||||
if step_merged:
|
||||
_save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])
|
||||
|
@ -219,6 +294,26 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
|
|||
param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
|
||||
ckpt_dict[CAT_DIM] = cat_dim
|
||||
ckpt_dict[PARAM_N_SUB_PARAMS] = 2
|
||||
elif matched_sub_params_shape:
|
||||
merged_chunks = []
|
||||
partition_dim = matched_sub_params_shape.partition_dim
|
||||
|
||||
sub_dim_sizes = matched_sub_params_shape.shape[partition_dim]
|
||||
if not isinstance(sub_dim_sizes, tuple):
|
||||
sub_dim_sizes = (sub_dim_sizes, )
|
||||
|
||||
partition_shape = [sum(d) if isinstance(d, tuple) else d for d in matched_sub_params_shape.shape]
|
||||
partition_shape = [d // tp_degree if i == partition_dim else d for i, d in enumerate(partition_shape)]
|
||||
slices = [s.view(partition_shape) for s in slices]
|
||||
|
||||
offset = 0
|
||||
for sub_dim_size in sub_dim_sizes:
|
||||
part_sub_dim_size = sub_dim_size // tp_degree
|
||||
merged_chunks.append(
|
||||
torch.cat([s.narrow(partition_dim, offset, part_sub_dim_size) for s in slices], dim=partition_dim))
|
||||
offset += part_sub_dim_size
|
||||
param = torch.cat(merged_chunks, dim=partition_dim)
|
||||
ckpt_dict[SUB_PARAM_SHAPE] = matched_sub_params_shape
|
||||
else:
|
||||
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
|
||||
# print(f"merge {name} with CAT DIM: {cat_dim}")
|
||||
|
@ -240,27 +335,28 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
|
|||
return unmatched_patterns
|
||||
|
||||
|
||||
def _get_chunks(l, n):
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
def merge_zero3_slices(dp_degree, dir, slice_dir, name):
|
||||
slice_base_path = os.path.join(slice_dir, name)
|
||||
param_base_path = os.path.join(dir, name)
|
||||
|
||||
for state in ("fp32", "exp_avg", "exp_avg_sq"):
|
||||
slices = _merge_zero_shards(slice_base_path, state, 1)
|
||||
final_path = os.path.join(param_base_path, f"{state}.pt")
|
||||
_save_checkpoint(final_path, slices[0])
|
||||
|
||||
|
||||
def _do_parallel_work(do_work, work_chunks, num_workers):
|
||||
results = []
|
||||
if num_workers > 1:
|
||||
pool = multiprocessing.Pool(num_workers)
|
||||
results = []
|
||||
for batch in tqdm.tqdm(work_chunks):
|
||||
res = pool.map(do_work, batch)
|
||||
results.extend(res)
|
||||
pool.close()
|
||||
pool.join()
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_list = [executor.submit(do_work, work) for work in work_chunks]
|
||||
for f in tqdm.tqdm(future_list):
|
||||
results.append(f.result())
|
||||
else:
|
||||
# No parallel pass for unit testing
|
||||
# We can't create child processes in tests
|
||||
results = []
|
||||
for batch in tqdm.tqdm(work_chunks):
|
||||
res = [do_work(x) for x in batch]
|
||||
results.extend(res)
|
||||
for work in tqdm.tqdm(work_chunks):
|
||||
results.append(do_work(work))
|
||||
return results
|
||||
|
||||
|
||||
|
@ -269,20 +365,20 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
|
|||
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
|
||||
range(ds_checkpoint.dp_degree)))
|
||||
#pprint(f'{_3d_range_list=}')
|
||||
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
|
||||
#pprint(f'{work_chunks=}')
|
||||
|
||||
# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
|
||||
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
|
||||
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)
|
||||
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)
|
||||
|
||||
|
||||
def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir):
|
||||
do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir)
|
||||
_do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
|
||||
|
||||
|
||||
def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
|
||||
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
|
||||
#pprint(work_chunks)
|
||||
zero_output_folder = os.path.join(args.output_folder, "zero")
|
||||
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
|
||||
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
|
||||
unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)
|
||||
|
||||
# verify that all patterns were used
|
||||
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
|
||||
|
@ -294,6 +390,23 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
|
|||
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')
|
||||
|
||||
|
||||
def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir):
|
||||
zero_output_folder = os.path.join(args.output_folder, "zero")
|
||||
do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir)
|
||||
_do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers)
|
||||
|
||||
|
||||
def _zero_partitioned_param_info(unpartitioned_numel, world_size):
|
||||
remainder = unpartitioned_numel % world_size
|
||||
padding_numel = (world_size - remainder) if remainder else 0
|
||||
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
||||
return partitioned_numel, padding_numel
|
||||
|
||||
|
||||
def _parse_model_states_stage3(files):
|
||||
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]
|
||||
|
||||
|
||||
def _save_optimizer_state(args, ds_checkpoint):
|
||||
sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS]
|
||||
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)
|
||||
|
@ -306,6 +419,48 @@ def _save_optimizer_state(args, ds_checkpoint):
|
|||
_save_checkpoint(output_file_path, output_sd)
|
||||
|
||||
|
||||
def _save_optimizer_state_stage3(args, optim_files):
|
||||
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
|
||||
output_sd = sd[OPTIMIZER_STATE_DICT]
|
||||
output_sd[PARAM_GROUPS] = output_sd[OPTIMIZER_STATE_DICT][PARAM_GROUPS]
|
||||
zero_output_folder = os.path.join(args.output_folder, "zero")
|
||||
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
|
||||
_save_checkpoint(output_file_path, output_sd)
|
||||
|
||||
|
||||
def _get_optim_files(checkpoint_dir):
|
||||
return _get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
||||
|
||||
|
||||
def _get_model_state_files(checkpoint_dir):
|
||||
return _get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
||||
|
||||
|
||||
def _get_checkpoint_files(checkpoint_dir, glob_pattern):
|
||||
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
||||
|
||||
if len(ckpt_files) == 0:
|
||||
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
||||
|
||||
return ckpt_files
|
||||
|
||||
|
||||
def _get_zero_stage(optim_files):
|
||||
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
|
||||
optimizer_state = state_dict[OPTIMIZER_STATE_DICT]
|
||||
zero_stage = optimizer_state.get(ZERO_STAGE, 1)
|
||||
return zero_stage
|
||||
|
||||
|
||||
def _inject_missing_state(ds_checkpoint):
|
||||
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
|
||||
sd = torch.load(ds_checkpoint.mp_rank_files[0], map_location=torch.device('cpu'))
|
||||
if UNIVERSAL_CHECKPOINT_INFO not in sd:
|
||||
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
|
||||
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO][
|
||||
UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE
|
||||
|
||||
|
||||
def _check_for_required_state(ds_checkpoint):
|
||||
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
|
||||
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
|
||||
|
@ -316,38 +471,69 @@ def main(args):
|
|||
|
||||
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')
|
||||
|
||||
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
|
||||
_check_for_required_state(ds_checkpoint)
|
||||
optim_files = _get_optim_files(args.input_folder)
|
||||
zero_stage = _get_zero_stage(optim_files)
|
||||
|
||||
iteration = ds_checkpoint.get_iteration()
|
||||
#_create_latest_file(args.output_folder, iteration)
|
||||
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
|
||||
ds_checkpoint.pp_degree)
|
||||
if zero_stage <= 2:
|
||||
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
|
||||
if args.inject_missing_state:
|
||||
_inject_missing_state(ds_checkpoint)
|
||||
else:
|
||||
_check_for_required_state(ds_checkpoint)
|
||||
|
||||
slice_shapes = []
|
||||
for mp_rank_file in ds_checkpoint.mp_rank_files:
|
||||
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
|
||||
slice_shapes += mp_sd[PARAM_SHAPES]
|
||||
iteration = ds_checkpoint.get_iteration()
|
||||
#_create_latest_file(args.output_folder, iteration)
|
||||
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
|
||||
ds_checkpoint.pp_degree)
|
||||
|
||||
# fix back to normal flat dict, merge duplicates for tp>1
|
||||
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
|
||||
temp_dir = os.path.join(args.output_folder, 'tmp')
|
||||
slice_shapes = []
|
||||
for mp_rank_file in ds_checkpoint.mp_rank_files:
|
||||
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
|
||||
slice_shapes += mp_sd[PARAM_SHAPES]
|
||||
|
||||
print('*** 1. Extracting ZeRO fragments')
|
||||
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)
|
||||
# fix back to normal flat dict, merge duplicates for tp>1
|
||||
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
|
||||
temp_dir = os.path.join(args.output_folder, 'tmp')
|
||||
|
||||
print('*** 2. Merging slices .....')
|
||||
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
|
||||
print('*** 1. Extracting ZeRO fragments')
|
||||
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)
|
||||
|
||||
print('*** 3. Saving common optimizer states')
|
||||
_save_optimizer_state(args, ds_checkpoint)
|
||||
print('*** 2. Merging slices .....')
|
||||
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
|
||||
|
||||
if not args.keep_temp_folder:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
print('*** 3. Saving common optimizer states')
|
||||
_save_optimizer_state(args, ds_checkpoint)
|
||||
|
||||
# Copy mp* files into output folder
|
||||
for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
|
||||
shutil.copy2(f, args.output_folder)
|
||||
if not args.keep_temp_folder:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
# Copy mp* files into output folder
|
||||
for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
|
||||
shutil.copy2(f, args.output_folder)
|
||||
|
||||
else:
|
||||
model_files = _get_model_state_files(args.input_folder)
|
||||
param_shapes = _parse_model_states_stage3(model_files)
|
||||
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
||||
dp_degree = len(model_files)
|
||||
|
||||
temp_dir = os.path.join(args.output_folder, 'tmp')
|
||||
|
||||
print('*** 1. Extracting ZeRO fragments')
|
||||
_extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
|
||||
|
||||
print('*** 2. Merging slices .....')
|
||||
_merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir)
|
||||
|
||||
print('*** 3. Saving common optimizer states')
|
||||
_save_optimizer_state_stage3(args, optim_files)
|
||||
|
||||
if not args.keep_temp_folder:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
# Copy *model_states files into output folder
|
||||
for f in glob.glob(os.path.join(args.input_folder, '*model_states.pt')):
|
||||
shutil.copy2(f, args.output_folder)
|
||||
|
||||
# Update latest to output folder
|
||||
checkpoint_root_folder, step_folder = os.path.split(args.output_folder)
|
||||
|
|
|
@ -7,7 +7,16 @@ import os
|
|||
import re
|
||||
import torch
|
||||
import types
|
||||
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)
|
||||
from typing import List, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubparamShape:
|
||||
patterns: List[str]
|
||||
shape: Tuple[Union[Tuple[int], int]]
|
||||
partition_dim: int
|
||||
|
||||
|
||||
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
|
||||
|
@ -76,12 +85,32 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
|
|||
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
|
||||
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
|
||||
|
||||
sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
|
||||
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
|
||||
# special case is when a single parameter is effectively a container for multiple sub parameters
|
||||
# (more details at PARAM_N_SUB_PARAMS definition)
|
||||
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
|
||||
n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
|
||||
if n_sub_params > 1:
|
||||
if sub_param_shape:
|
||||
partition_dim = sub_param_shape.partition_dim
|
||||
sub_dim_sizes = sub_param_shape.shape[partition_dim]
|
||||
if not isinstance(sub_dim_sizes, tuple):
|
||||
sub_dim_sizes = (sub_dim_sizes, )
|
||||
|
||||
partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
|
||||
full_hp_param = full_hp_param.view(partition_shape)
|
||||
|
||||
offset = 0
|
||||
merged_chunks = []
|
||||
for sub_dim_size in sub_dim_sizes:
|
||||
sub_params_tp_slice = full_hp_param.narrow(partition_dim,
|
||||
offset, sub_dim_size).chunk(tp_world_size,
|
||||
dim=partition_dim)[tp_rank]
|
||||
merged_chunks.append(sub_params_tp_slice)
|
||||
offset += sub_dim_size
|
||||
tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)
|
||||
|
||||
elif n_sub_params > 1:
|
||||
sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
|
||||
sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
|
||||
tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
|
||||
|
|
|
@ -8,13 +8,14 @@ Copyright 2021 The Microsoft DeepSpeed Team
|
|||
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.ops.op_builder import NotImplementedBuilder
|
||||
from .reduce_op import ReduceOp
|
||||
from .torch import TorchBackend
|
||||
|
||||
|
||||
def build_ccl_op():
|
||||
builder = get_accelerator().create_op_builder("CCLCommBuilder")
|
||||
if builder is None:
|
||||
if builder is None or NotImplementedBuilder:
|
||||
return None
|
||||
ccl_cpp_module = builder.load()
|
||||
print(f'DeepSpeed {builder.absolute_name()} built successfully')
|
||||
|
@ -98,12 +99,12 @@ class CCLBackend(TorchBackend):
|
|||
else:
|
||||
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)
|
||||
|
||||
def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None):
|
||||
name = "inference_all_reduce"
|
||||
if name in self.available_coll:
|
||||
return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op)
|
||||
return self.ccl_comm_op.inference_all_reduce(tensor, op)
|
||||
else:
|
||||
return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op)
|
||||
return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=False)
|
||||
|
||||
def broadcast(self, tensor, src, group=None, async_op=False):
|
||||
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
|
||||
|
|
|
@ -519,7 +519,7 @@ def inference_all_reduce(tensor,
|
|||
log_name='all_reduce',
|
||||
debug=get_caller_func()):
|
||||
global cdb
|
||||
return cdb.inference_all_reduce(tensor, op, group, async_op)
|
||||
return cdb.inference_all_reduce(tensor, op, group)
|
||||
|
||||
|
||||
@timed_op
|
||||
|
@ -632,7 +632,7 @@ def init_distributed(dist_backend=None,
|
|||
auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
|
||||
distributed_port: Optional (int). torch distributed backend port
|
||||
verbose: Optional (bool). verbose logging
|
||||
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
|
||||
timeout: Optional (timedelta). Timeout for operations executed against the process group. The default value of 30 minutes can be overridden by the environment variable `DEEPSPEED_TIMEOUT`.
|
||||
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
|
||||
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
|
||||
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
|
||||
|
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import deepspeed
|
||||
from deepspeed import utils
|
||||
|
||||
from .utils import *
|
||||
from .backend import *
|
||||
from .comm import *
|
||||
from ..runtime import compiler
|
||||
from deepspeed.utils.torch import required_torch_version
|
||||
import os
|
||||
|
||||
DS_COMM_ALL_GATHER_OFF = False
|
||||
|
@ -18,25 +20,13 @@ DS_COMM_ALL_REDUCE_OFF = False
|
|||
DS_COMM_REDUCE_OFF = False
|
||||
|
||||
|
||||
def is_torch_ver_eq_2_0():
|
||||
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
|
||||
if TORCH_MAJOR == 2 and TORCH_MINOR == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_ver_ge_2_1():
|
||||
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
|
||||
if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def torch_ver_ge_1_13():
|
||||
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])
|
||||
if TORCH_MAJOR >= 1 and TORCH_MINOR >= 13:
|
||||
return True
|
||||
return False
|
||||
def build_shm_op():
|
||||
builder = get_accelerator().create_op_builder("ShareMemCommBuilder")
|
||||
if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]:
|
||||
return None
|
||||
shm_cpp_module = builder.load()
|
||||
print(f'DeepSpeed {builder.absolute_name()} built successfully')
|
||||
return shm_cpp_module
|
||||
|
||||
|
||||
def has_coalescing_manager():
|
||||
|
@ -45,13 +35,13 @@ def has_coalescing_manager():
|
|||
|
||||
|
||||
def has_all_reduce_coalesced():
|
||||
return hasattr(torch.distributed, "all_reduce_coalesced") and torch_ver_ge_1_13()
|
||||
return hasattr(torch.distributed, "all_reduce_coalesced") and required_torch_version(min_version=1.13)
|
||||
|
||||
|
||||
def get_coalescing_manager(group, device, reqs, async_op):
|
||||
if is_torch_ver_eq_2_0():
|
||||
if required_torch_version(min_version=2.0, max_version=2.0):
|
||||
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
|
||||
elif is_torch_ver_ge_2_1():
|
||||
elif required_torch_version(min_version=2.1):
|
||||
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, async_ops=async_op)
|
||||
else:
|
||||
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
|
||||
|
@ -108,6 +98,7 @@ class TorchBackend(Backend):
|
|||
|
||||
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
|
||||
super(TorchBackend, self).__init__()
|
||||
self.shm_comm_op = build_shm_op()
|
||||
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
|
||||
self.has_coalescing_manager = has_coalescing_manager()
|
||||
self.all_gather_function = self.get_all_gather_function()
|
||||
|
@ -119,6 +110,8 @@ class TorchBackend(Backend):
|
|||
# it is not so we can run on a single GPU without doing any init_process_group
|
||||
self.single_gpu_mode = True
|
||||
self.init_process_group(backend, timeout, init_method, rank, world_size)
|
||||
if self.shm_comm_op != None:
|
||||
self.shm_comm_op.initialize(self.get_world_size(), self.get_rank())
|
||||
|
||||
@classmethod
|
||||
@compiler.disable
|
||||
|
@ -162,9 +155,10 @@ class TorchBackend(Backend):
|
|||
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
|
||||
|
||||
@compiler.disable
|
||||
def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
|
||||
op = self._reduce_op(op)
|
||||
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
|
||||
def inference_all_reduce(self, tensor, op, group=None):
|
||||
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
|
||||
op = self._reduce_op(op)
|
||||
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
|
||||
|
||||
@compiler.disable
|
||||
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from torch.distributed.elastic.agent.server.api import _get_socket_with_port
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.distributed.elastic.metrics import put_metric
|
||||
from torch.distributed.elastic.agent.server.api import (
|
||||
RunResult,
|
||||
|
@ -48,7 +48,7 @@ class DSElasticAgent(LocalElasticAgent):
|
|||
master_port: Optional[int],
|
||||
local_addr: Optional[str] = None):
|
||||
if master_port is None:
|
||||
sock = _get_socket_with_port()
|
||||
sock = get_free_port()
|
||||
with closing(sock):
|
||||
master_port = sock.getsockname()[1]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
from deepspeed.runtime.utils import required_torch_version
|
||||
from deepspeed.utils.torch import required_torch_version
|
||||
|
||||
|
||||
def is_torch_elastic_compatible():
|
||||
|
|
|
@ -110,7 +110,7 @@ def installed_cann_version():
|
|||
def get_shm_size():
|
||||
try:
|
||||
shm_stats = os.statvfs('/dev/shm')
|
||||
except (OSError, FileNotFoundError, ValueError):
|
||||
except (OSError, FileNotFoundError, ValueError, AttributeError):
|
||||
return "UNKNOWN", None
|
||||
|
||||
shm_size = shm_stats.f_frsize * shm_stats.f_blocks
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import torch
|
||||
import time
|
||||
import os
|
||||
import deepspeed
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.utils.logging import log_dist
|
||||
|
||||
|
@ -13,6 +14,7 @@ from torch.nn.modules import Module
|
|||
from packaging import version as pkg_version
|
||||
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
||||
from deepspeed.runtime.compiler import is_compile_supported
|
||||
|
||||
from ..runtime.state_dict_factory import SDLoaderFactory
|
||||
from ..runtime.weight_quantizer import WeightQuantization
|
||||
|
@ -185,6 +187,7 @@ class InferenceEngine(Module):
|
|||
|
||||
# Check if local CUDA graphs can be created in replacement modules
|
||||
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
|
||||
self._is_compiled = False
|
||||
|
||||
def destroy(self):
|
||||
# Have to import here because inference_module is a global, but python
|
||||
|
@ -634,3 +637,22 @@ class InferenceEngine(Module):
|
|||
)
|
||||
|
||||
return self.module.generate(*inputs, **kwargs)
|
||||
|
||||
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
|
||||
"""
|
||||
Compile the module using the specified backend and kwargs.
|
||||
"""
|
||||
if not is_compile_supported():
|
||||
raise RuntimeError("compile is not supported in your version of PyTorch.")
|
||||
|
||||
if self._is_compiled:
|
||||
return
|
||||
|
||||
# Avoid graph breaks
|
||||
deepspeed.utils.nvtx.enable_nvtx = False
|
||||
self.module.compile(backend=backend, **compile_kwargs)
|
||||
self._is_compiled = True
|
||||
|
||||
@property
|
||||
def is_compiled(self) -> bool:
|
||||
return self._is_compiled
|
||||
|
|
|
@ -14,14 +14,14 @@ import functools
|
|||
|
||||
device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu'
|
||||
|
||||
quantizer_cuda_module = None
|
||||
quantizer_module = None
|
||||
|
||||
|
||||
def get_quantizer_cuda_module():
|
||||
global quantizer_cuda_module
|
||||
if quantizer_cuda_module is None:
|
||||
quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
|
||||
return quantizer_cuda_module
|
||||
def get_quantizer_module():
|
||||
global quantizer_module
|
||||
if quantizer_module is None:
|
||||
quantizer_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
|
||||
return quantizer_module
|
||||
|
||||
|
||||
def tensor_clamp(tensor: Tensor, min, max) -> Tensor:
|
||||
|
@ -107,19 +107,19 @@ class DeQuantizer:
|
|||
if self.config['group_size'] % 8 == 0 and \
|
||||
(self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \
|
||||
self.config['group_dim'] == len(tensor.shape) - 1 and \
|
||||
self.dtype == torch.float16 and device == 'cuda':
|
||||
self.dtype == torch.float16 and device == get_accelerator().device_name():
|
||||
|
||||
last_dimension_size = self.config['group_size']
|
||||
if self.config['num_bits'] == 4:
|
||||
last_dimension_size = last_dimension_size // 2
|
||||
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
|
||||
quantized_tensor = get_quantizer_module().dequantize_int4_to_half_experimental(
|
||||
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
|
||||
tensor.numel() // last_dimension_size, self.config['group_size'])
|
||||
shape = list(tensor.shape)
|
||||
shape[-1] = shape[-1] * 2
|
||||
elif self.config['num_bits'] == 8:
|
||||
# last_dimension_size = last_dimension_size // 2
|
||||
quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental(
|
||||
quantized_tensor = get_quantizer_module().dequantize_int8_to_half_experimental(
|
||||
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
|
||||
tensor.numel() // last_dimension_size, self.config['group_size'])
|
||||
shape = list(tensor.shape)
|
||||
|
|
|
@ -55,7 +55,9 @@ private:
|
|||
|
||||
enum class BlasType { FP32, FP16, BF16 };
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// TODO HIP: Remove backward compatibility for torch<=2.0 in future
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_operation get_trans_op(bool do_trans)
|
||||
{
|
||||
return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none;
|
||||
|
@ -76,9 +78,15 @@ cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T
|
|||
cublasDataType_t get_datatype(BlasType type)
|
||||
{
|
||||
switch (type) {
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
case BlasType::FP32: return HIPBLAS_R_32F;
|
||||
case BlasType::FP16: return HIPBLAS_R_16F;
|
||||
case BlasType::BF16: return HIPBLAS_R_16B;
|
||||
#else
|
||||
case BlasType::FP32: return CUDA_R_32F;
|
||||
case BlasType::FP16: return CUDA_R_16F;
|
||||
case BlasType::BF16: return CUDA_R_16BF;
|
||||
#endif
|
||||
default: throw std::runtime_error("Unsupported BlasType");
|
||||
}
|
||||
}
|
||||
|
@ -99,7 +107,8 @@ int blas_gemm_ex(void* C,
|
|||
const float* beta,
|
||||
BlasType type)
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_operation_t transa_op = get_trans_op(transa);
|
||||
rocblas_operation_t transb_op = get_trans_op(transb);
|
||||
|
||||
|
@ -151,11 +160,18 @@ int blas_gemm_ex(void* C,
|
|||
C,
|
||||
abc_type,
|
||||
ldc,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
@ -190,7 +206,8 @@ int blas_strided_batched_gemm(void* C,
|
|||
int batch,
|
||||
BlasType type)
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
rocblas_operation_t transa_op = get_trans_op(transa);
|
||||
rocblas_operation_t transb_op = get_trans_op(transb);
|
||||
|
||||
|
@ -253,11 +270,18 @@ int blas_strided_batched_gemm(void* C,
|
|||
ldc,
|
||||
stride_C,
|
||||
batch,
|
||||
#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2)
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
#elif defined(__HIP_PLATFORM_AMD__)
|
||||
HIPBLAS_R_32F,
|
||||
#else
|
||||
CUDA_R_32F,
|
||||
#endif
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#if defined(__HIP_PLATFORM_AMD__) && \
|
||||
((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0))
|
||||
if (status != rocblas_status_success) {
|
||||
#else
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
|
|
|
@ -8,9 +8,9 @@
|
|||
|
||||
#include "bias_activation.h"
|
||||
#include "blas.h"
|
||||
#include "cuda_linear_kernels.h"
|
||||
#include "gated_activation_kernels.h"
|
||||
#include "layer_norm.h"
|
||||
#include "linear_kernels.h"
|
||||
#include "rms_norm.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
|
@ -35,7 +35,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||
m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA");
|
||||
m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA");
|
||||
|
||||
// cuda_linear_kernels.h
|
||||
// linear_kernels.h
|
||||
m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA");
|
||||
m.def(
|
||||
"preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit");
|
||||
|
|
|
@ -252,7 +252,6 @@ __global__ void fused_residual_ln(T* output,
|
|||
for (int i = 0; i < unRoll; i++) {
|
||||
T* iteration_buffer = local_buffer + i * T_per_load;
|
||||
T residual_buffer[T_per_load];
|
||||
T bias_buffer[T_per_load];
|
||||
|
||||
mem_access::load_global<ln::granularity>(
|
||||
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
|
||||
|
|