2021-10-15 16:21:21 +03:00
|
|
|
import torch
|
2021-10-29 12:50:55 +03:00
|
|
|
import sys
|
2021-10-15 16:21:21 +03:00
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
|
|
|
cuda = False
|
2021-11-05 21:18:00 +03:00
|
|
|
torchscript = False
|
2021-10-15 16:21:21 +03:00
|
|
|
|
2021-11-04 19:41:55 +03:00
|
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
|
|
|
2021-10-15 16:21:21 +03:00
|
|
|
for arg in sys.argv[1:]:
|
|
|
|
if arg == '--torchy':
|
|
|
|
import torchy
|
|
|
|
torchy.enable()
|
|
|
|
elif arg == '--cuda':
|
|
|
|
cuda = True
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
print('UNSUPPORTED: CUDA is not available')
|
|
|
|
exit(0x42)
|
2021-11-05 21:18:00 +03:00
|
|
|
elif arg == '--torchscript':
|
|
|
|
torchscript = True
|
2021-10-25 21:06:22 +03:00
|
|
|
elif arg == '--fuser-nnc':
|
2021-11-04 19:41:55 +03:00
|
|
|
torch._C._jit_set_texpr_fuser_enabled(True)
|
2021-11-05 18:54:42 +03:00
|
|
|
elif arg == '--nnc-enable-reductions':
|
|
|
|
torch._C._jit_set_texpr_reductions_enabled(True)
|
2021-10-25 21:06:22 +03:00
|
|
|
elif arg == '--nvfuser':
|
|
|
|
#os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
|
|
|
|
#os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
|
|
|
|
#os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
|
|
|
|
torch._C._jit_set_nvfuser_enabled(True)
|
2021-10-15 16:21:21 +03:00
|
|
|
else:
|
|
|
|
print(f'Unknown arg: {arg}')
|