зеркало из https://github.com/microsoft/DeepSpeed.git
Fixed Windows inference build. (#5609)
Fix #2427 --------- Co-authored-by: Costin Eseanu <costineseanu@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
d89e8cdfe5
Коммит
b3767d01d4
|
@ -7,6 +7,7 @@ import functools
|
|||
import os
|
||||
import pkgutil
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from .abstract_accelerator import DeepSpeedAccelerator
|
||||
# During setup stage torch may not be installed, pass on no torch will
|
||||
|
@ -24,7 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
def __init__(self):
|
||||
self._name = 'cuda'
|
||||
self._communication_backend_name = 'nccl'
|
||||
self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'
|
||||
self._compile_backend = "inductor"
|
||||
if pynvml is None:
|
||||
self._init_pynvml()
|
||||
|
|
|
@ -6,10 +6,8 @@ set DS_BUILD_AIO=0
|
|||
set DS_BUILD_CUTLASS_OPS=0
|
||||
set DS_BUILD_EVOFORMER_ATTN=0
|
||||
set DS_BUILD_FP_QUANTIZER=0
|
||||
set DS_BUILD_INFERENCE_CORE_OPS=0
|
||||
set DS_BUILD_RAGGED_DEVICE_OPS=0
|
||||
set DS_BUILD_SPARSE_ATTN=0
|
||||
set DS_BUILD_TRANSFORMER_INFERENCE=0
|
||||
|
||||
python setup.py bdist_wheel
|
||||
|
||||
|
|
|
@ -542,22 +542,23 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
|
|||
1);
|
||||
|
||||
if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
|
||||
auto prev_key = torch::from_blob(workspace + offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
auto prev_key = torch::from_blob(
|
||||
workspace + offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
|
||||
auto prev_value =
|
||||
torch::from_blob(workspace + offset + value_offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k * InferenceContext::Instance().GetMaxTokenLength(),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
auto prev_value = torch::from_blob(
|
||||
workspace + offset + value_offset,
|
||||
{bsz, heads, all_tokens, k},
|
||||
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
|
||||
k,
|
||||
1},
|
||||
options);
|
||||
|
||||
return {output, prev_key, prev_value};
|
||||
}
|
||||
|
@ -1592,7 +1593,9 @@ std::vector<at::Tensor> ds_rms_mlp_gemm(at::Tensor& input,
|
|||
auto output = at::from_blob(output_ptr, input.sizes(), options);
|
||||
auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options);
|
||||
auto intermediate_gemm =
|
||||
at::from_blob(intermediate_ptr, {input.size(0), input.size(1), mlp_1_out_neurons}, options);
|
||||
at::from_blob(intermediate_ptr,
|
||||
{input.size(0), input.size(1), static_cast<int64_t>(mlp_1_out_neurons)},
|
||||
options);
|
||||
|
||||
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
|
||||
|
||||
|
|
|
@ -252,7 +252,6 @@ __global__ void fused_residual_ln(T* output,
|
|||
for (int i = 0; i < unRoll; i++) {
|
||||
T* iteration_buffer = local_buffer + i * T_per_load;
|
||||
T residual_buffer[T_per_load];
|
||||
T bias_buffer[T_per_load];
|
||||
|
||||
mem_access::load_global<ln::granularity>(
|
||||
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
|
||||
|
|
|
@ -179,13 +179,13 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1,
|
|||
SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 *
|
||||
4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
|
||||
// Trible-Buffer for B Tile
|
||||
half __restrict__(*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
|
||||
half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
|
||||
smem_array + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
|
||||
#ifdef PIPELINE_LEVEL_SMEM
|
||||
half __restrict__(*read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
|
||||
half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
|
||||
smem_array + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
|
||||
#endif
|
||||
half __restrict__(*write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
|
||||
half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
|
||||
smem_array +
|
||||
((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
|
||||
//
|
||||
|
@ -265,7 +265,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1,
|
|||
}
|
||||
|
||||
#else
|
||||
#warning "The FP6 functions are only available on Ampere GPUs."
|
||||
assert(("The FP6 functions are only available on Ampere GPUs.", false));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ __device__ __forceinline__ void cp_async(half* smem_ptr,
|
|||
"l"(global_ptr),
|
||||
"n"(SizeInBytes));
|
||||
#else
|
||||
#warning "The async copy functions are only supported on Ampere and newer architectures"
|
||||
assert(
|
||||
("The async copy functions are only supported on Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -40,7 +41,8 @@ __device__ __forceinline__ void cp_async_group_commit()
|
|||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#else
|
||||
#warning "The async copy functions are only supported on Ampere and newer architectures"
|
||||
assert(
|
||||
("The async copy functions are only supported on Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -51,7 +53,8 @@ __device__ __forceinline__ void cp_async_wait_group()
|
|||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||
#else
|
||||
#warning "The async copy functions are only supported on Ampere and newer architectures"
|
||||
assert(
|
||||
("The async copy functions are only supported on Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -64,7 +67,8 @@ __device__ __forceinline__ void cp_async_wait_all()
|
|||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
#else
|
||||
#warning "The async copy functions are only supported on Ampere and newer architectures"
|
||||
assert(
|
||||
("The async copy functions are only supported on Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#ifdef PIPELINE_LEVEL_SMEM
|
||||
template <typename TilingConfig>
|
||||
__device__ __forceinline__ void B_FromSharedToReg(
|
||||
uint32_t __restrict__ Reg[][4],
|
||||
half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
uint32_t (*__restrict__ Reg)[4],
|
||||
half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
int slice_id)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
@ -56,7 +56,8 @@ __device__ __forceinline__ void B_FromSharedToReg(
|
|||
}
|
||||
}
|
||||
#else
|
||||
#warning "The matrix load functions are only supported on Ampere and newer architectures"
|
||||
assert(
|
||||
("The matrix load functions are only supported on Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
|
@ -64,8 +65,8 @@ __device__ __forceinline__ void B_FromSharedToReg(
|
|||
// B is in column-major
|
||||
template <typename TilingConfig>
|
||||
__device__ __forceinline__ void B_FromSharedToReg(
|
||||
uint32_t __restrict__ Reg[][4],
|
||||
half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
uint32_t (*__restrict__ Reg)[4],
|
||||
half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
int k_offset)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
@ -102,14 +103,15 @@ __device__ __forceinline__ void B_FromSharedToReg(
|
|||
}
|
||||
}
|
||||
#else
|
||||
#warning "The matrix load functions are only supported on Ampere and newer architectures"
|
||||
assert(
|
||||
("The matrix load functions are only supported on Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[],
|
||||
uint32_t __restrict__* a,
|
||||
uint32_t __restrict__* b)
|
||||
__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c,
|
||||
uint32_t* __restrict__ a,
|
||||
uint32_t* __restrict__ b)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile(
|
||||
|
@ -130,7 +132,7 @@ __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[],
|
|||
"r"(c[2]),
|
||||
"r"(c[3]));
|
||||
#else
|
||||
#warning "The mma functions are only implemented for Ampere and newer architectures"
|
||||
assert(("The mma functions are only implemented for Ampere and newer architectures", false));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ __device__ __forceinline__ void initialize_mma_slice(
|
|||
uint32_t (*b)[4],
|
||||
uint32_t* __restrict__ A1_SPTR_read,
|
||||
uint32_t* __restrict__ A2_SPTR_read,
|
||||
half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
uint32_t* RPTR_Scales)
|
||||
{
|
||||
// Writing registers
|
||||
|
@ -54,7 +54,7 @@ __device__ __forceinline__ void core_mma_slice(
|
|||
uint32_t (*b)[4],
|
||||
uint32_t* __restrict__ A1_SPTR_read,
|
||||
uint32_t* __restrict__ A2_SPTR_read,
|
||||
half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
uint32_t* RPTR_Scales,
|
||||
int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
|
||||
{
|
||||
|
|
|
@ -57,7 +57,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc
|
|||
*/
|
||||
template <int MaxNumOfLinesToCopy, int BLOCK_WARPS>
|
||||
__device__ __forceinline__ void CopyFromGlobalToShared(
|
||||
half __restrict__ (*SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
|
||||
const half* GlobalPTR,
|
||||
const int GlobalStride,
|
||||
const int NumOfLinesLeft, // To support arbitrary N dimensions.
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
* Outputs: R1, R2
|
||||
* Note: Simplified Exponent calculation is applied.
|
||||
*/
|
||||
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2)
|
||||
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t* R1, uint32_t* R2)
|
||||
{
|
||||
*R2 = *R1 & 0x80808080;
|
||||
*R1 = *R1 >> 2;
|
||||
|
@ -33,7 +33,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2)
|
|||
* Outputs: R1, R2
|
||||
* Note: Simplified Exponent calculation is NOT applied.
|
||||
*/
|
||||
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_t* R2)
|
||||
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t* R1, uint32_t* R2)
|
||||
{
|
||||
//*R2 = *R1 & 0x80808080;
|
||||
*R2 = *R1 & 0xc0c0c0c0;
|
||||
|
@ -56,7 +56,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_
|
|||
//*R2 = 0x3c003c00;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale)
|
||||
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale)
|
||||
{
|
||||
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
|
||||
half* FP16_2 = FP16_1 + 1;
|
||||
|
@ -67,17 +67,17 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc
|
|||
return output;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4],
|
||||
u_int32_t __restrict__* read_RPTR_Frag1,
|
||||
u_int32_t __restrict__* read_RPTR_Frag2,
|
||||
u_int32_t* Scales)
|
||||
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (*__restrict__ Reg)[4],
|
||||
uint32_t* __restrict__ read_RPTR_Frag1,
|
||||
uint32_t* __restrict__ read_RPTR_Frag2,
|
||||
uint32_t* Scales)
|
||||
{
|
||||
u_int32_t* OutputRegs = reinterpret_cast<u_int32_t*>(Reg);
|
||||
u_int32_t* Frag1_PTR = read_RPTR_Frag1;
|
||||
u_int32_t* Frag2_PTR = read_RPTR_Frag2;
|
||||
uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg);
|
||||
uint32_t* Frag1_PTR = read_RPTR_Frag1;
|
||||
uint32_t* Frag2_PTR = read_RPTR_Frag2;
|
||||
half* Scale_RPTR = reinterpret_cast<half*>(Scales);
|
||||
u_int32_t Packed_FP6 = 0;
|
||||
u_int32_t tmp = 0;
|
||||
uint32_t Packed_FP6 = 0;
|
||||
uint32_t tmp = 0;
|
||||
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
|
||||
#pragma unroll(8)
|
||||
for (int i = 0; i < 8; i++) {
|
||||
|
|
|
@ -18,6 +18,7 @@ BACKWARD_REDUCE_MICRO_TIMER = 'bwd_allreduce_microstep'
|
|||
BACKWARD_REDUCE_GLOBAL_TIMER = 'bwd_allreduce'
|
||||
STEP_MICRO_TIMER = 'step_microstep'
|
||||
STEP_GLOBAL_TIMER = 'step'
|
||||
TIME_EPSILON = 1e-6
|
||||
|
||||
try:
|
||||
import psutil
|
||||
|
@ -262,7 +263,7 @@ class ThroughputTimer:
|
|||
self.micro_step_count,
|
||||
self.global_step_count,
|
||||
self.avg_samples_per_sec(),
|
||||
self.batch_size / self.step_elapsed_time,
|
||||
self.batch_size / (self.step_elapsed_time + TIME_EPSILON),
|
||||
round(get_accelerator().memory_allocated() / 1024**3, 2),
|
||||
round(get_accelerator().max_memory_allocated() / 1024**3, 2),
|
||||
))
|
||||
|
|
|
@ -678,6 +678,7 @@ class CUDAOpBuilder(OpBuilder):
|
|||
|
||||
if not self.build_for_cpu and self.enable_bf16:
|
||||
compile_args['cxx'].append("-DBF16_AVAILABLE")
|
||||
compile_args['nvcc'].append("-DBF16_AVAILABLE")
|
||||
|
||||
if self.is_rocm_pytorch():
|
||||
compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1")
|
||||
|
|
13
setup.py
13
setup.py
|
@ -18,6 +18,7 @@ build_win.bat
|
|||
The wheel will be located at: dist/*.whl
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
@ -209,9 +210,15 @@ else:
|
|||
git_branch = "unknown"
|
||||
|
||||
if sys.platform == "win32":
|
||||
shutil.copytree('.\\csrc', '.\\deepspeed\\ops')
|
||||
shutil.copytree('.\\op_builder', '.\\deepspeed\\ops')
|
||||
shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator')
|
||||
shutil.rmtree('.\\deepspeed\\ops\\csrc', ignore_errors=True)
|
||||
pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True)
|
||||
shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True)
|
||||
shutil.rmtree('.\\deepspeed\\ops\\op_builder', ignore_errors=True)
|
||||
pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True)
|
||||
shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True)
|
||||
shutil.rmtree('.\\deepspeed\\accelerator', ignore_errors=True)
|
||||
pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True)
|
||||
shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True)
|
||||
egg_info.manifest_maker.template = 'MANIFEST_win.in'
|
||||
|
||||
# Parse the DeepSpeed version string from version.txt.
|
||||
|
|
Загрузка…
Ссылка в новой задаче