зеркало из https://github.com/microsoft/hat.git
Updates CUDA compilation to use device compute capability (#57)
This commit is contained in:
Родитель
476ecffaac
Коммит
aea0a8d3fc
|
@ -23,7 +23,7 @@ def ASSERT_DRV(err):
|
||||||
def _find_cuda_incl_path() -> pathlib.Path:
|
def _find_cuda_incl_path() -> pathlib.Path:
|
||||||
"Tries to find the CUDA include path."
|
"Tries to find the CUDA include path."
|
||||||
cuda_path = os.getenv("CUDA_PATH")
|
cuda_path = os.getenv("CUDA_PATH")
|
||||||
if not cuda_path:
|
if cuda_path is None:
|
||||||
if sys.platform == 'linux':
|
if sys.platform == 'linux':
|
||||||
cuda_path = pathlib.Path("/usr/local/cuda/include")
|
cuda_path = pathlib.Path("/usr/local/cuda/include")
|
||||||
if not (cuda_path.exists() and cuda_path.is_dir()):
|
if not (cuda_path.exists() and cuda_path.is_dir()):
|
||||||
|
@ -33,19 +33,33 @@ def _find_cuda_incl_path() -> pathlib.Path:
|
||||||
elif sys.platform == 'darwin':
|
elif sys.platform == 'darwin':
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
|
cuda_path = pathlib.Path(cuda_path)
|
||||||
cuda_path /= "include"
|
cuda_path /= "include"
|
||||||
|
|
||||||
return cuda_path
|
return cuda_path
|
||||||
|
|
||||||
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name):
|
def _get_compute_capability(gpu_id) -> int:
|
||||||
|
err, major = cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, gpu_id)
|
||||||
|
ASSERT_DRV(err)
|
||||||
|
|
||||||
|
err, minor = cuda.cuDeviceGetAttribute(cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, gpu_id)
|
||||||
|
ASSERT_DRV(err)
|
||||||
|
|
||||||
|
return (major * 10) + minor
|
||||||
|
|
||||||
|
def compile_cuda_program(cuda_src_path: pathlib.Path, func_name, gpu_id):
|
||||||
src = cuda_src_path.read_text()
|
src = cuda_src_path.read_text()
|
||||||
|
|
||||||
|
cuda_incl_path = _find_cuda_incl_path()
|
||||||
|
if not cuda_incl_path:
|
||||||
|
raise RuntimeError("Unable to determine CUDA include path. Please set CUDA_PATH environment variable.")
|
||||||
|
|
||||||
opts = [
|
opts = [
|
||||||
# https://docs.nvidia.com/cuda/nvrtc/index.html#group__options
|
# https://docs.nvidia.com/cuda/nvrtc/index.html#group__options
|
||||||
b'--gpu-architecture=compute_86',
|
f'--gpu-architecture=compute_{_get_compute_capability(gpu_id)}'.encode(),
|
||||||
b'--ptxas-options=--warn-on-spills', # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-passing-specific-phase-options-ptxas-options
|
b'--ptxas-options=--warn-on-spills', # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-passing-specific-phase-options-ptxas-options
|
||||||
b'-use_fast_math',
|
b'-use_fast_math',
|
||||||
b'--include-path=/usr/local/cuda-11.6/include/',
|
b'--include-path=' + str(cuda_incl_path).encode(),
|
||||||
b'-std=c++17',
|
b'-std=c++17',
|
||||||
b'-default-device',
|
b'-default-device',
|
||||||
#b'--restrict',
|
#b'--restrict',
|
||||||
|
@ -184,7 +198,7 @@ class CudaCallableFunc(CallableFunc):
|
||||||
|
|
||||||
ptx = _PTX_CACHE.get(self.cuda_src_path)
|
ptx = _PTX_CACHE.get(self.cuda_src_path)
|
||||||
if not ptx:
|
if not ptx:
|
||||||
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_name)
|
_PTX_CACHE[self.cuda_src_path] = ptx = compile_cuda_program(self.cuda_src_path, self.func_name, gpu_id)
|
||||||
|
|
||||||
self.kernel = get_func_from_ptx(ptx, self.func_name)
|
self.kernel = get_func_from_ptx(ptx, self.func_name)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче