2023-02-27 22:20:41 +03:00
|
|
|
'''Copyright The Microsoft DeepSpeed Team'''
|
|
|
|
|
2023-01-07 07:40:58 +03:00
|
|
|
import os
|
|
|
|
import pkgutil
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
from .abstract_accelerator import DeepSpeedAccelerator
|
|
|
|
# During setup stage torch may not be installed, pass on no torch will
|
|
|
|
# allow op builder related API to be executed.
|
|
|
|
try:
|
|
|
|
import torch.cuda
|
|
|
|
except ImportError:
|
|
|
|
pass
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
|
|
|
|
class CUDA_Accelerator(DeepSpeedAccelerator):
|
|
|
|
def __init__(self):
|
|
|
|
self._name = 'cuda'
|
|
|
|
self._communication_backend_name = 'nccl'
|
|
|
|
|
2023-01-07 07:40:58 +03:00
|
|
|
# begin initialize for create_op_builder()
|
|
|
|
# put all valid class name <--> class type mapping into class_dict
|
|
|
|
op_builder_dir = self.op_builder_dir()
|
|
|
|
op_builder_module = importlib.import_module(op_builder_dir)
|
|
|
|
|
|
|
|
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
|
|
|
|
# avoid self references
|
2023-01-26 17:03:12 +03:00
|
|
|
if module_name != 'all_ops' and module_name != 'builder':
|
2023-01-07 07:40:58 +03:00
|
|
|
module = importlib.import_module("{}.{}".format(
|
|
|
|
op_builder_dir,
|
|
|
|
module_name))
|
|
|
|
for member_name in module.__dir__():
|
|
|
|
if member_name.endswith(
|
|
|
|
'Builder'
|
|
|
|
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
|
|
|
|
if not member_name in self.class_dict:
|
|
|
|
self.class_dict[member_name] = getattr(module, member_name)
|
|
|
|
# end initialize for create_op_builder()
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
# Device APIs
|
|
|
|
def device_name(self, device_index=None):
|
|
|
|
if device_index == None:
|
|
|
|
return 'cuda'
|
|
|
|
return 'cuda:{}'.format(device_index)
|
|
|
|
|
|
|
|
def device(self, device_index=None):
|
|
|
|
return torch.cuda.device(device_index)
|
|
|
|
|
|
|
|
def set_device(self, device_index):
|
|
|
|
torch.cuda.set_device(device_index)
|
|
|
|
|
|
|
|
def current_device(self):
|
|
|
|
return torch.cuda.current_device()
|
|
|
|
|
|
|
|
def current_device_name(self):
|
|
|
|
return 'cuda:{}'.format(torch.cuda.current_device())
|
|
|
|
|
|
|
|
def device_count(self):
|
|
|
|
return torch.cuda.device_count()
|
|
|
|
|
|
|
|
def synchronize(self, device_index=None):
|
|
|
|
return torch.cuda.synchronize(device_index)
|
|
|
|
|
|
|
|
# RNG APIs
|
|
|
|
def random(self):
|
|
|
|
return torch.random
|
|
|
|
|
|
|
|
def set_rng_state(self, new_state, device_index=None):
|
|
|
|
if device_index is None:
|
|
|
|
return torch.cuda.set_rng_state(new_state)
|
|
|
|
|
|
|
|
return torch.cuda.set_rng_state(new_state, device_index)
|
|
|
|
|
|
|
|
def get_rng_state(self, device_index=None):
|
|
|
|
if device_index is None:
|
|
|
|
return torch.cuda.get_rng_state()
|
|
|
|
|
|
|
|
return torch.cuda.get_rng_state(device_index)
|
|
|
|
|
|
|
|
def manual_seed(self, seed):
|
|
|
|
return torch.cuda.manual_seed(seed)
|
|
|
|
|
|
|
|
def manual_seed_all(self, seed):
|
|
|
|
return torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
def initial_seed(self, seed):
|
|
|
|
return torch.cuda.initial_seed(seed)
|
|
|
|
|
|
|
|
def default_generator(self, device_index):
|
|
|
|
return torch.cuda.default_generators[device_index]
|
|
|
|
|
|
|
|
# Streams/Events
|
2023-01-26 17:03:12 +03:00
|
|
|
@property
|
|
|
|
def Stream(self):
|
|
|
|
return torch.cuda.Stream
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
def stream(self, stream):
|
|
|
|
return torch.cuda.stream(stream)
|
|
|
|
|
|
|
|
def current_stream(self, device_index=None):
|
|
|
|
return torch.cuda.current_stream(device_index)
|
|
|
|
|
|
|
|
def default_stream(self, device_index=None):
|
|
|
|
return torch.cuda.default_stream(device_index)
|
|
|
|
|
2023-01-26 17:03:12 +03:00
|
|
|
@property
|
|
|
|
def Event(self):
|
|
|
|
return torch.cuda.Event
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
# Memory management
|
|
|
|
def empty_cache(self):
|
|
|
|
return torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
def memory_allocated(self, device_index=None):
|
|
|
|
return torch.cuda.memory_allocated(device_index)
|
|
|
|
|
|
|
|
def max_memory_allocated(self, device_index=None):
|
|
|
|
return torch.cuda.max_memory_allocated(device_index)
|
|
|
|
|
|
|
|
def reset_max_memory_allocated(self, device_index=None):
|
|
|
|
return torch.cuda.reset_max_memory_allocated(device_index)
|
|
|
|
|
|
|
|
def memory_cached(self, device_index=None):
|
|
|
|
return torch.cuda.memory_cached(device_index)
|
|
|
|
|
|
|
|
def max_memory_cached(self, device_index=None):
|
|
|
|
return torch.cuda.max_memory_cached(device_index)
|
|
|
|
|
|
|
|
def reset_max_memory_cached(self, device_index=None):
|
|
|
|
return torch.cuda.reset_max_memory_cached(device_index)
|
|
|
|
|
|
|
|
def memory_stats(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'memory_stats'):
|
|
|
|
return torch.cuda.memory_stats(device_index)
|
|
|
|
|
|
|
|
def reset_peak_memory_stats(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'reset_peak_memory_stats'):
|
|
|
|
return torch.cuda.reset_peak_memory_stats(device_index)
|
|
|
|
|
|
|
|
def memory_reserved(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'memory_reserved'):
|
|
|
|
return torch.cuda.memory_reserved(device_index)
|
|
|
|
|
|
|
|
def max_memory_reserved(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'max_memory_reserved'):
|
|
|
|
return torch.cuda.max_memory_reserved(device_index)
|
|
|
|
|
|
|
|
def total_memory(self, device_index=None):
|
|
|
|
return torch.cuda.get_device_properties(device_index).total_memory
|
|
|
|
|
|
|
|
# Data types
|
|
|
|
def is_bf16_supported(self):
|
|
|
|
return torch.cuda.is_bf16_supported()
|
|
|
|
|
|
|
|
def is_fp16_supported(self):
|
|
|
|
major, _ = torch.cuda.get_device_capability()
|
|
|
|
if major >= 7:
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
# Misc
|
|
|
|
def amp(self):
|
|
|
|
if hasattr(torch.cuda, 'amp'):
|
|
|
|
return torch.cuda.amp
|
|
|
|
return None
|
|
|
|
|
|
|
|
def is_available(self):
|
|
|
|
return torch.cuda.is_available()
|
|
|
|
|
|
|
|
def range_push(self, msg):
|
|
|
|
if hasattr(torch.cuda.nvtx, 'range_push'):
|
|
|
|
return torch.cuda.nvtx.range_push(msg)
|
|
|
|
|
|
|
|
def range_pop(self):
|
|
|
|
if hasattr(torch.cuda.nvtx, 'range_pop'):
|
|
|
|
return torch.cuda.nvtx.range_pop()
|
|
|
|
|
|
|
|
def lazy_call(self, callback):
|
|
|
|
return torch.cuda._lazy_call(callback)
|
|
|
|
|
|
|
|
def communication_backend_name(self):
|
|
|
|
return self._communication_backend_name
|
|
|
|
|
|
|
|
# Tensor operations
|
|
|
|
|
|
|
|
@property
|
|
|
|
def BFloat16Tensor(self):
|
|
|
|
return torch.cuda.BFloat16Tensor
|
|
|
|
|
|
|
|
@property
|
|
|
|
def ByteTensor(self):
|
|
|
|
return torch.cuda.ByteTensor
|
|
|
|
|
|
|
|
@property
|
|
|
|
def DoubleTensor(self):
|
|
|
|
return torch.cuda.DoubleTensor
|
|
|
|
|
|
|
|
@property
|
|
|
|
def FloatTensor(self):
|
|
|
|
return torch.cuda.FloatTensor
|
|
|
|
|
|
|
|
@property
|
|
|
|
def HalfTensor(self):
|
|
|
|
return torch.cuda.HalfTensor
|
|
|
|
|
|
|
|
@property
|
|
|
|
def IntTensor(self):
|
|
|
|
return torch.cuda.IntTensor
|
|
|
|
|
|
|
|
@property
|
|
|
|
def LongTensor(self):
|
|
|
|
return torch.cuda.LongTensor
|
|
|
|
|
|
|
|
def pin_memory(self, tensor):
|
|
|
|
return tensor.pin_memory()
|
|
|
|
|
|
|
|
def on_accelerator(self, tensor):
|
|
|
|
device_str = str(tensor.device)
|
|
|
|
if device_str.startswith('cuda:'):
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def op_builder_dir(self):
|
2023-01-07 07:40:58 +03:00
|
|
|
try:
|
2023-03-08 20:55:41 +03:00
|
|
|
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
|
|
|
# if successful this also means we're doing a local install and not JIT compile path
|
|
|
|
from op_builder import __deepspeed__ # noqa: F401
|
2023-01-07 07:40:58 +03:00
|
|
|
return "op_builder"
|
|
|
|
except ImportError:
|
|
|
|
return "deepspeed.ops.op_builder"
|
|
|
|
|
|
|
|
# dict that holds class name <--> class type mapping i.e.
|
|
|
|
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
|
|
|
|
# this dict will be filled at init stage
|
|
|
|
class_dict = {}
|
2022-11-30 15:15:58 +03:00
|
|
|
|
2023-01-26 17:03:12 +03:00
|
|
|
# create an instance of op builder and return, name specified by class_name
|
2022-11-30 15:15:58 +03:00
|
|
|
def create_op_builder(self, class_name):
|
2023-01-07 07:40:58 +03:00
|
|
|
if class_name in self.class_dict:
|
|
|
|
return self.class_dict[class_name]()
|
2022-11-30 15:15:58 +03:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2023-01-26 17:03:12 +03:00
|
|
|
# return an op builder class, name specified by class_name
|
|
|
|
def get_op_builder(self, class_name):
|
|
|
|
if class_name in self.class_dict:
|
|
|
|
return self.class_dict[class_name]
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
def build_extension(self):
|
|
|
|
from torch.utils.cpp_extension import BuildExtension
|
|
|
|
return BuildExtension
|