зеркало из https://github.com/mozilla/TTS.git
update torch version and solve compat issue with isinf
This commit is contained in:
Родитель
46522e213e
Коммит
c27fd4238a
|
@ -1,5 +1,5 @@
|
||||||
numpy>=1.16.0
|
numpy>=1.16.0
|
||||||
torch>=0.4.1
|
torch>=1.5
|
||||||
librosa>=0.5.1
|
librosa>=0.5.1
|
||||||
Unidecode>=0.4.20
|
Unidecode>=0.4.20
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -92,7 +92,7 @@ setup(
|
||||||
},
|
},
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"scipy>=0.19.0",
|
"scipy>=0.19.0",
|
||||||
"torch>=0.4.1",
|
"torch>=1.5",
|
||||||
"numpy>=1.16.0",
|
"numpy>=1.16.0",
|
||||||
"librosa==0.6.2",
|
"librosa==0.6.2",
|
||||||
"unidecode==0.4.20",
|
"unidecode==0.4.20",
|
||||||
|
|
|
@ -9,9 +9,15 @@ def check_update(model, grad_clip, ignore_stopnet=False):
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||||
if torch.isinf(grad_norm):
|
# compatibility with different torch versions
|
||||||
print(" | > Gradient is INF !!")
|
if isinstance(grad_norm, float):
|
||||||
skip_flag = True
|
if np.isinf(grad_norm):
|
||||||
|
print(" | > Gradient is INF !!")
|
||||||
|
skip_flag = True
|
||||||
|
else:
|
||||||
|
if torch.isinf(grad_norm):
|
||||||
|
print(" | > Gradient is INF !!")
|
||||||
|
skip_flag = True
|
||||||
return grad_norm, skip_flag
|
return grad_norm, skip_flag
|
||||||
|
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче