# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os, sys, time, math import subprocess import hashlib import traceback EVAL_PROPERTIES={} def init(**kwargs): backend_root = kwargs['backend_root'] backend = os.path.basename(backend_root) source_root = f'{backend_root}/../../graph_evaluator' global eval_client try: import importlib eval_client = importlib.import_module('backends.%s.evaluator.client' % backend) out = eval_client.init(**kwargs) if hasattr(eval_client, 'EVAL_PROPERTIES'): global EVAL_PROPERTIES EVAL_PROPERTIES = eval_client.EVAL_PROPERTIES return out except ModuleNotFoundError: pass except: traceback.print_exc() exit(1) assert os.path.exists(f'{backend_root}/include/backend.hpp') evaluator_path = '%s/evaluator.%s' % (os.environ['ANTARES_DRIVER_PATH'], backend) if backend_root: with open(f'{backend_root}/include/backend.hpp', 'r') as fp: eval_flags_pref = f'//; eval_flags({backend}):' eval_flags, compiler = '', 'g++' while True: line = fp.readline() if not line: break line = line.strip() if line.startswith(eval_flags_pref): eval_flags = line[len(eval_flags_pref):].strip() if eval_flags.startswith('['): idx = eval_flags.index(']') eval_flags, compiler = eval_flags[idx+1:].strip(), eval_flags[1:idx].strip() else: eval_flags += ' -lpthread' break compile_flags = f'-D__BACKEND__=\\"{backend}\\" -D__BACKEND_{backend[backend.index("-")+1:]}__ -std=c++17 -Wno-string-compare -Wno-unused-result -Wno-unused-value {eval_flags}' EVAL_PROPERTIES['compiler'], EVAL_PROPERTIES['compile_flags'] = compiler, compile_flags compile_flags += f' -I{backend_root}/include' if 0 != os.system(f"diff {backend_root}/include/backend.hpp {os.environ['ANTARES_DRIVER_PATH']}/backend.hpp_@{backend} >/dev/null 2>&1"): error_info = f"SDK for `{backend}` is not configured correctly, please look into the error messages and reconfigure the corresponding environment." compile_cmd = f'{compiler} {source_root}/run_graph.cpp -o {evaluator_path}.tmp {compile_flags}' sys.stdout.write('\033[91m') print(f'\n[EvalAgent] Compiling Evaluator: {compile_cmd}') compile_stat = os.system(f'timeout 30s {compile_cmd}') sys.stdout.write('\033[0m\n') assert compile_stat == 0, error_info os.system(f"cp {backend_root}/include/backend.hpp {os.environ['ANTARES_DRIVER_PATH']}/backend.hpp_@{backend}") os.system(f'mv {evaluator_path}.tmp {evaluator_path} >/dev/null 2>&1') is_wsl = 1 if (os.environ.get('IS_WSL', '0') == '1') else 0 def eval(kernel_path, **kwargs): dev_id = kwargs['dev_id'] backend_root = kwargs['backend_root'] backend = os.path.basename(backend_root) evaluator_path = '%s/evaluator.%s' % (os.environ['ANTARES_DRIVER_PATH'], backend) if not os.path.exists(evaluator_path): global eval_client return eval_client.eval(kernel_path, **kwargs) is_wsl = 1 if (os.environ.get('IS_WSL', '0') == '1') else 0 with open(evaluator_path, 'rb') as fp: exec_magic = fp.read(2) if is_wsl == 0 and exec_magic == b'MZ': print(f"Antares should run under WSL-1/2 for this backend({backend}), otherwise, evaluation would be skipped.") exit(1) launcher = f'{backend_root}/launcher.sh' if not os.path.exists(launcher): launcher = '' flags = [] if int(kwargs.get("compile", 0)) == 0: flags += ['--dev', str(dev_id)] if int(os.environ.get('PROGRESS', 0)) > 0: flags += ['--progress'] debug_cnt = int(os.environ.get('AB_DEBUG', 0)) if debug_cnt > 0: flags += ['--debug', str(debug_cnt)] value_absdir = os.environ.get('VALUE_PATH', '.').strip() if value_absdir: value_absdir = value_absdir if not value_absdir.startswith('.') else os.path.join(os.environ['WORKDIR'], value_absdir) flags += ['--value_absdir', value_absdir] timeout = str(kwargs.get("expected_timeout", "")).strip() if timeout: flags += ['--timeout', timeout] else: flags += ['--compile'] flags = ' '.join(flags) exec_cmd = f'sh -c "cd {os.path.dirname(kernel_path)} && BACKEND={backend} {launcher} {evaluator_path} my_kernel.cc {flags}" || true' try: output = subprocess.check_output(exec_cmd, shell=True).decode() except: output = '' results = {} for line in output.split('\n'): if line.startswith('- '): key, val = line[2:].split(': ') val = val.strip() if val[-1].isdigit(): results[key] = float(val) else: results[key] = val return results