torchy/benchmarks/testdriver.py

33 строки
871 B
Python

import torch
import sys
torch.manual_seed(0)
cuda = False
torchscript = False
torch._C._jit_set_texpr_fuser_enabled(False)
for arg in sys.argv[1:]:
if arg == '--torchy':
import torchy
torchy.enable()
elif arg == '--cuda':
cuda = True
if not torch.cuda.is_available():
print('UNSUPPORTED: CUDA is not available')
exit(0x42)
elif arg == '--torchscript':
torchscript = True
elif arg == '--fuser-nnc':
torch._C._jit_set_texpr_fuser_enabled(True)
elif arg == '--nnc-enable-reductions':
torch._C._jit_set_texpr_reductions_enabled(True)
elif arg == '--nvfuser':
#os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
#os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
#os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
torch._C._jit_set_nvfuser_enabled(True)
else:
print(f'Unknown arg: {arg}')