Updates CUDA compilation to use device compute capability (#57)

This commit is contained in:
Kern Handa 2022-06-01 13:51:19 -07:00 коммит произвёл GitHub
Родитель 476ecffaac
Коммит aea0a8d3fc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 19 добавлений и 5 удалений

Просмотреть файл

@ -23,7 +23,7 @@ def ASSERT_DRV(err):
def _find_cuda_incl_path() -> pathlib.Path:
"Tries to find the CUDA include path."
cuda_path = os.getenv("CUDA_PATH")
if not cuda_path:
if cuda_path is None:
if sys.platform == 'linux':
cuda_path = pathlib.Path("/usr/local/cuda/include")
if not (cuda_path.exists() and cuda_path.is_dir()):
@ -33,19 +33,33 @@ def _find_cuda_incl_path() -> pathlib.Path:
elif sys.platform == 'darwin':
...
else:
cuda_path = pathlib.Path(cuda_path)
cuda_path /= "include"
return cuda_path
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name):
def _get_compute_capability(gpu_id) -> int:
err, major = cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, gpu_id)
ASSERT_DRV(err)
err, minor = cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, gpu_id)
ASSERT_DRV(err)
return (major * 10) + minor
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name, gpu_id):
src = cuda_src_path.read_text()
cuda_incl_path = _find_cuda_incl_path()
if not cuda_incl_path:
raise RuntimeError("Unable to determine CUDA include path. Please set CUDA_PATH environment variable.")
opts = [
# https://docs.nvidia.com/cuda/nvrtc/index.html#group__options
b'--gpu-architecture=compute_86',
f'--gpu-architecture=compute_{_get_compute_capability(gpu_id)}'.encode(),
b'--ptxas-options=--warn-on-spills', # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-passing-specific-phase-options-ptxas-options
b'-use_fast_math',
b'--include-path=/usr/local/cuda-11.6/include/',
b'--include-path=' + str(cuda_incl_path).encode(),
b'-std=c++17',
b'-default-device',
#b'--restrict',
@ -184,7 +198,7 @@ class CudaCallableFunc(CallableFunc):
ptx = _PTX_CACHE.get(self.cuda_src_path)
if not ptx:
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_name)
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_name, gpu_id)
self.kernel = get_func_from_ptx(ptx, self.func_name)