зеркало из https://github.com/microsoft/antares.git
dynamic check runtime backend (#41)
This commit is contained in:
Родитель
aedcadc34d
Коммит
5328db69b6
2
Makefile
2
Makefile
|
@ -1,5 +1,5 @@
|
|||
COMPUTE_V1 ?= - einstein_v2("output0[N] = input0[N] + input1[N]", input_dict={"input0": {"dtype": "float32", "shape": [1024 * 512]}, "input1": {"dtype": "float32", "shape": [1024 * 512]}})
|
||||
BACKEND ?= c-rocm
|
||||
BACKEND ?=
|
||||
TUNER ?=
|
||||
STEP ?= 0
|
||||
CONFIG ?=
|
||||
|
|
|
@ -82,7 +82,18 @@ def translate_code(code):
|
|||
return '%s\n%s%s' % (get_kernel_metadata(), defs, code)
|
||||
|
||||
def device_properties():
|
||||
return tvm.runtime.ndarray.gpu(0)
|
||||
props = tvm.runtime.ndarray.gpu(0)
|
||||
with open('%s/device_properties.cfg' % os.environ['ANTARES_DRIVER_PATH'], 'r') as fp:
|
||||
mem_bandwith = 2.5e-7
|
||||
while True:
|
||||
line = fp.readline()
|
||||
if not line:
|
||||
break
|
||||
key, val = line.split(': ')
|
||||
if key in ('GlobalMemoryBusWidth', 'MemoryClockRate'):
|
||||
mem_bandwith *= float(val)
|
||||
props.mem_bandwith = mem_bandwith
|
||||
return props
|
||||
|
||||
def compile_source(code):
|
||||
if 'HTTP_SERVICE' in os.environ:
|
||||
|
|
|
@ -9,6 +9,16 @@ fi
|
|||
|
||||
# Valid Backends: c-cuda, c-rocm, c-mcpu, c-hlsl, c-gc
|
||||
|
||||
if [[ "$BACKEND" == "" ]]; then
|
||||
if [ -e /dev/nvidia-modeset ]; then
|
||||
BACKEND=c-cuda
|
||||
elif [ -e /dev/kfd ]; then
|
||||
BACKEND=c-rocm
|
||||
elif grep Microsoft /proc/sys/kernel/osrelease >/dev/null; then
|
||||
BACKEND=c-hlsl
|
||||
fi
|
||||
fi
|
||||
|
||||
export BACKEND=${BACKEND:-c-rocm}
|
||||
export ANTARES_DRIVER_PATH=/tmp/libAntares
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <hip/hip_runtime.h>
|
||||
#define Q(attr_key) ((0 == hipDeviceGetAttribute(&val, hipDeviceAttribute ## attr_key, 0)) ? printf("%s: %d\n", #attr_key, val) : (exit(1), 0))
|
||||
#define hipDeviceAttributeMultiProcessorCount hipDeviceAttributeMultiprocessorCount
|
||||
#define hipDeviceAttributeGlobalMemoryBusWidth hipDeviceAttributeMemoryBusWidth
|
||||
#define CHECK_ENV() assert(getenv("BACKEND") != NULL), assert(strcmp(getenv("BACKEND"), "c-rocm") == 0);
|
||||
#endif
|
||||
|
||||
|
@ -31,5 +32,7 @@ int main() {
|
|||
Q(MaxBlockDimX);
|
||||
Q(MaxBlockDimY);
|
||||
Q(MaxBlockDimZ);
|
||||
Q(GlobalMemoryBusWidth);
|
||||
Q(MemoryClockRate);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -79,7 +79,6 @@ def scan_items(root, ast, range_book):
|
|||
tensor_name = root._value['tensor']._value['name']
|
||||
current_range = []
|
||||
for i, sub in enumerate(root._value['index']):
|
||||
tensor_index = i
|
||||
index_range = infer_range(sub, ax_rank)
|
||||
if index_range == '*':
|
||||
index_range = [0, None, 0, ast['props']['data_axes'][i]['range'] - 1]
|
||||
|
@ -94,7 +93,7 @@ def scan_items(root, ast, range_book):
|
|||
current_range[i] = [0, None, 0, ast['props']['data_axes'][i]['range'] - 1]
|
||||
range_book[tensor_name] = current_range
|
||||
|
||||
def auto_shard_on_ast(ast):
|
||||
def compute(ast):
|
||||
if backend not in ['c-gc']:
|
||||
return
|
||||
|
||||
|
|
|
@ -409,8 +409,18 @@ def emit_tvm_ir(exprss, input_dict):
|
|||
arg_props['_out'].sort(key=lambda x: x['name'])
|
||||
os.environ['GLOBAL_ARG_PROPS'] = json.dumps(arg_props)
|
||||
|
||||
from lang.auto_shard import auto_shard_on_ast
|
||||
auto_shard_on_ast(ast)
|
||||
try:
|
||||
from lang import auto_shard
|
||||
auto_shard.compute(ast)
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from lang import simplify
|
||||
simplify.compute(ast)
|
||||
except:
|
||||
pass
|
||||
|
||||
bias_axis_body = ''
|
||||
|
||||
def emit_input_body(input_dict):
|
||||
|
|
|
@ -113,7 +113,7 @@ std::pair<void *, void *> create_tensor_memory(const tensor_property &tp) {
|
|||
int main(int argc, char** argv)
|
||||
{
|
||||
if (0 != cudaSetDevice(0))
|
||||
throw std::runtime_error("GPU device not found.");
|
||||
throw std::runtime_error("GPU device `" + std::string(getenv("BACKEND")) + "` is not found.");
|
||||
|
||||
std::ifstream t("my_kernel.cc");
|
||||
std::string source((std::istreambuf_iterator<char>(t)), std::istreambuf_iterator<char>());
|
||||
|
|
Загрузка…
Ссылка в новой задаче