2023-03-31 03:14:38 +03:00
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
|
|
# DeepSpeed Team
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
from .builder import CUDAOpBuilder
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBuilder(CUDAOpBuilder):
|
|
|
|
BUILD_VAR = "DS_BUILD_TRANSFORMER"
|
|
|
|
NAME = "transformer"
|
|
|
|
|
|
|
|
def __init__(self, name=None):
|
|
|
|
name = self.NAME if name is None else name
|
|
|
|
super().__init__(name=name)
|
|
|
|
|
|
|
|
def absolute_name(self):
|
|
|
|
return f'deepspeed.ops.transformer.{self.NAME}_op'
|
|
|
|
|
2022-04-26 21:50:38 +03:00
|
|
|
def extra_ldflags(self):
|
|
|
|
if not self.is_rocm_pytorch():
|
|
|
|
return ['-lcurand']
|
|
|
|
else:
|
|
|
|
return []
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
def sources(self):
|
|
|
|
return [
|
2023-03-27 14:55:19 +03:00
|
|
|
'csrc/transformer/ds_transformer_cuda.cpp', 'csrc/transformer/cublas_wrappers.cu',
|
|
|
|
'csrc/transformer/transform_kernels.cu', 'csrc/transformer/gelu_kernels.cu',
|
|
|
|
'csrc/transformer/dropout_kernels.cu', 'csrc/transformer/normalize_kernels.cu',
|
|
|
|
'csrc/transformer/softmax_kernels.cu', 'csrc/transformer/general_kernels.cu'
|
2020-11-12 22:51:38 +03:00
|
|
|
]
|
|
|
|
|
|
|
|
def include_paths(self):
|
2022-03-03 04:53:35 +03:00
|
|
|
includes = ['csrc/includes']
|
|
|
|
return includes
|