зеркало из https://github.com/microsoft/DeepSpeed.git
assumption of torch.initial_seed function accepting seed arg in DeepSpeedAccelerator abstract class is incorrect (#5569)
pytorch API reference - https://pytorch.org/docs/stable/generated/torch.initial_seed.html fix return value of manual_seed api for hpu --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
b6e24adb43
Коммит
ac935c7fde
|
@ -81,7 +81,7 @@ class DeepSpeedAccelerator(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|
|
@ -100,8 +100,8 @@ class CPU_Accelerator(DeepSpeedAccelerator):
|
||||||
def manual_seed_all(self, seed):
|
def manual_seed_all(self, seed):
|
||||||
return torch.manual_seed(seed)
|
return torch.manual_seed(seed)
|
||||||
|
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
return torch.initial_seed(seed)
|
return torch.initial_seed()
|
||||||
|
|
||||||
def default_generator(self, device_index):
|
def default_generator(self, device_index):
|
||||||
return torch.default_generator
|
return torch.default_generator
|
||||||
|
|
|
@ -99,8 +99,8 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
||||||
def manual_seed_all(self, seed):
|
def manual_seed_all(self, seed):
|
||||||
return torch.cuda.manual_seed_all(seed)
|
return torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
return torch.cuda.initial_seed(seed)
|
return torch.cuda.initial_seed()
|
||||||
|
|
||||||
def default_generator(self, device_index):
|
def default_generator(self, device_index):
|
||||||
return torch.cuda.default_generators[device_index]
|
return torch.cuda.default_generators[device_index]
|
||||||
|
|
|
@ -74,13 +74,13 @@ class HPU_Accelerator(DeepSpeedAccelerator):
|
||||||
return self.hpu.random.get_rng_state()
|
return self.hpu.random.get_rng_state()
|
||||||
|
|
||||||
def manual_seed(self, seed):
|
def manual_seed(self, seed):
|
||||||
self.hpu.random.manual_seed(seed)
|
return self.hpu.random.manual_seed(seed)
|
||||||
|
|
||||||
def manual_seed_all(self, seed):
|
def manual_seed_all(self, seed):
|
||||||
self.hpu.random.manual_seed_all(seed)
|
self.hpu.random.manual_seed_all(seed)
|
||||||
|
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
self.hpu.random.initial_seed(seed)
|
return self.hpu.random.initial_seed()
|
||||||
|
|
||||||
def default_generator(self, device_index):
|
def default_generator(self, device_index):
|
||||||
return self.hpu.random.default_generators[device_index]
|
return self.hpu.random.default_generators[device_index]
|
||||||
|
|
|
@ -77,7 +77,7 @@ class MPS_Accelerator(DeepSpeedAccelerator):
|
||||||
def seed(self):
|
def seed(self):
|
||||||
return torch.mps.seed()
|
return torch.mps.seed()
|
||||||
|
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
def default_generator(self, device_index):
|
def default_generator(self, device_index):
|
||||||
|
|
|
@ -84,8 +84,8 @@ class NPU_Accelerator(DeepSpeedAccelerator):
|
||||||
def manual_seed_all(self, seed):
|
def manual_seed_all(self, seed):
|
||||||
return torch.npu.manual_seed_all(seed)
|
return torch.npu.manual_seed_all(seed)
|
||||||
|
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
return torch.npu.initial_seed(seed)
|
return torch.npu.initial_seed()
|
||||||
|
|
||||||
def default_generator(self, device_index):
|
def default_generator(self, device_index):
|
||||||
return torch.npu.default_generators[device_index]
|
return torch.npu.default_generators[device_index]
|
||||||
|
|
|
@ -74,8 +74,8 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
||||||
def manual_seed_all(self, seed):
|
def manual_seed_all(self, seed):
|
||||||
return torch.xpu.manual_seed_all(seed)
|
return torch.xpu.manual_seed_all(seed)
|
||||||
|
|
||||||
def initial_seed(self, seed):
|
def initial_seed(self):
|
||||||
return torch.xpu.initial_seed(seed)
|
return torch.xpu.initial_seed()
|
||||||
|
|
||||||
def default_generator(self, device_index):
|
def default_generator(self, device_index):
|
||||||
return torch.xpu.default_generators[device_index]
|
return torch.xpu.default_generators[device_index]
|
||||||
|
|
Загрузка…
Ссылка в новой задаче