fix for pytorch < 1.6 (#6300)
This commit is contained in:
Родитель
2804fff839
Коммит
118ecfd427
|
@ -1400,7 +1400,7 @@ class ReformerLayer(nn.Module):
|
||||||
|
|
||||||
# randomize seeds
|
# randomize seeds
|
||||||
# use cuda generator if available
|
# use cuda generator if available
|
||||||
if len(torch.cuda.default_generators) > 0:
|
if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
|
||||||
# GPU
|
# GPU
|
||||||
device_idx = torch.cuda.current_device()
|
device_idx = torch.cuda.current_device()
|
||||||
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
|
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
|
||||||
|
@ -1420,7 +1420,7 @@ class ReformerLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
# randomize seeds
|
# randomize seeds
|
||||||
# use cuda generator if available
|
# use cuda generator if available
|
||||||
if len(torch.cuda.default_generators) > 0:
|
if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
|
||||||
# GPU
|
# GPU
|
||||||
device_idx = torch.cuda.current_device()
|
device_idx = torch.cuda.current_device()
|
||||||
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
|
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче