torchy/benchmarks/testdriver.py

33 строки
871 B
Python
Исходник Обычный вид История

import torch
2021-10-29 12:50:55 +03:00
import sys
torch.manual_seed(0)
cuda = False
torchscript = False
2021-11-04 19:41:55 +03:00
torch._C._jit_set_texpr_fuser_enabled(False)
for arg in sys.argv[1:]:
if arg == '--torchy':
import torchy
torchy.enable()
elif arg == '--cuda':
cuda = True
if not torch.cuda.is_available():
print('UNSUPPORTED: CUDA is not available')
exit(0x42)
elif arg == '--torchscript':
torchscript = True
2021-10-25 21:06:22 +03:00
elif arg == '--fuser-nnc':
2021-11-04 19:41:55 +03:00
torch._C._jit_set_texpr_fuser_enabled(True)
2021-11-05 18:54:42 +03:00
elif arg == '--nnc-enable-reductions':
torch._C._jit_set_texpr_reductions_enabled(True)
2021-10-25 21:06:22 +03:00
elif arg == '--nvfuser':
#os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
#os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
#os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
torch._C._jit_set_nvfuser_enabled(True)
else:
print(f'Unknown arg: {arg}')