update torch version and solve compat issue with isinf

This commit is contained in:
erogol 2020-05-22 13:09:07 +02:00
Родитель 46522e213e
Коммит c27fd4238a
3 изменённых файлов: 11 добавлений и 5 удалений

Просмотреть файл

@ -1,5 +1,5 @@
numpy>=1.16.0 numpy>=1.16.0
torch>=0.4.1 torch>=1.5
librosa>=0.5.1 librosa>=0.5.1
Unidecode>=0.4.20 Unidecode>=0.4.20
tensorboard tensorboard

Просмотреть файл

@ -92,7 +92,7 @@ setup(
}, },
install_requires=[ install_requires=[
"scipy>=0.19.0", "scipy>=0.19.0",
"torch>=0.4.1", "torch>=1.5",
"numpy>=1.16.0", "numpy>=1.16.0",
"librosa==0.6.2", "librosa==0.6.2",
"unidecode==0.4.20", "unidecode==0.4.20",

Просмотреть файл

@ -9,9 +9,15 @@ def check_update(model, grad_clip, ignore_stopnet=False):
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else: else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if torch.isinf(grad_norm): # compatibility with different torch versions
print(" | > Gradient is INF !!") if isinstance(grad_norm, float):
skip_flag = True if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
else:
if torch.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag return grad_norm, skip_flag