2023-03-31 03:14:38 +03:00
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
|
|
# DeepSpeed Team
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
import os
|
2021-05-24 11:10:39 +03:00
|
|
|
import sys
|
2020-11-12 22:51:38 +03:00
|
|
|
import time
|
|
|
|
import importlib
|
|
|
|
from pathlib import Path
|
|
|
|
import subprocess
|
2021-07-29 08:42:27 +03:00
|
|
|
import shlex
|
|
|
|
import shutil
|
|
|
|
import tempfile
|
2021-07-25 00:05:04 +03:00
|
|
|
import distutils.ccompiler
|
2021-07-29 08:42:27 +03:00
|
|
|
import distutils.log
|
|
|
|
import distutils.sysconfig
|
|
|
|
from distutils.errors import CompileError, LinkError
|
2020-11-12 22:51:38 +03:00
|
|
|
from abc import ABC, abstractmethod
|
2022-12-06 21:42:32 +03:00
|
|
|
from typing import List
|
2020-11-12 22:51:38 +03:00
|
|
|
|
|
|
|
YELLOW = '\033[93m'
|
|
|
|
END = '\033[0m'
|
|
|
|
WARNING = f"{YELLOW} [WARNING] {END}"
|
|
|
|
|
|
|
|
DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
|
2020-12-03 04:22:16 +03:00
|
|
|
DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
|
2020-11-12 22:51:38 +03:00
|
|
|
|
2021-06-17 00:18:37 +03:00
|
|
|
try:
|
|
|
|
import torch
|
|
|
|
except ImportError:
|
2023-03-27 14:55:19 +03:00
|
|
|
print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.")
|
2022-03-29 21:01:02 +03:00
|
|
|
else:
|
|
|
|
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
|
|
|
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
2022-03-03 04:53:35 +03:00
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
|
2023-11-14 01:06:46 +03:00
|
|
|
class MissingCUDAException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class CUDAMismatchException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-01-18 02:02:45 +03:00
|
|
|
def installed_cuda_version(name=""):
|
2020-11-12 22:51:38 +03:00
|
|
|
import torch.utils.cpp_extension
|
|
|
|
cuda_home = torch.utils.cpp_extension.CUDA_HOME
|
2023-11-14 01:06:46 +03:00
|
|
|
if cuda_home is None:
|
|
|
|
raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)")
|
2020-11-12 22:51:38 +03:00
|
|
|
# Ensure there is not a cuda version mismatch between torch and nvcc compiler
|
2023-03-27 14:55:19 +03:00
|
|
|
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
|
2020-11-12 22:51:38 +03:00
|
|
|
output_split = output.split()
|
|
|
|
release_idx = output_split.index("release")
|
|
|
|
release = output_split[release_idx + 1].replace(',', '').split(".")
|
|
|
|
# Ignore patch versions, only look at major + minor
|
2020-12-03 04:22:16 +03:00
|
|
|
cuda_major, cuda_minor = release[:2]
|
|
|
|
return int(cuda_major), int(cuda_minor)
|
|
|
|
|
|
|
|
|
2021-10-02 05:56:32 +03:00
|
|
|
def get_default_compute_capabilities():
|
2020-12-03 04:22:16 +03:00
|
|
|
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
|
2020-12-23 09:26:26 +03:00
|
|
|
import torch.utils.cpp_extension
|
2023-03-27 14:55:19 +03:00
|
|
|
if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11:
|
2021-01-07 05:12:39 +03:00
|
|
|
if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
|
|
|
|
# Special treatment of CUDA 11.0 because compute_86 is not supported.
|
|
|
|
compute_caps += ";8.0"
|
|
|
|
else:
|
|
|
|
compute_caps += ";8.0;8.6"
|
2020-12-03 04:22:16 +03:00
|
|
|
return compute_caps
|
|
|
|
|
|
|
|
|
2021-06-07 23:50:03 +03:00
|
|
|
# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
|
|
|
|
# to build deepspeed and system-wide installed cuda 11.2
|
|
|
|
cuda_minor_mismatch_ok = {
|
2023-11-07 20:25:35 +03:00
|
|
|
10: ["10.0", "10.1", "10.2"],
|
2023-03-27 14:55:19 +03:00
|
|
|
11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"],
|
2023-11-04 03:48:29 +03:00
|
|
|
12: ["12.0", "12.1", "12.2", "12.3"],
|
2021-06-07 23:50:03 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2023-01-18 02:02:45 +03:00
|
|
|
def assert_no_cuda_mismatch(name=""):
|
|
|
|
cuda_major, cuda_minor = installed_cuda_version(name)
|
2020-12-03 04:22:16 +03:00
|
|
|
sys_cuda_version = f'{cuda_major}.{cuda_minor}'
|
2020-11-12 22:51:38 +03:00
|
|
|
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
|
|
|
|
# This is a show-stopping error, should probably not proceed past this
|
2020-12-03 04:22:16 +03:00
|
|
|
if sys_cuda_version != torch_cuda_version:
|
2023-03-27 14:55:19 +03:00
|
|
|
if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
|
2021-06-07 23:50:03 +03:00
|
|
|
and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
|
|
|
|
print(f"Installed CUDA version {sys_cuda_version} does not match the "
|
|
|
|
f"version torch was compiled with {torch.version.cuda} "
|
|
|
|
"but since the APIs are compatible, accepting this combination")
|
2023-01-18 02:02:45 +03:00
|
|
|
return True
|
2023-05-04 02:59:08 +03:00
|
|
|
elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1":
|
|
|
|
print(
|
|
|
|
f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
|
|
|
|
f"version torch was compiled with {torch.version.cuda}."
|
|
|
|
"Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior."
|
|
|
|
)
|
|
|
|
return True
|
2023-11-14 01:06:46 +03:00
|
|
|
raise CUDAMismatchException(
|
|
|
|
f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
|
|
|
|
f"version torch was compiled with {torch.version.cuda}, unable to compile "
|
|
|
|
"cuda/cpp extensions without a matching cuda version.")
|
2023-01-18 02:02:45 +03:00
|
|
|
return True
|
2020-11-12 22:51:38 +03:00
|
|
|
|
|
|
|
|
|
|
|
class OpBuilder(ABC):
|
2022-03-03 04:53:35 +03:00
|
|
|
_rocm_version = None
|
|
|
|
_is_rocm_pytorch = None
|
2024-01-05 23:29:07 +03:00
|
|
|
_is_sycl_enabled = None
|
2023-11-04 01:07:35 +03:00
|
|
|
_loaded_ops = {}
|
2022-03-03 04:53:35 +03:00
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
def __init__(self, name):
|
|
|
|
self.name = name
|
|
|
|
self.jit_mode = False
|
2023-01-18 02:02:45 +03:00
|
|
|
self.build_for_cpu = False
|
2023-05-17 00:57:19 +03:00
|
|
|
self.enable_bf16 = False
|
2022-07-30 01:00:12 +03:00
|
|
|
self.error_log = None
|
2020-11-12 22:51:38 +03:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def absolute_name(self):
|
|
|
|
'''
|
|
|
|
Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
|
|
|
|
will be installed as something like: deepspeed/ops/adam/cpu_adam.so
|
|
|
|
'''
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def sources(self):
|
|
|
|
'''
|
|
|
|
Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
|
|
|
|
'''
|
|
|
|
pass
|
|
|
|
|
2022-03-07 21:59:14 +03:00
|
|
|
def hipify_extension(self):
|
|
|
|
pass
|
|
|
|
|
2024-01-05 23:29:07 +03:00
|
|
|
def sycl_extension(self):
|
|
|
|
pass
|
|
|
|
|
2022-03-03 04:53:35 +03:00
|
|
|
@staticmethod
|
2023-01-18 02:02:45 +03:00
|
|
|
def validate_torch_version(torch_info):
|
2022-03-03 04:53:35 +03:00
|
|
|
install_torch_version = torch_info['version']
|
|
|
|
current_torch_version = ".".join(torch.__version__.split('.')[:2])
|
2023-01-18 02:02:45 +03:00
|
|
|
if install_torch_version != current_torch_version:
|
2023-03-27 14:55:19 +03:00
|
|
|
raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed "
|
|
|
|
"with a different version than what is being used at runtime. "
|
|
|
|
f"Please re-install DeepSpeed or switch torch versions. "
|
|
|
|
f"Install torch version={install_torch_version}, "
|
|
|
|
f"Runtime torch version={current_torch_version}")
|
2022-03-03 04:53:35 +03:00
|
|
|
|
2023-01-18 02:02:45 +03:00
|
|
|
@staticmethod
|
|
|
|
def validate_torch_op_version(torch_info):
|
2022-03-03 04:53:35 +03:00
|
|
|
if not OpBuilder.is_rocm_pytorch():
|
2023-01-18 02:02:45 +03:00
|
|
|
current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
|
|
|
|
install_cuda_version = torch_info['cuda_version']
|
|
|
|
if install_cuda_version != current_cuda_version:
|
2023-03-27 14:55:19 +03:00
|
|
|
raise RuntimeError("CUDA version mismatch! DeepSpeed ops were compiled and installed "
|
|
|
|
"with a different version than what is being used at runtime. "
|
|
|
|
f"Please re-install DeepSpeed or switch torch versions. "
|
|
|
|
f"Install CUDA version={install_cuda_version}, "
|
|
|
|
f"Runtime CUDA version={current_cuda_version}")
|
2022-03-03 04:53:35 +03:00
|
|
|
else:
|
2023-01-18 02:02:45 +03:00
|
|
|
current_hip_version = ".".join(torch.version.hip.split('.')[:2])
|
|
|
|
install_hip_version = torch_info['hip_version']
|
|
|
|
if install_hip_version != current_hip_version:
|
2023-03-27 14:55:19 +03:00
|
|
|
raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed "
|
|
|
|
"with a different version than what is being used at runtime. "
|
|
|
|
f"Please re-install DeepSpeed or switch torch versions. "
|
|
|
|
f"Install HIP version={install_hip_version}, "
|
|
|
|
f"Runtime HIP version={current_hip_version}")
|
2022-03-03 04:53:35 +03:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def is_rocm_pytorch():
|
|
|
|
if OpBuilder._is_rocm_pytorch is not None:
|
|
|
|
return OpBuilder._is_rocm_pytorch
|
|
|
|
|
|
|
|
_is_rocm_pytorch = False
|
2022-03-29 21:01:02 +03:00
|
|
|
try:
|
|
|
|
import torch
|
|
|
|
except ImportError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
|
2023-03-27 14:55:19 +03:00
|
|
|
_is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None
|
2022-03-29 21:01:02 +03:00
|
|
|
if _is_rocm_pytorch:
|
|
|
|
from torch.utils.cpp_extension import ROCM_HOME
|
|
|
|
_is_rocm_pytorch = ROCM_HOME is not None
|
2022-03-03 04:53:35 +03:00
|
|
|
OpBuilder._is_rocm_pytorch = _is_rocm_pytorch
|
|
|
|
return OpBuilder._is_rocm_pytorch
|
|
|
|
|
2024-01-05 23:29:07 +03:00
|
|
|
@staticmethod
|
|
|
|
def is_sycl_enabled():
|
|
|
|
if OpBuilder._is_sycl_enabled is not None:
|
|
|
|
return OpBuilder._is_sycl_enabled
|
|
|
|
|
|
|
|
_is_sycl_enabled = False
|
|
|
|
try:
|
|
|
|
result = subprocess.run(["c2s", "--version"], capture_output=True)
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
_is_sycl_enabled = True
|
|
|
|
|
|
|
|
OpBuilder._is_sycl_enabled = _is_sycl_enabled
|
|
|
|
return OpBuilder._is_sycl_enabled
|
|
|
|
|
2022-03-03 04:53:35 +03:00
|
|
|
@staticmethod
|
|
|
|
def installed_rocm_version():
|
|
|
|
if OpBuilder._rocm_version:
|
|
|
|
return OpBuilder._rocm_version
|
|
|
|
|
|
|
|
ROCM_MAJOR = '0'
|
|
|
|
ROCM_MINOR = '0'
|
|
|
|
if OpBuilder.is_rocm_pytorch():
|
|
|
|
from torch.utils.cpp_extension import ROCM_HOME
|
2022-07-20 02:40:58 +03:00
|
|
|
rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev")
|
2022-07-25 20:56:05 +03:00
|
|
|
if rocm_ver_file.is_file():
|
2022-07-20 02:40:58 +03:00
|
|
|
with open(rocm_ver_file, 'r') as file:
|
|
|
|
ROCM_VERSION_DEV_RAW = file.read()
|
2022-07-26 02:48:08 +03:00
|
|
|
elif "rocm" in torch.__version__:
|
2022-07-20 02:40:58 +03:00
|
|
|
ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1]
|
|
|
|
else:
|
|
|
|
assert False, "Could not detect ROCm version"
|
2022-07-26 02:48:08 +03:00
|
|
|
assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version"
|
2022-03-03 04:53:35 +03:00
|
|
|
ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0]
|
|
|
|
ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1]
|
|
|
|
OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR))
|
|
|
|
return OpBuilder._rocm_version
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
def include_paths(self):
|
|
|
|
'''
|
|
|
|
Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
|
|
|
|
'''
|
|
|
|
return []
|
|
|
|
|
|
|
|
def nvcc_args(self):
|
|
|
|
'''
|
|
|
|
Returns optional list of compiler flags to forward to nvcc when building CUDA sources
|
|
|
|
'''
|
|
|
|
return []
|
|
|
|
|
|
|
|
def cxx_args(self):
|
|
|
|
'''
|
|
|
|
Returns optional list of compiler flags to forward to the build
|
|
|
|
'''
|
|
|
|
return []
|
|
|
|
|
2021-12-09 03:29:13 +03:00
|
|
|
def is_compatible(self, verbose=True):
|
2020-11-12 22:51:38 +03:00
|
|
|
'''
|
|
|
|
Check if all non-python dependencies are satisfied to build this op
|
|
|
|
'''
|
|
|
|
return True
|
|
|
|
|
2021-03-08 23:54:54 +03:00
|
|
|
def extra_ldflags(self):
|
|
|
|
return []
|
|
|
|
|
2021-07-29 08:42:27 +03:00
|
|
|
def has_function(self, funcname, libraries, verbose=False):
|
|
|
|
'''
|
|
|
|
Test for existence of a function within a tuple of libraries.
|
|
|
|
|
2021-10-02 05:56:32 +03:00
|
|
|
This is used as a smoke test to check whether a certain library is available.
|
2021-07-29 08:42:27 +03:00
|
|
|
As a test, this creates a simple C program that calls the specified function,
|
|
|
|
and then distutils is used to compile that program and link it with the specified libraries.
|
|
|
|
Returns True if both the compile and link are successful, False otherwise.
|
|
|
|
'''
|
|
|
|
tempdir = None # we create a temporary directory to hold various files
|
|
|
|
filestderr = None # handle to open file to which we redirect stderr
|
|
|
|
oldstderr = None # file descriptor for stderr
|
|
|
|
try:
|
|
|
|
# Echo compile and link commands that are used.
|
|
|
|
if verbose:
|
|
|
|
distutils.log.set_verbosity(1)
|
|
|
|
|
|
|
|
# Create a compiler object.
|
|
|
|
compiler = distutils.ccompiler.new_compiler(verbose=verbose)
|
|
|
|
|
|
|
|
# Configure compiler and linker to build according to Python install.
|
|
|
|
distutils.sysconfig.customize_compiler(compiler)
|
|
|
|
|
|
|
|
# Create a temporary directory to hold test files.
|
|
|
|
tempdir = tempfile.mkdtemp()
|
|
|
|
|
|
|
|
# Define a simple C program that calls the function in question
|
2023-03-27 14:55:19 +03:00
|
|
|
prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname)
|
2021-07-29 08:42:27 +03:00
|
|
|
|
|
|
|
# Write the test program to a file.
|
|
|
|
filename = os.path.join(tempdir, 'test.c')
|
|
|
|
with open(filename, 'w') as f:
|
|
|
|
f.write(prog)
|
|
|
|
|
|
|
|
# Redirect stderr file descriptor to a file to silence compile/link warnings.
|
|
|
|
if not verbose:
|
|
|
|
filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
|
|
|
|
oldstderr = os.dup(sys.stderr.fileno())
|
|
|
|
os.dup2(filestderr.fileno(), sys.stderr.fileno())
|
|
|
|
|
2021-11-03 00:41:10 +03:00
|
|
|
# Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames()
|
|
|
|
# Otherwise, a local directory will be used instead of tempdir
|
|
|
|
drive, driveless_filename = os.path.splitdrive(filename)
|
|
|
|
root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else ''
|
|
|
|
output_dir = os.path.join(drive, root_dir)
|
|
|
|
|
2021-07-29 08:42:27 +03:00
|
|
|
# Attempt to compile the C program into an object file.
|
|
|
|
cflags = shlex.split(os.environ.get('CFLAGS', ""))
|
2023-03-27 14:55:19 +03:00
|
|
|
objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags))
|
2021-07-29 08:42:27 +03:00
|
|
|
|
|
|
|
# Attempt to link the object file into an executable.
|
|
|
|
# Be sure to tack on any libraries that have been specified.
|
|
|
|
ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
|
|
|
|
compiler.link_executable(objs,
|
2023-03-27 14:55:19 +03:00
|
|
|
os.path.join(tempdir, 'a.out'),
|
2021-07-29 08:42:27 +03:00
|
|
|
extra_preargs=self.strip_empty_entries(ldflags),
|
|
|
|
libraries=libraries)
|
|
|
|
|
|
|
|
# Compile and link succeeded
|
|
|
|
return True
|
|
|
|
|
|
|
|
except CompileError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
except LinkError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
except:
|
|
|
|
return False
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# Restore stderr file descriptor and close the stderr redirect file.
|
|
|
|
if oldstderr is not None:
|
|
|
|
os.dup2(oldstderr, sys.stderr.fileno())
|
|
|
|
if filestderr is not None:
|
|
|
|
filestderr.close()
|
|
|
|
|
|
|
|
# Delete the temporary directory holding the test program and stderr files.
|
|
|
|
if tempdir is not None:
|
|
|
|
shutil.rmtree(tempdir)
|
2021-07-25 00:05:04 +03:00
|
|
|
|
2021-07-16 00:46:47 +03:00
|
|
|
def strip_empty_entries(self, args):
|
|
|
|
'''
|
|
|
|
Drop any empty strings from the list of compile and link flags
|
|
|
|
'''
|
|
|
|
return [x for x in args if len(x) > 0]
|
|
|
|
|
|
|
|
def cpu_arch(self):
|
2021-12-11 05:31:55 +03:00
|
|
|
try:
|
|
|
|
from cpuinfo import get_cpu_info
|
|
|
|
except ImportError as e:
|
2021-12-22 01:00:41 +03:00
|
|
|
cpu_info = self._backup_cpuinfo()
|
2021-12-21 02:13:37 +03:00
|
|
|
if cpu_info is None:
|
|
|
|
return "-march=native"
|
|
|
|
|
|
|
|
try:
|
|
|
|
cpu_info = get_cpu_info()
|
2021-12-22 01:00:41 +03:00
|
|
|
except Exception as e:
|
2023-03-27 14:55:19 +03:00
|
|
|
self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
|
|
|
|
"falling back to `lscpu` to get this information.")
|
2021-12-22 01:00:41 +03:00
|
|
|
cpu_info = self._backup_cpuinfo()
|
2021-12-21 02:13:37 +03:00
|
|
|
if cpu_info is None:
|
|
|
|
return "-march=native"
|
2021-12-11 05:31:55 +03:00
|
|
|
|
|
|
|
if cpu_info['arch'].startswith('PPC_'):
|
2021-07-16 00:46:47 +03:00
|
|
|
# gcc does not provide -march on PowerPC, use -mcpu instead
|
|
|
|
return '-mcpu=native'
|
|
|
|
return '-march=native'
|
|
|
|
|
2023-01-18 02:02:45 +03:00
|
|
|
def is_cuda_enable(self):
|
|
|
|
try:
|
2023-04-19 03:45:43 +03:00
|
|
|
assert_no_cuda_mismatch(self.name)
|
|
|
|
return '-D__ENABLE_CUDA__'
|
2023-11-14 01:06:46 +03:00
|
|
|
except MissingCUDAException:
|
2023-04-19 03:45:43 +03:00
|
|
|
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
|
|
|
|
"only cpu ops can be compiled!")
|
2023-01-18 02:02:45 +03:00
|
|
|
return '-D__DISABLE_CUDA__'
|
|
|
|
return '-D__DISABLE_CUDA__'
|
|
|
|
|
2021-12-21 02:13:37 +03:00
|
|
|
def _backup_cpuinfo(self):
|
|
|
|
# Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides
|
|
|
|
if not self.command_exists('lscpu'):
|
2023-03-27 14:55:19 +03:00
|
|
|
self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo "
|
|
|
|
"to detect the CPU architecture. 'lscpu' does not appear to exist on "
|
|
|
|
"your system, will fall back to use -march=native and non-vectorized execution.")
|
2021-12-21 02:13:37 +03:00
|
|
|
return None
|
|
|
|
result = subprocess.check_output('lscpu', shell=True)
|
|
|
|
result = result.decode('utf-8').strip().lower()
|
|
|
|
|
|
|
|
cpu_info = {}
|
|
|
|
cpu_info['arch'] = None
|
|
|
|
cpu_info['flags'] = ""
|
|
|
|
if 'genuineintel' in result or 'authenticamd' in result:
|
|
|
|
cpu_info['arch'] = 'X86_64'
|
|
|
|
if 'avx512' in result:
|
|
|
|
cpu_info['flags'] += 'avx512,'
|
2022-12-17 16:57:28 +03:00
|
|
|
elif 'avx512f' in result:
|
|
|
|
cpu_info['flags'] += 'avx512f,'
|
2021-12-21 02:13:37 +03:00
|
|
|
if 'avx2' in result:
|
|
|
|
cpu_info['flags'] += 'avx2'
|
|
|
|
elif 'ppc64le' in result:
|
|
|
|
cpu_info['arch'] = "PPC_"
|
|
|
|
|
|
|
|
return cpu_info
|
|
|
|
|
2021-03-08 23:54:54 +03:00
|
|
|
def simd_width(self):
|
2021-12-11 05:31:55 +03:00
|
|
|
try:
|
|
|
|
from cpuinfo import get_cpu_info
|
|
|
|
except ImportError as e:
|
2021-12-22 01:00:41 +03:00
|
|
|
cpu_info = self._backup_cpuinfo()
|
2021-12-21 02:13:37 +03:00
|
|
|
if cpu_info is None:
|
|
|
|
return '-D__SCALAR__'
|
2021-03-08 23:54:54 +03:00
|
|
|
|
2021-08-27 21:04:13 +03:00
|
|
|
try:
|
2021-12-11 05:31:55 +03:00
|
|
|
cpu_info = get_cpu_info()
|
2021-12-22 01:00:41 +03:00
|
|
|
except Exception as e:
|
2023-03-27 14:55:19 +03:00
|
|
|
self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
|
|
|
|
"falling back to `lscpu` to get this information.")
|
2021-12-22 01:00:41 +03:00
|
|
|
cpu_info = self._backup_cpuinfo()
|
2021-12-21 02:13:37 +03:00
|
|
|
if cpu_info is None:
|
|
|
|
return '-D__SCALAR__'
|
2021-08-27 21:04:13 +03:00
|
|
|
|
2021-12-11 05:31:55 +03:00
|
|
|
if cpu_info['arch'] == 'X86_64':
|
2022-12-17 16:57:28 +03:00
|
|
|
if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']:
|
2021-03-08 23:54:54 +03:00
|
|
|
return '-D__AVX512__'
|
2021-12-11 05:31:55 +03:00
|
|
|
elif 'avx2' in cpu_info['flags']:
|
2021-11-23 08:15:14 +03:00
|
|
|
return '-D__AVX256__'
|
2021-07-16 00:46:47 +03:00
|
|
|
return '-D__SCALAR__'
|
2021-03-08 23:54:54 +03:00
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
def command_exists(self, cmd):
|
|
|
|
if '|' in cmd:
|
|
|
|
cmds = cmd.split("|")
|
|
|
|
else:
|
|
|
|
cmds = [cmd]
|
|
|
|
valid = False
|
|
|
|
for cmd in cmds:
|
|
|
|
result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
|
|
|
|
valid = valid or result.wait() == 0
|
|
|
|
|
|
|
|
if not valid and len(cmds) > 1:
|
2023-03-27 14:55:19 +03:00
|
|
|
print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!")
|
2020-11-12 22:51:38 +03:00
|
|
|
elif not valid and len(cmds) == 1:
|
2023-03-27 14:55:19 +03:00
|
|
|
print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!")
|
2020-11-12 22:51:38 +03:00
|
|
|
return valid
|
|
|
|
|
|
|
|
def warning(self, msg):
|
2022-07-30 01:00:12 +03:00
|
|
|
self.error_log = f"{msg}"
|
2020-11-12 22:51:38 +03:00
|
|
|
print(f"{WARNING} {msg}")
|
|
|
|
|
|
|
|
def deepspeed_src_path(self, code_path):
|
|
|
|
if os.path.isabs(code_path):
|
|
|
|
return code_path
|
|
|
|
else:
|
|
|
|
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
|
|
|
|
|
|
|
|
def builder(self):
|
|
|
|
from torch.utils.cpp_extension import CppExtension
|
2023-03-27 14:55:19 +03:00
|
|
|
return CppExtension(name=self.absolute_name(),
|
|
|
|
sources=self.strip_empty_entries(self.sources()),
|
|
|
|
include_dirs=self.strip_empty_entries(self.include_paths()),
|
|
|
|
extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
|
|
|
|
extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
|
2020-11-12 22:51:38 +03:00
|
|
|
|
|
|
|
def load(self, verbose=True):
|
2023-11-04 01:07:35 +03:00
|
|
|
if self.name in __class__._loaded_ops:
|
|
|
|
return __class__._loaded_ops[self.name]
|
|
|
|
|
2023-01-11 20:59:11 +03:00
|
|
|
from deepspeed.git_version_info import installed_ops, torch_info
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
if installed_ops.get(self.name, False):
|
2020-11-12 22:51:38 +03:00
|
|
|
# Ensure the op we're about to load was compiled with the same
|
|
|
|
# torch/cuda versions we are currently using at runtime.
|
2023-01-18 02:02:45 +03:00
|
|
|
self.validate_torch_version(torch_info)
|
|
|
|
if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder):
|
|
|
|
self.validate_torch_op_version(torch_info)
|
2020-11-12 22:51:38 +03:00
|
|
|
|
2023-11-04 01:07:35 +03:00
|
|
|
op_module = importlib.import_module(self.absolute_name())
|
|
|
|
__class__._loaded_ops[self.name] = op_module
|
|
|
|
return op_module
|
2020-11-12 22:51:38 +03:00
|
|
|
else:
|
|
|
|
return self.jit_load(verbose)
|
|
|
|
|
|
|
|
def jit_load(self, verbose=True):
|
2021-12-09 03:29:13 +03:00
|
|
|
if not self.is_compatible(verbose):
|
2020-11-12 22:51:38 +03:00
|
|
|
raise RuntimeError(
|
2022-07-30 01:00:12 +03:00
|
|
|
f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
|
2020-11-12 22:51:38 +03:00
|
|
|
)
|
|
|
|
try:
|
2023-08-08 16:30:23 +03:00
|
|
|
import ninja # noqa: F401 # type: ignore
|
2020-11-12 22:51:38 +03:00
|
|
|
except ImportError:
|
2023-03-27 14:55:19 +03:00
|
|
|
raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
|
2020-11-12 22:51:38 +03:00
|
|
|
|
2022-03-03 04:53:35 +03:00
|
|
|
if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
|
2023-11-07 20:25:35 +03:00
|
|
|
self.build_for_cpu = not torch.cuda.is_available()
|
2020-11-12 22:51:38 +03:00
|
|
|
|
|
|
|
self.jit_mode = True
|
|
|
|
from torch.utils.cpp_extension import load
|
|
|
|
|
|
|
|
start_build = time.time()
|
Resolve any '..' in the file paths using os.path.abspath() (#4709)
This PR is to resolve any '..' in the file paths like below using
os.path.abspath()
```
sources: ['/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/core_ops.cpp', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/bias_activations/bias_activation.cu', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu']
extra_include_paths: ['/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/includes', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/bias_activations', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/blas_kernels', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/cuda_layer_norm', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/cuda_rms_norm', '/opt/conda/envs/py_3.9/lib/python3.9/site-packages/deepspeed/ops/../inference/v2/kernels/core_ops/gated_activations']
```
It fixes the hipify errors that occur during JIT build of
'inference_core_ops' extension due to ".." prefix in the paths,
https://github.com/microsoft/DeepSpeed/blob/0ec2d3e4bfa2d0a5237e9747da1ef9d5e4a4453b/op_builder/inference_core_ops.py#L73
https://github.com/microsoft/DeepSpeed/blob/0ec2d3e4bfa2d0a5237e9747da1ef9d5e4a4453b/op_builder/inference_core_ops.py#L90
cc @jithunnair-amd
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2023-12-05 21:42:34 +03:00
|
|
|
sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()]
|
|
|
|
extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()]
|
2022-02-12 00:41:57 +03:00
|
|
|
|
|
|
|
# Torch will try and apply whatever CCs are in the arch list at compile time,
|
|
|
|
# we have already set the intended targets ourselves we know that will be
|
|
|
|
# needed at runtime. This prevents CC collisions such as multiple __half
|
|
|
|
# implementations. Stash arch list to reset after build.
|
|
|
|
torch_arch_list = None
|
|
|
|
if "TORCH_CUDA_ARCH_LIST" in os.environ:
|
|
|
|
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
|
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
|
|
|
|
2023-05-04 03:20:07 +03:00
|
|
|
nvcc_args = self.strip_empty_entries(self.nvcc_args())
|
|
|
|
cxx_args = self.strip_empty_entries(self.cxx_args())
|
|
|
|
|
|
|
|
if isinstance(self, CUDAOpBuilder):
|
|
|
|
if not self.build_for_cpu and self.enable_bf16:
|
|
|
|
cxx_args.append("-DBF16_AVAILABLE")
|
|
|
|
nvcc_args.append("-DBF16_AVAILABLE")
|
2023-11-04 01:07:35 +03:00
|
|
|
nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
|
|
|
|
nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
|
2023-05-04 03:20:07 +03:00
|
|
|
|
2023-10-26 20:37:13 +03:00
|
|
|
if self.is_rocm_pytorch():
|
|
|
|
cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
|
|
|
|
|
2023-03-27 14:55:19 +03:00
|
|
|
op_module = load(name=self.name,
|
|
|
|
sources=self.strip_empty_entries(sources),
|
|
|
|
extra_include_paths=self.strip_empty_entries(extra_include_paths),
|
2023-05-04 03:20:07 +03:00
|
|
|
extra_cflags=cxx_args,
|
|
|
|
extra_cuda_cflags=nvcc_args,
|
2023-03-27 14:55:19 +03:00
|
|
|
extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
|
|
|
|
verbose=verbose)
|
2023-01-18 02:02:45 +03:00
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
build_duration = time.time() - start_build
|
|
|
|
if verbose:
|
|
|
|
print(f"Time to load {self.name} op: {build_duration} seconds")
|
2022-02-12 00:41:57 +03:00
|
|
|
|
|
|
|
# Reset arch list so we are not silently removing it for other possible use cases
|
|
|
|
if torch_arch_list:
|
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
|
|
|
|
|
2023-11-04 01:07:35 +03:00
|
|
|
__class__._loaded_ops[self.name] = op_module
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
return op_module
|
|
|
|
|
|
|
|
|
|
|
|
class CUDAOpBuilder(OpBuilder):
|
2023-03-27 14:55:19 +03:00
|
|
|
|
2020-12-03 04:22:16 +03:00
|
|
|
def compute_capability_args(self, cross_compile_archs=None):
|
2020-12-07 23:08:41 +03:00
|
|
|
"""
|
|
|
|
Returns nvcc compute capability compile flags.
|
2020-12-03 04:22:16 +03:00
|
|
|
|
2020-12-07 23:08:41 +03:00
|
|
|
1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
|
|
|
|
2. If neither is set default compute capabilities will be used
|
2020-12-11 21:15:33 +03:00
|
|
|
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
|
2020-12-07 23:08:41 +03:00
|
|
|
|
|
|
|
Format:
|
|
|
|
|
|
|
|
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
|
|
|
|
|
|
|
|
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
|
2022-12-06 21:42:32 +03:00
|
|
|
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
|
2020-12-07 23:08:41 +03:00
|
|
|
|
|
|
|
- `cross_compile_archs` uses ; separator.
|
|
|
|
|
|
|
|
"""
|
|
|
|
ccs = []
|
2020-11-12 22:51:38 +03:00
|
|
|
if self.jit_mode:
|
2020-12-07 23:08:41 +03:00
|
|
|
# Compile for underlying architectures since we know those at runtime
|
|
|
|
for i in range(torch.cuda.device_count()):
|
|
|
|
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
|
|
|
|
cc = f"{CC_MAJOR}.{CC_MINOR}"
|
|
|
|
if cc not in ccs:
|
|
|
|
ccs.append(cc)
|
|
|
|
ccs = sorted(ccs)
|
2020-12-11 21:15:33 +03:00
|
|
|
ccs[-1] += '+PTX'
|
2020-11-12 22:51:38 +03:00
|
|
|
else:
|
|
|
|
# Cross-compile mode, compile for various architectures
|
2020-12-07 23:08:41 +03:00
|
|
|
# env override takes priority
|
|
|
|
cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
|
|
|
|
if cross_compile_archs_env is not None:
|
|
|
|
if cross_compile_archs is not None:
|
|
|
|
print(
|
|
|
|
f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
|
|
|
|
)
|
|
|
|
cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
|
|
|
|
else:
|
|
|
|
if cross_compile_archs is None:
|
2021-10-02 05:56:32 +03:00
|
|
|
cross_compile_archs = get_default_compute_capabilities()
|
2020-12-07 23:08:41 +03:00
|
|
|
ccs = cross_compile_archs.split(';')
|
|
|
|
|
2022-12-06 21:42:32 +03:00
|
|
|
ccs = self.filter_ccs(ccs)
|
|
|
|
if len(ccs) == 0:
|
|
|
|
raise RuntimeError(
|
2023-03-27 14:55:19 +03:00
|
|
|
f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")
|
2022-12-06 21:42:32 +03:00
|
|
|
|
2020-12-07 23:08:41 +03:00
|
|
|
args = []
|
2023-05-04 03:20:07 +03:00
|
|
|
self.enable_bf16 = True
|
2020-12-07 23:08:41 +03:00
|
|
|
for cc in ccs:
|
2020-12-11 21:15:33 +03:00
|
|
|
num = cc[0] + cc[2]
|
|
|
|
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
|
|
|
|
if cc.endswith('+PTX'):
|
|
|
|
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
|
2020-12-07 23:08:41 +03:00
|
|
|
|
2023-05-04 03:20:07 +03:00
|
|
|
if int(cc[0]) <= 7:
|
|
|
|
self.enable_bf16 = False
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
return args
|
|
|
|
|
2022-12-06 21:42:32 +03:00
|
|
|
def filter_ccs(self, ccs: List[str]):
|
|
|
|
"""
|
|
|
|
Prune any compute capabilities that are not compatible with the builder. Should log
|
|
|
|
which CCs have been pruned.
|
|
|
|
"""
|
|
|
|
return ccs
|
|
|
|
|
2020-11-12 22:51:38 +03:00
|
|
|
def version_dependent_macros(self):
|
|
|
|
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
|
|
|
|
version_ge_1_1 = []
|
|
|
|
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
|
|
|
|
version_ge_1_1 = ['-DVERSION_GE_1_1']
|
|
|
|
version_ge_1_3 = []
|
|
|
|
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
|
|
|
|
version_ge_1_3 = ['-DVERSION_GE_1_3']
|
|
|
|
version_ge_1_5 = []
|
|
|
|
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
|
|
|
|
version_ge_1_5 = ['-DVERSION_GE_1_5']
|
|
|
|
return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
|
|
|
|
|
2021-12-09 03:29:13 +03:00
|
|
|
def is_compatible(self, verbose=True):
|
|
|
|
return super().is_compatible(verbose)
|
2020-11-12 22:51:38 +03:00
|
|
|
|
|
|
|
def builder(self):
|
2023-04-19 03:45:43 +03:00
|
|
|
try:
|
2023-06-06 19:25:12 +03:00
|
|
|
if not self.is_rocm_pytorch():
|
|
|
|
assert_no_cuda_mismatch(self.name)
|
2023-04-19 03:45:43 +03:00
|
|
|
self.build_for_cpu = False
|
2023-11-14 01:06:46 +03:00
|
|
|
except MissingCUDAException:
|
2023-04-19 03:45:43 +03:00
|
|
|
self.build_for_cpu = True
|
|
|
|
|
2023-01-18 02:02:45 +03:00
|
|
|
if self.build_for_cpu:
|
|
|
|
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
|
|
|
|
else:
|
|
|
|
from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder
|
|
|
|
|
|
|
|
compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \
|
|
|
|
{'cxx': self.strip_empty_entries(self.cxx_args()), \
|
2023-11-07 20:25:35 +03:00
|
|
|
'nvcc': self.strip_empty_entries(self.nvcc_args())}
|
2023-01-18 02:02:45 +03:00
|
|
|
|
2023-05-04 03:20:07 +03:00
|
|
|
if not self.build_for_cpu and self.enable_bf16:
|
|
|
|
compile_args['cxx'].append("-DBF16_AVAILABLE")
|
|
|
|
|
2023-11-08 21:28:03 +03:00
|
|
|
if self.is_rocm_pytorch():
|
|
|
|
compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1")
|
|
|
|
|
2023-03-27 14:55:19 +03:00
|
|
|
cuda_ext = ExtensionBuilder(name=self.absolute_name(),
|
|
|
|
sources=self.strip_empty_entries(self.sources()),
|
|
|
|
include_dirs=self.strip_empty_entries(self.include_paths()),
|
|
|
|
libraries=self.strip_empty_entries(self.libraries_args()),
|
2023-07-07 22:36:02 +03:00
|
|
|
extra_compile_args=compile_args,
|
|
|
|
extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
|
2023-01-18 02:02:45 +03:00
|
|
|
|
2022-03-03 04:53:35 +03:00
|
|
|
if self.is_rocm_pytorch():
|
|
|
|
# hip converts paths to absolute, this converts back to relative
|
|
|
|
sources = cuda_ext.sources
|
|
|
|
curr_file = Path(__file__).parent.parent # ds root
|
|
|
|
for i in range(len(sources)):
|
|
|
|
src = Path(sources[i])
|
2023-04-07 18:59:36 +03:00
|
|
|
if src.is_absolute():
|
|
|
|
sources[i] = str(src.relative_to(curr_file))
|
|
|
|
else:
|
|
|
|
sources[i] = str(src)
|
2022-03-03 04:53:35 +03:00
|
|
|
cuda_ext.sources = sources
|
|
|
|
return cuda_ext
|
2021-05-24 11:10:39 +03:00
|
|
|
|
2022-03-07 21:59:14 +03:00
|
|
|
def hipify_extension(self):
|
|
|
|
if self.is_rocm_pytorch():
|
|
|
|
from torch.utils.hipify import hipify_python
|
|
|
|
hipify_python.hipify(
|
|
|
|
project_directory=os.getcwd(),
|
|
|
|
output_directory=os.getcwd(),
|
|
|
|
header_include_dirs=self.include_paths(),
|
2023-03-27 14:55:19 +03:00
|
|
|
includes=[os.path.join(os.getcwd(), '*')],
|
2022-03-07 21:59:14 +03:00
|
|
|
extra_files=[os.path.abspath(s) for s in self.sources()],
|
|
|
|
show_detailed=True,
|
|
|
|
is_pytorch_extension=True,
|
|
|
|
hipify_extra_files_only=True,
|
|
|
|
)
|
|
|
|
|
2021-05-24 11:10:39 +03:00
|
|
|
def cxx_args(self):
|
|
|
|
if sys.platform == "win32":
|
|
|
|
return ['-O2']
|
|
|
|
else:
|
2023-05-25 02:27:32 +03:00
|
|
|
return ['-O3', '-std=c++17', '-g', '-Wno-reorder']
|
2021-05-24 11:10:39 +03:00
|
|
|
|
2021-06-10 03:24:43 +03:00
|
|
|
def nvcc_args(self):
|
2023-01-18 02:02:45 +03:00
|
|
|
if self.build_for_cpu:
|
|
|
|
return []
|
2022-03-03 04:53:35 +03:00
|
|
|
args = ['-O3']
|
|
|
|
if self.is_rocm_pytorch():
|
|
|
|
ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version()
|
|
|
|
args += [
|
2023-05-25 02:27:32 +03:00
|
|
|
'-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__',
|
2022-03-03 04:53:35 +03:00
|
|
|
'-U__HIP_NO_HALF2_OPERATORS__',
|
|
|
|
'-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
|
|
|
|
'-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
cuda_major, _ = installed_cuda_version()
|
|
|
|
args += [
|
2023-03-27 14:55:19 +03:00
|
|
|
'-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math',
|
2023-05-25 02:27:32 +03:00
|
|
|
'-std=c++17' if cuda_major > 10 else '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__',
|
|
|
|
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__'
|
2022-03-03 04:53:35 +03:00
|
|
|
]
|
2023-01-28 23:04:57 +03:00
|
|
|
if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1':
|
|
|
|
args.append('--ptxas-options=-v')
|
2022-03-03 04:53:35 +03:00
|
|
|
args += self.compute_capability_args()
|
|
|
|
return args
|
2021-06-10 03:24:43 +03:00
|
|
|
|
2021-05-24 11:10:39 +03:00
|
|
|
def libraries_args(self):
|
2023-01-18 02:02:45 +03:00
|
|
|
if self.build_for_cpu:
|
|
|
|
return []
|
|
|
|
|
2021-05-24 11:10:39 +03:00
|
|
|
if sys.platform == "win32":
|
|
|
|
return ['cublas', 'curand']
|
|
|
|
else:
|
|
|
|
return []
|
2022-01-05 21:59:02 +03:00
|
|
|
|
|
|
|
|
|
|
|
class TorchCPUOpBuilder(CUDAOpBuilder):
|
2023-03-27 14:55:19 +03:00
|
|
|
|
2022-01-11 00:47:56 +03:00
|
|
|
def extra_ldflags(self):
|
2023-01-18 02:02:45 +03:00
|
|
|
if self.build_for_cpu:
|
|
|
|
return ['-fopenmp']
|
|
|
|
|
2022-03-03 04:53:35 +03:00
|
|
|
if not self.is_rocm_pytorch():
|
|
|
|
return ['-lcurand']
|
2023-01-18 02:02:45 +03:00
|
|
|
|
|
|
|
return []
|
2022-01-11 00:47:56 +03:00
|
|
|
|
2022-01-05 21:59:02 +03:00
|
|
|
def cxx_args(self):
|
|
|
|
import torch
|
2023-01-18 02:02:45 +03:00
|
|
|
args = []
|
|
|
|
if not self.build_for_cpu:
|
|
|
|
if not self.is_rocm_pytorch():
|
|
|
|
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
|
2023-08-30 20:37:16 +03:00
|
|
|
if not os.path.exists(CUDA_LIB64):
|
|
|
|
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
|
2023-01-18 02:02:45 +03:00
|
|
|
else:
|
|
|
|
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
|
|
|
|
|
|
|
|
args += super().cxx_args()
|
|
|
|
args += [
|
|
|
|
f'-L{CUDA_LIB64}',
|
|
|
|
'-lcudart',
|
|
|
|
'-lcublas',
|
|
|
|
'-g',
|
|
|
|
]
|
|
|
|
|
2022-01-05 21:59:02 +03:00
|
|
|
CPU_ARCH = self.cpu_arch()
|
|
|
|
SIMD_WIDTH = self.simd_width()
|
2023-01-18 02:02:45 +03:00
|
|
|
CUDA_ENABLE = self.is_cuda_enable()
|
2022-01-05 21:59:02 +03:00
|
|
|
args += [
|
|
|
|
CPU_ARCH,
|
|
|
|
'-fopenmp',
|
|
|
|
SIMD_WIDTH,
|
2023-01-18 02:02:45 +03:00
|
|
|
CUDA_ENABLE,
|
2022-01-05 21:59:02 +03:00
|
|
|
]
|
2023-01-18 02:02:45 +03:00
|
|
|
|
2022-01-05 21:59:02 +03:00
|
|
|
return args
|