зеркало из https://github.com/microsoft/antares.git
124 строки
4.8 KiB
Python
124 строки
4.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
import os, sys, time, math
|
|
import subprocess
|
|
import hashlib
|
|
import traceback
|
|
|
|
EVAL_PROPERTIES={}
|
|
|
|
def init(**kwargs):
|
|
backend_root = kwargs['backend_root']
|
|
backend = os.path.basename(backend_root)
|
|
source_root = f'{backend_root}/../../graph_evaluator'
|
|
|
|
global eval_client
|
|
try:
|
|
import importlib
|
|
eval_client = importlib.import_module('backends.%s.evaluator.client' % backend)
|
|
out = eval_client.init(**kwargs)
|
|
if hasattr(eval_client, 'EVAL_PROPERTIES'):
|
|
global EVAL_PROPERTIES
|
|
EVAL_PROPERTIES = eval_client.EVAL_PROPERTIES
|
|
return out
|
|
except ModuleNotFoundError:
|
|
pass
|
|
except:
|
|
traceback.print_exc()
|
|
exit(1)
|
|
|
|
assert os.path.exists(f'{backend_root}/include/backend.hpp')
|
|
evaluator_path = '%s/evaluator.%s' % (os.environ['ANTARES_DRIVER_PATH'], backend)
|
|
|
|
if backend_root:
|
|
with open(f'{backend_root}/include/backend.hpp', 'r') as fp:
|
|
eval_flags_pref = f'//; eval_flags({backend}):'
|
|
eval_flags, compiler = '', 'g++'
|
|
while True:
|
|
line = fp.readline()
|
|
if not line:
|
|
break
|
|
line = line.strip()
|
|
if line.startswith(eval_flags_pref):
|
|
eval_flags = line[len(eval_flags_pref):].strip()
|
|
if eval_flags.startswith('['):
|
|
idx = eval_flags.index(']')
|
|
eval_flags, compiler = eval_flags[idx+1:].strip(), eval_flags[1:idx].strip()
|
|
else:
|
|
eval_flags += ' -lpthread'
|
|
break
|
|
|
|
compile_flags = f'-D__BACKEND__=\\"{backend}\\" -D__BACKEND_{backend[backend.index("-")+1:]}__ -std=c++17 -Wno-string-compare -Wno-unused-result -Wno-unused-value {eval_flags}'
|
|
EVAL_PROPERTIES['compiler'], EVAL_PROPERTIES['compile_flags'] = compiler, compile_flags
|
|
compile_flags += f' -I{backend_root}/include'
|
|
|
|
if 0 != os.system(f"diff {backend_root}/include/backend.hpp {os.environ['ANTARES_DRIVER_PATH']}/backend.hpp_@{backend} >/dev/null 2>&1"):
|
|
error_info = f"SDK for `{backend}` is not configured correctly, please look into the error messages and reconfigure the corresponding environment."
|
|
compile_cmd = f'{compiler} {source_root}/run_graph.cpp -o {evaluator_path}.tmp {compile_flags}'
|
|
sys.stdout.write('\033[91m')
|
|
print(f'\n[EvalAgent] Compiling Evaluator: {compile_cmd}')
|
|
compile_stat = os.system(f'timeout 30s {compile_cmd}')
|
|
sys.stdout.write('\033[0m\n')
|
|
assert compile_stat == 0, error_info
|
|
os.system(f"cp {backend_root}/include/backend.hpp {os.environ['ANTARES_DRIVER_PATH']}/backend.hpp_@{backend}")
|
|
os.system(f'mv {evaluator_path}.tmp {evaluator_path} >/dev/null 2>&1')
|
|
is_wsl = 1 if (os.environ.get('IS_WSL', '0') == '1') else 0
|
|
|
|
def eval(kernel_path, **kwargs):
|
|
dev_id = kwargs['dev_id']
|
|
backend_root = kwargs['backend_root']
|
|
backend = os.path.basename(backend_root)
|
|
|
|
evaluator_path = '%s/evaluator.%s' % (os.environ['ANTARES_DRIVER_PATH'], backend)
|
|
if not os.path.exists(evaluator_path):
|
|
global eval_client
|
|
return eval_client.eval(kernel_path, **kwargs)
|
|
|
|
is_wsl = 1 if (os.environ.get('IS_WSL', '0') == '1') else 0
|
|
with open(evaluator_path, 'rb') as fp:
|
|
exec_magic = fp.read(2)
|
|
|
|
if is_wsl == 0 and exec_magic == b'MZ':
|
|
print(f"Antares should run under WSL-1/2 for this backend({backend}), otherwise, evaluation would be skipped.")
|
|
exit(1)
|
|
|
|
launcher = f'{backend_root}/launcher.sh'
|
|
if not os.path.exists(launcher):
|
|
launcher = ''
|
|
flags = []
|
|
if int(kwargs.get("compile", 0)) == 0:
|
|
flags += ['--dev', str(dev_id)]
|
|
if int(os.environ.get('PROGRESS', 0)) > 0:
|
|
flags += ['--progress']
|
|
debug_cnt = int(os.environ.get('AB_DEBUG', 0))
|
|
if debug_cnt > 0:
|
|
flags += ['--debug', str(debug_cnt)]
|
|
value_absdir = os.environ.get('VALUE_PATH', '.').strip()
|
|
if value_absdir:
|
|
value_absdir = value_absdir if not value_absdir.startswith('.') else os.path.join(os.environ['WORKDIR'], value_absdir)
|
|
flags += ['--value_absdir', value_absdir]
|
|
timeout = str(kwargs.get("expected_timeout", "")).strip()
|
|
if timeout:
|
|
flags += ['--timeout', timeout]
|
|
else:
|
|
flags += ['--compile']
|
|
|
|
flags = ' '.join(flags)
|
|
exec_cmd = f'sh -c "cd {os.path.dirname(kernel_path)} && BACKEND={backend} {launcher} {evaluator_path} my_kernel.cc {flags}" || true'
|
|
try:
|
|
output = subprocess.check_output(exec_cmd, shell=True).decode()
|
|
except:
|
|
output = ''
|
|
|
|
results = {}
|
|
for line in output.split('\n'):
|
|
if line.startswith('- '):
|
|
key, val = line[2:].split(': ')
|
|
val = val.strip()
|
|
if val[-1].isdigit():
|
|
results[key] = float(val)
|
|
else:
|
|
results[key] = val
|
|
return results
|