From 706f9b6f7eabf29aaa2f0d97c0d17fc92696c253 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 2 May 2017 10:14:45 -0700 Subject: [PATCH] [CODEGEN/RUNTIME] Metal support, runtime improvement. (#111) * [CODEGEN/RUNTIME] Metal support, runtime improvement. * Fix case when no device is available --- Makefile | 29 +- docs/api/python/ndarray.rst | 1 + include/tvm/codegen.h | 7 - include/tvm/runtime/c_runtime_api.h | 4 +- include/tvm/runtime/config.h | 7 + include/tvm/runtime/module.h | 4 +- make/config.mk | 13 +- python/tvm/__init__.py | 2 +- python/tvm/_ffi/ndarray.py | 125 ++++---- python/tvm/build.py | 3 +- python/tvm/codegen.py | 17 -- python/tvm/contrib/cc_compiler.py | 2 + python/tvm/contrib/metal_compiler.py | 50 ++++ python/tvm/ndarray.py | 62 +++- src/api/api_codegen.cc | 5 - src/codegen/build_common.h | 34 +++ src/codegen/build_cuda.cc | 17 +- src/codegen/build_metal.cc | 47 +++ src/codegen/build_opencl.cc | 27 +- src/codegen/codegen.cc | 5 - src/codegen/codegen_c.cc | 24 +- src/codegen/codegen_c.h | 12 +- src/codegen/codegen_cuda.cc | 17 +- src/codegen/codegen_cuda.h | 4 +- src/codegen/codegen_metal.cc | 203 +++++++++++++ src/codegen/codegen_metal.h | 34 +++ src/codegen/codegen_opencl.cc | 33 ++- src/codegen/codegen_opencl.h | 9 +- src/codegen/codegen_source_base.h | 6 + src/codegen/source_module.cc | 52 ++++ src/codegen/verilog/verilog_module.cc | 6 +- src/codegen/verilog/vpi_device_api.cc | 14 +- src/pass/lower_thread_allreduce.cc | 11 +- src/pass/make_api.cc | 6 + src/pass/split_host_device.cc | 36 +-- src/pass/storage_sync.cc | 19 +- src/runtime/c_runtime_api.cc | 67 ++++- src/runtime/cpu_device_api.cc | 14 +- src/runtime/cuda/cuda_device_api.cc | 31 +- src/runtime/cuda/cuda_module.cc | 34 +-- src/runtime/device_api.h | 29 +- src/runtime/meta_data.h | 1 - src/runtime/metal/metal_common.h | 98 +++++++ src/runtime/metal/metal_device_api.mm | 240 +++++++++++++++ src/runtime/metal/metal_module.h | 36 +++ src/runtime/metal/metal_module.mm | 273 ++++++++++++++++++ src/runtime/module.cc | 16 +- src/runtime/opencl/opencl_common.h | 19 +- src/runtime/opencl/opencl_device_api.cc | 127 ++++---- src/runtime/opencl/opencl_module.cc | 49 +--- src/runtime/opencl/opencl_module.h | 2 +- src/runtime/pack_args.h | 233 +++++++++++++++ src/runtime/void_addr_args.h | 164 ----------- tests/python/integration/test_dot.py | 3 +- tests/python/integration/test_ewise.py | 44 +-- tests/python/integration/test_gemm.py | 33 ++- tests/python/integration/test_reduce.py | 24 +- tests/python/integration/test_scan.py | 15 +- tests/python/perf/gemm_square.py | 4 +- tests/python/unittest/test_codegen_device.py | 13 +- tests/python/unittest/test_codegen_extern.py | 6 +- tests/python/unittest/test_codegen_llvm.py | 8 +- .../python/unittest/test_codegen_vm_basic.py | 2 +- tests/python/unittest/test_module_load.py | 3 +- tests/python/unittest/test_runtime_ndarray.py | 13 +- tests/travis/run_test.sh | 7 +- .../integration/test_codegen_verilog.py | 5 +- tutorials/python/get_started.py | 1 - tutorials/python/intrin_math.py | 1 - 69 files changed, 1939 insertions(+), 623 deletions(-) create mode 100644 python/tvm/contrib/metal_compiler.py create mode 100644 src/codegen/build_common.h create mode 100644 src/codegen/build_metal.cc create mode 100644 src/codegen/codegen_metal.cc create mode 100644 src/codegen/codegen_metal.h create mode 100644 src/codegen/source_module.cc create mode 100644 src/runtime/metal/metal_common.h create mode 100644 src/runtime/metal/metal_device_api.mm create mode 100644 src/runtime/metal/metal_module.h create mode 100644 src/runtime/metal/metal_module.mm create mode 100644 src/runtime/pack_args.h delete mode 100644 src/runtime/void_addr_args.h diff --git a/Makefile b/Makefile index 47f41fda..ff77eff7 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,9 @@ all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) +METAL_SRC = $(wildcard src/runtime/metal/*.mm) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) +METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC)) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc) @@ -29,6 +31,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) export LDFLAGS = -pthread -lm export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\ -Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0 +export OBJCFLAGS= -fobjc-arc ifdef CUDA_PATH NVCC=$(CUDA_PATH)/bin/nvcc @@ -36,7 +39,7 @@ ifdef CUDA_PATH LDFLAGS += -L$(CUDA_PATH)/lib64 endif -ifeq ($(USE_CUDA), 1) +ifeq ($(ENABLE_CUDA), 1) CFLAGS += -DTVM_CUDA_RUNTIME=1 LDFLAGS += -lcuda -lcudart -lnvrtc else @@ -45,9 +48,10 @@ endif FRAMEWORKS= -ifeq ($(USE_OPENCL), 1) +UNAME_S := $(shell uname -s) + +ifeq ($(ENABLE_OPENCL), 1) CFLAGS += -DTVM_OPENCL_RUNTIME=1 - UNAME_S := $(shell uname -s) ifeq ($(UNAME_S), Darwin) FRAMEWORKS += -framework OpenCL else @@ -57,10 +61,20 @@ else CFLAGS += -DTVM_OPENCL_RUNTIME=0 endif +ifeq ($(ENABLE_METAL), 1) + CFLAGS += -DTVM_METAL_RUNTIME=1 + LDFLAGS += -lObjc + ALL_DEP += $(METAL_OBJ) + RUNTIME_DEP += $(METAL_OBJ) + FRAMEWORKS += -framework Metal -framework Foundation +else + CFLAGS += -DTVM_METAL_RUNTIME=0 +endif + # llvm configuration LLVM_CONFIG=llvm-config -ifeq ($(USE_LLVM), 1) +ifeq ($(ENABLE_LLVM), 1) LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3) LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags)) LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs) @@ -87,6 +101,11 @@ build/%.o: src/%.cc $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) -c $(CFLAGS) -c $< -o $@ +build/%.o: src/%.mm + @mkdir -p $(@D) + $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d + $(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@ + lib/libtvm.so: $(ALL_DEP) @mkdir -p $(@D) $(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) @@ -105,7 +124,7 @@ LIBHALIDEIR: + cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR) cpplint: - python2 dmlc-core/scripts/lint.py tvm cpp include src verilog + python dmlc-core/scripts/lint.py tvm cpp include src verilog pylint: pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc diff --git a/docs/api/python/ndarray.rst b/docs/api/python/ndarray.rst index 044212bb..54964f1b 100644 --- a/docs/api/python/ndarray.rst +++ b/docs/api/python/ndarray.rst @@ -12,4 +12,5 @@ tvm.ndarray .. autofunction:: tvm.cpu .. autofunction:: tvm.gpu .. autofunction:: tvm.opencl +.. autofunction:: tvm.metal .. autofunction:: tvm.ndarray.array diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index e85d571a..d23d9e33 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -31,13 +31,6 @@ using runtime::TVMRetValue; */ runtime::Module Build(const Array& funcs, const std::string& target); - -/*! - * \param target The target to be queried. - * \return Whether target is enabled. - */ -bool TargetEnabled(const std::string& target); - } // namespace codegen } // namespace tvm diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5b9a0549..f2b35b47 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -41,6 +41,8 @@ typedef int64_t tvm_index_t; /*! \brief Extension device types in TVM */ typedef enum { + /*! \brief Metal buffer. */ + kMetal = 8, /*! \brief Simulated on board RAM */ kVPI = 9 } TVMDeviceExtType; @@ -360,7 +362,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); TVM_DLL int TVMFuncListGlobalNames(int *out_size, const char*** out_array); -// Array related apis for quick proptying +// Array related apis for quick proptyping /*! * \brief Allocate a nd-array's memory, * including space of shape, of given spec. diff --git a/include/tvm/runtime/config.h b/include/tvm/runtime/config.h index 92a737f8..73857f1a 100644 --- a/include/tvm/runtime/config.h +++ b/include/tvm/runtime/config.h @@ -20,4 +20,11 @@ #define TVM_OPENCL_RUNTIME 0 #endif +/*! + *\brief whether to use metal runtime + */ +#ifndef TVM_METAL_RUNTIME +#define TVM_METAL_RUNTIME 0 +#endif + #endif // TVM_RUNTIME_CONFIG_H_ diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 7b6c74ed..f6364174 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -145,8 +145,8 @@ class ModuleNode { namespace symbol { /*! \brief Global variable to store module context. */ constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; -/*! \brief Local function to set the device during API entry. */ -constexpr const char* tvm_entry_setdevice = "__tvm_entry_setdevice"; +/*! \brief global function to set device */ +constexpr const char* tvm_set_device = "__tvm_set_device"; /*! \brief Auxiliary counter to global barrier. */ constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; /*! \brief Prepare the global barrier before kernels that uses global barrier. */ diff --git a/make/config.mk b/make/config.mk index f0680eb6..17c12c0a 100644 --- a/make/config.mk +++ b/make/config.mk @@ -34,16 +34,19 @@ ADD_CFLAGS = # matrix computation libraries for CPU/GPU #--------------------------------------------- -# whether use CUDA during compile -USE_CUDA = 1 +# whether enable CUDA during compile +ENABLE_CUDA = 1 -# whether use OpenCL during compile -USE_OPENCL = 0 +# whether enable OpenCL during compile +ENABLE_OPENCL = 0 + +# whether enable Metal during compile +ENABLE_METAL = 0 # whether build with LLVM support # This requires llvm-config to be in your PATH # Requires LLVM version >= 4.0 -USE_LLVM = 0 +ENABLE_LLVM = 0 # add the path to CUDA library to link and compile flag # if you have already add them to environment variable. diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index ced5c7dd..58f866e6 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -16,7 +16,7 @@ from . import node from . import ir_builder from . import ndarray as nd -from .ndarray import cpu, gpu, opencl, cl, vpi +from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi from ._ffi.function import Function from ._ffi.base import TVMError, __version__ diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index fa3c9b9f..11023a41 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -4,7 +4,8 @@ from __future__ import absolute_import import ctypes import numpy as np -from .base import _LIB, check_call, c_array +from .base import _LIB, check_call, c_array, string_types +from .. import _api_internal tvm_shape_index_t = ctypes.c_int64 @@ -63,22 +64,62 @@ class TVMType(ctypes.Structure): def __ne__(self, other): return not self.__eq__(other) + class TVMContext(ctypes.Structure): """TVM context strucure.""" _fields_ = [("device_id", ctypes.c_int), ("device_type", ctypes.c_int)] - MASK2STR = { 1 : 'cpu', 2 : 'gpu', 4 : 'opencl', + 8 : 'metal', 9 : 'vpi' } - def __init__(self, device_id, device_type): + STR2MASK = { + 'cpu': 1, + 'gpu': 2, + 'cuda': 2, + 'cl': 4, + 'opencl': 4, + 'metal': 8, + 'vpi': 9 + } + def __init__(self, device_type, device_id): super(TVMContext, self).__init__() self.device_id = device_id self.device_type = device_type + @property + def exist(self): + """Whether this device exist.""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 0) != 0 + + @property + def max_threads_per_block(self): + """Maximum number of threads on each block.""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 1) + + @property + def warp_size(self): + """Number of threads that executes in concurrent.""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 2) + + def sync(self): + """Synchronize until jobs finished at the context.""" + check_call(_LIB.TVMSynchronize(self, None)) + + def __eq__(self, other): + return (isinstance(other, TVMContext) and + self.device_id == other.device_id and + self.device_type == other.device_type) + + def __ne__(self, other): + return not self.__eq__(other) + def __repr__(self): return "%s(%d)" % ( TVMContext.MASK2STR[self.device_type], self.device_id) @@ -97,48 +138,38 @@ class TVMArray(ctypes.Structure): TVMArrayHandle = ctypes.POINTER(TVMArray) - -def cpu(dev_id=0): - """Construct a CPU device +def context(dev_type, dev_id=0): + """Construct a TVM context with given device type and id. Parameters ---------- + dev_type: int or str + The device type mask or name of the device. + dev_id : int, optional The integer device id + + Returns + ------- + ctx: TVMContext + The corresponding context. + + Examples + -------- + Context can be used to create reflection of context by + string representation of the device type. + + .. code-block:: python + + assert tvm.context("cpu", 1) == tvm.cpu(1) + assert tvm.context("gpu", 0) == tvm.gpu(0) + assert tvm.context("cuda", 0) == tvm.gpu(0) """ - return TVMContext(dev_id, 1) - - -def gpu(dev_id=0): - """Construct a CPU device - - Parameters - ---------- - dev_id : int, optional - The integer device id - """ - return TVMContext(dev_id, 2) - - -def opencl(dev_id=0): - """Construct a OpenCL device - - Parameters - ---------- - dev_id : int, optional - The integer device id - """ - return TVMContext(dev_id, 4) - -def vpi(dev_id=0): - """Construct a VPI simulated device - - Parameters - ---------- - dev_id : int, optional - The integer device id - """ - return TVMContext(dev_id, 9) + if isinstance(dev_type, string_types): + if not dev_type in TVMContext.STR2MASK: + raise ValueError("Unknown device type %s" % dev_type) + dev_type = TVMContext.STR2MASK[dev_type] + return TVMContext(dev_type, dev_id) def numpyasarray(np_data): @@ -154,10 +185,11 @@ def numpyasarray(np_data): arr.dtype = TVMType(np.dtype(data.dtype).name) arr.ndim = data.ndim # CPU device - arr.ctx = cpu(0) + arr.ctx = context(1, 0) return arr, shape -def empty(shape, dtype="float32", ctx=cpu(0)): + +def empty(shape, dtype="float32", ctx=context(1, 0)): """Create an empty array given shape and device Parameters @@ -185,17 +217,6 @@ def empty(shape, dtype="float32", ctx=cpu(0)): return _CLASS_NDARRAY(handle) -def sync(ctx): - """Synchronize all the context - - Parameters - ---------- - ctx : TVMContext - The context to be synced - """ - check_call(_LIB.TVMSynchronize(ctx, None)) - - class NDArrayBase(object): """A simple Device/CPU Array object in runtime.""" __slots__ = ["handle", "is_view"] diff --git a/python/tvm/build.py b/python/tvm/build.py index 5b926bd7..a261e1b1 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -10,6 +10,7 @@ from . import schedule from . import expr from . import ir_pass from . import collections +from . import module from . import codegen @@ -149,7 +150,7 @@ def build(sch, fsplits[0] = ir_pass.LowerPackedCall(fsplits[0]) if len(fsplits) > 1: if not target_host: - target_host = "llvm" if codegen.enabled("llvm") else "stackvm" + target_host = "llvm" if module.enabled("llvm") else "stackvm" mhost = codegen.build_module(fsplits[0], target_host) if target: mdev = codegen.build_module(fsplits[1:], target) diff --git a/python/tvm/codegen.py b/python/tvm/codegen.py index 2021c30a..d8097a0b 100644 --- a/python/tvm/codegen.py +++ b/python/tvm/codegen.py @@ -19,21 +19,4 @@ def build_module(lowered_func, target): """ return _Build(lowered_func, target) - -def enabled(target): - """Whether target is enabled for codegen. - - Parameters - ---------- - target : str - The target module type. - - Returns - ------- - enabled : boolean - Whether the target module is enabled. - """ - return _Enabled(target) - - _init_api("tvm.codegen") diff --git a/python/tvm/contrib/cc_compiler.py b/python/tvm/contrib/cc_compiler.py index 8f0c057c..af599ed5 100644 --- a/python/tvm/contrib/cc_compiler.py +++ b/python/tvm/contrib/cc_compiler.py @@ -24,6 +24,8 @@ def create_shared(path_target, objects, """ cmd = [cc] cmd += ["-shared"] + if sys.platform == "darwin": + cmd += ["-undefined", "dynamic_lookup"] cmd += ["-o", path_target] cmd += objects if options: diff --git a/python/tvm/contrib/metal_compiler.py b/python/tvm/contrib/metal_compiler.py new file mode 100644 index 00000000..b77fd28e --- /dev/null +++ b/python/tvm/contrib/metal_compiler.py @@ -0,0 +1,50 @@ +# pylint: disable=invalid-name +"""Utility to invoke metal compiler in the CLI system""" +from __future__ import absolute_import as _abs +import sys +import subprocess +from . import util + +def compile_source(code, path_target=None): + """Compile metal with CLI tool from env. + + Parameters + ---------- + code : str + The cuda code. + + path_target : str, optional + Output file. + + Return + ------ + metallib : bytearray + The bytearray of the metallib + """ + temp = util.tempdir() + temp_code = temp.relpath("my_lib.metal") + temp_ir = temp.relpath("my_lib.air") + temp_target = temp.relpath("my_lib.metallib") + + with open(temp_code, "w") as out_file: + out_file.write(code) + file_target = path_target if path_target else temp_target + + cmd1 = ["xcrun", "-sdk", "macosx", "metal", "-O3"] + cmd1 += [temp_code, "-o", temp_ir] + cmd2 = ["xcrun", "-sdk", "macosx", "metallib"] + cmd2 += [temp_ir, "-o", file_target] + proc = subprocess.Popen( + ' '.join(cmd1) + ";" + ' '.join(cmd2), + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + if proc.returncode != 0: + sys.stderr.write("Compilation error:\n") + sys.stderr.write(out) + sys.stderr.flush() + libbin = None + else: + libbin = bytearray(open(file_target, "rb").read()) + return libbin diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 487621b8..2e49c3e3 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -8,11 +8,9 @@ from __future__ import absolute_import as _abs import numpy as _np from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase -from ._ffi.ndarray import cpu, gpu, opencl, vpi, empty, sync +from ._ffi.ndarray import context, empty from ._ffi.ndarray import _set_class_ndarray -cl = opencl - class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. @@ -27,6 +25,64 @@ class NDArray(NDArrayBase): pass +def cpu(dev_id=0): + """Construct a CPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(1, dev_id) + + +def gpu(dev_id=0): + """Construct a CPU device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(2, dev_id) + + +def opencl(dev_id=0): + """Construct a OpenCL device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(4, dev_id) + + +def metal(dev_id=0): + """Construct a metal device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(8, dev_id) + + +def vpi(dev_id=0): + """Construct a VPI simulated device + + Parameters + ---------- + dev_id : int, optional + The integer device id + """ + return TVMContext(9, dev_id) + +cl = opencl +mtl = metal + + def array(arr, ctx=cpu(0)): """Create an array from source arr. diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index ae182e90..37e0717f 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -20,10 +20,5 @@ TVM_REGISTER_API("codegen._Build") *ret = Build(args[0], args[1]); } }); - -TVM_REGISTER_API("codegen._Enabled") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = TargetEnabled(args[0]); - }); } // namespace codegen } // namespace tvm diff --git a/src/codegen/build_common.h b/src/codegen/build_common.h new file mode 100644 index 00000000..fdb53d36 --- /dev/null +++ b/src/codegen/build_common.h @@ -0,0 +1,34 @@ +/*! + * Copyright (c) 2017 by Contributors + * Common build utilities + * \file build_common.h + */ +#ifndef TVM_CODEGEN_BUILD_COMMON_H_ +#define TVM_CODEGEN_BUILD_COMMON_H_ + +#include +#include +#include +#include "../runtime/meta_data.h" + +namespace tvm { +namespace codegen { +// Extract function information from device function. +inline std::unordered_map +ExtractFuncInfo(const Array& funcs) { + std::unordered_map fmap; + for (LoweredFunc f : funcs) { + runtime::FunctionInfo info; + for (size_t i = 0; i < f->args.size(); ++i) { + info.arg_types.push_back(Type2TVMType(f->args[i].type())); + } + for (size_t i = 0; i < f->thread_axis.size(); ++i) { + info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag); + } + fmap[f->name] = info; + } + return fmap; +} +} // namespace codegen +} // namespace tvm +#endif // TVM_CODEGEN_BUILD_COMMON_H_ diff --git a/src/codegen/build_cuda.cc b/src/codegen/build_cuda.cc index 4f9c4ee9..aaac4115 100644 --- a/src/codegen/build_cuda.cc +++ b/src/codegen/build_cuda.cc @@ -6,11 +6,10 @@ #include #include #include "./codegen_cuda.h" +#include "./build_common.h" #if TVM_CUDA_RUNTIME - #include -#include "../runtime/meta_data.h" #include "../runtime/cuda/cuda_common.h" #include "../runtime/cuda/cuda_module.h" @@ -71,19 +70,7 @@ runtime::Module BuildCUDA(Array funcs) { } else { ptx = NVRTCCompile(code); } - - std::unordered_map fmap; - for (LoweredFunc f : funcs) { - runtime::FunctionInfo info; - for (size_t i = 0; i < f->args.size(); ++i) { - info.arg_types.push_back(Type2TVMType(f->args[i].type())); - } - for (size_t i = 0; i < f->thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag); - } - fmap[f->name] = info; - } - return CUDAModuleCreate(ptx, fmt, fmap, code); + return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code); } TVM_REGISTER_API("codegen.build_cuda") diff --git a/src/codegen/build_metal.cc b/src/codegen/build_metal.cc new file mode 100644 index 00000000..f2a7e14f --- /dev/null +++ b/src/codegen/build_metal.cc @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2017 by Contributors + * Build metal modules from source. + * \file build_metal.cc + */ +#include +#include +#include "./codegen_metal.h" +#include "./build_common.h" + +#if TVM_METAL_RUNTIME +#include "../runtime/metal/metal_module.h" +#endif // TVM_METAL_RUNTIME + +namespace tvm { +namespace codegen { + +runtime::Module BuildMetal(Array funcs) { + using tvm::runtime::Registry; + bool output_ssa = false; + CodeGenMetal cg; + cg.Init(output_ssa); + for (LoweredFunc f : funcs) { + cg.AddFunction(f); + } + std::string code = cg.Finish(); +#if TVM_METAL_RUNTIME + std::string fmt = "metal"; + std::string source = ""; + if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { + source = code; + code = (*f)(code).operator std::string(); + fmt = "metallib"; + } + return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); +#else + LOG(WARNING) << "Metal runtime not enabled, return a source module..."; + return SourceModuleCreate(code, "metal"); +#endif // TVM_METAL_RUNTIME +} + +TVM_REGISTER_API("codegen.build_metal") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildMetal(args[0]); + }); +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/build_opencl.cc b/src/codegen/build_opencl.cc index 9580fe41..499c88a0 100644 --- a/src/codegen/build_opencl.cc +++ b/src/codegen/build_opencl.cc @@ -6,39 +6,29 @@ #include #include #include "./codegen_opencl.h" +#include "./build_common.h" #if TVM_OPENCL_RUNTIME - -#include "../runtime/meta_data.h" -#include "../runtime/opencl/opencl_common.h" #include "../runtime/opencl/opencl_module.h" +#endif // TVM_OPENCL_RUNTIME namespace tvm { namespace codegen { runtime::Module BuildOpenCL(Array funcs) { - std::ostringstream os; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (LoweredFunc f : funcs) { cg.AddFunction(f); } std::string code = cg.Finish(); - - std::unordered_map fmap; - for (LoweredFunc f : funcs) { - runtime::FunctionInfo info; - for (size_t i = 0; i < f->args.size(); ++i) { - info.arg_types.push_back(Type2TVMType(f->args[i].type())); - } - for (size_t i = 0; i < f->thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag); - } - fmap[f->name] = info; - } - return OpenCLModuleCreate(code, "cl", fmap); +#if TVM_OPENCL_RUNTIME + return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs)); +#else + LOG(WARNING) << "OpenCL runtime not enabled, return a source module..."; + return SourceModuleCreate(code, "cl"); +#endif // TVM_OPENCL_RUNTIME } TVM_REGISTER_API("codegen.build_opencl") @@ -47,4 +37,3 @@ TVM_REGISTER_API("codegen.build_opencl") }); } // namespace codegen } // namespace tvm -#endif // TVM_OPENCL_RUNTIME diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index f9c9ee68..03cf8ff4 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -32,10 +32,5 @@ runtime::Module Build(const Array& funcs, return m; } -bool TargetEnabled(const std::string& target) { - std::string build_f_name = "codegen.build_" + target; - return runtime::Registry::Get(build_f_name) != nullptr; -} - } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 51de1efa..38a57cb5 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -24,7 +24,6 @@ void CodeGenC::InitFuncState(LoweredFunc f) { void CodeGenC::AddFunction(LoweredFunc f) { // clear previous generated state. this->InitFuncState(f); - // skip the first underscore, so SSA variable starts from _1 GetUniqueName("_"); // add to alloc buffer type. @@ -41,8 +40,8 @@ void CodeGenC::AddFunction(LoweredFunc f) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); + stream << ' '; } - stream << ' '; } if (handle_data_type_.count(v.get())) { PrintType(handle_data_type_.at(v.get()), stream); @@ -246,9 +245,9 @@ void CodeGenC::PrintVecStore(const Variable* buffer, stream << ref << " = " << value << ";\n"; } -void CodeGenC::PrintThreadIndexExpr( - std::string thread_tag, std::ostream& os) { // NOLINT(*) - os << thread_tag; +void CodeGenC::BindThreadIndex(const IterVar& iv) { + CHECK(!var_idmap_.count(iv->var.get())); + var_idmap_[iv->var.get()] = iv->thread_tag; } void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*) @@ -674,6 +673,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { const Variable* buffer = op->buffer_var.as(); std::string scope = alloc_storage_scope_.at(buffer); PrintStorageScope(scope, stream); + stream << ' '; PrintType(op->type, stream); stream << ' '<< vid << '[' << constant_size << "];\n"; @@ -687,13 +687,7 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { IterVar iv(op->node.node_); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { - this->PrintIndent(); - PrintType(iv->var.type(), stream); - stream << ' ' - << AllocVarID(iv->var.get()) - << " = "; - PrintThreadIndexExpr(iv->thread_tag, stream); - stream << ";\n"; + BindThreadIndex(iv); } } } else if (op->attr_key == ir::attr::storage_scope) { @@ -740,7 +734,11 @@ void CodeGenC::VisitStmt_(const For* op) { void CodeGenC::VisitStmt_(const IfThenElse* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); - stream << "if (" << cond << ") {\n"; + if (cond[0] == '(' && cond[cond.length() - 1] == ')') { + stream << "if " << cond << " {\n"; + } else { + stream << "if (" << cond << ") {\n"; + } int then_scope = BeginScope(); PrintStmt(op->then_case); this->EndScope(then_scope); diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 7fec8e64..510f6ddc 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -121,12 +121,10 @@ class CodeGenC : virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*) /*! * \brief Print expr representing the thread tag - * \param tag The tag in the thread. - * \param os The strean to output to + * \param IterVar iv The thread index to be binded; */ - virtual void PrintThreadIndexExpr( - std::string tag, std::ostream& os); // NOLINT(*) - virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) + virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) + virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) virtual void PrintStorageSync(const Call* op); // NOLINT(*) // Binary vector op. virtual void PrintVecBinaryOp( @@ -169,12 +167,12 @@ class CodeGenC : const std::string& target, const std::string& src, Type t) final; /*! \brief the storage scope of allocation */ std::unordered_map alloc_storage_scope_; + /*! \brief the data type of allocated buffers */ + std::unordered_map handle_data_type_; private: /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; - /*! \brief the data type of allocated buffers */ - std::unordered_map handle_data_type_; /*! \brief set of volatile buf access */ std::unordered_set volatile_buf_; }; diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index d2a1a27d..5589d4b8 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -141,7 +141,9 @@ void CodeGenCUDA::PrintVecElemStore( void CodeGenCUDA::PrintStorageSync(const Call* op) { const std::string& sync = op->args[0].as()->value; - if (sync == "shared") { + if (sync == "warp") { + // DO nothing. + } else if (sync == "shared") { this->PrintIndent(); this->stream << "__syncthreads();\n"; } else if (sync == "global") { @@ -182,7 +184,7 @@ void CodeGenCUDA::PrintStorageScope( const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_NE(scope, "global"); if (scope == "shared") { - os << "__shared__ "; + os << "__shared__"; } } @@ -203,6 +205,17 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) { } } +void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->type, os); + os << "("; + for (int i = 0; i < op->lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } + os << ')'; +} } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 5df85490..2cd7bc0d 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -14,7 +14,7 @@ namespace tvm { namespace codegen { -class CodeGenCUDA : public CodeGenC { +class CodeGenCUDA final : public CodeGenC { public: void Init(bool output_ssa); void AddFunction(LoweredFunc f); @@ -31,6 +31,8 @@ class CodeGenCUDA : public CodeGenC { void PrintVecElemStore( const std::string& vec, Type t, int i, const std::string& value) final; + // overload visitor + void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const Evaluate *op) final; private: diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc new file mode 100644 index 00000000..7f8c1dd9 --- /dev/null +++ b/src/codegen/codegen_metal.cc @@ -0,0 +1,203 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_metal.cc + */ +#include +#include +#include +#include +#include "./codegen_metal.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace codegen { + +void CodeGenMetal::InitFuncState(LoweredFunc f) { + CodeGenC::InitFuncState(f); + // analyze the data; + for (Var arg : f->args) { + if (arg.type().is_handle()) { + alloc_storage_scope_[arg.get()] = "global"; + } + } +} + +CodeGenMetal::CodeGenMetal() { + decl_stream << "#include \n"; + decl_stream << "using namespace metal;\n\n"; + decl_stream << "union __TVMArgUnion {\n" + << " int v_int;\n" + << "};\n\n"; +} + +void CodeGenMetal::AddFunction(LoweredFunc f) { + // clear previous generated state. + this->InitFuncState(f); + // skip the first underscore, so SSA variable starts from _1 + GetUniqueName("_"); + // add to alloc buffer type. + for (const auto & kv : f->handle_data_type) { + RegisterHandleType(kv.first.get(), kv.second.type()); + } + // Function header. + this->stream << "kernel void " << f->name << "(\n"; + // Buffer arguments + size_t num_buffer = 0; + for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) { + Var v = f->args[i]; + if (!v.type().is_handle()) break; + stream << " "; + std::string vid = AllocVarID(v.get()); + auto it = alloc_storage_scope_.find(v.get()); + CHECK(it != alloc_storage_scope_.end()); + PrintStorageScope(it->second, stream); + stream << ' '; + if (handle_data_type_.count(v.get())) { + PrintType(handle_data_type_.at(v.get()), stream); + stream << "*"; + } else { + PrintType(v.type(), stream); + } + stream << ' ' << vid + << " [[ buffer(" << i << ") ]],\n"; + } + // Setup normal arguments. + size_t nargs = f->args.size() - num_buffer; + std::string varg = GetUniqueName("arg"); + if (nargs != 0) { + std::string arg_buf_type = f->name + "_args_t"; + stream << " constant " << arg_buf_type << "& " << varg + << " [[ buffer(" << num_buffer << ") ]],\n"; + // declare the struct + decl_stream << "struct " << arg_buf_type << " {\n"; + for (size_t i = num_buffer; i < f->args.size(); ++i) { + Var v = f->args[i]; + CHECK(!v.type().is_handle()); + std::string vid = AllocVarID(v.get()); + std::ostringstream vref; + if (v.type().bits() == 32) { + decl_stream << " "; + PrintType(v.type(), decl_stream); + decl_stream << " " << vid << ";\n"; + vref << varg << "." << vid; + } else { + // For non 32bit type, ref through arg union. + decl_stream << " __TVMArgUnion " << vid << ";\n"; + vref << varg << "." << vid << ".v_"; + PrintType(v.type(), vref); + } + var_idmap_[v.get()] = vref.str(); + } + decl_stream << "};\n\n"; + } + // Setup the thread group info. + CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); + CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); + int work_dim = 0; + for (IterVar iv : f->thread_axis) { + runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + work_dim = std::max(work_dim, scope.dim_index + 1); + } + if (work_dim != 0) { + // use ushort by default for now + stream << " "; + PrintType(UInt(16, work_dim), stream); + stream << " blockIdx [[threadgroup_position_in_grid]],\n"; + stream << " "; + PrintType(UInt(16, work_dim), stream); + stream << " threadIdx [[thread_position_in_threadgroup]]\n"; + } + // bind thread axis + for (IterVar iv : f->thread_axis) { + CHECK(!var_idmap_.count(iv->var.get())); + if (work_dim <= 1) { + var_idmap_[iv->var.get()] = + iv->thread_tag.substr(0, iv->thread_tag.length() - 2); + } else { + var_idmap_[iv->var.get()] = iv->thread_tag; + } + } + // the function scope. + stream << ") {\n"; + int func_scope = this->BeginScope(); + this->PrintStmt(f->body); + this->EndScope(func_scope); + this->PrintIndent(); + this->stream << "}\n\n"; +} + +void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + CHECK_EQ(lanes, 1) + << "do not yet support vector types"; + os << "void*"; return; + } + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: os << "half"; break; + case 32: os << "float"; break; + default: fail = true; break; + } + if (!fail && lanes == 1) return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << 'u'; + } + if (t.bits() == 8 && t.lanes() == 4) { + // directly 4 8 bit int in integer. + os << "int"; return; + } + switch (t.bits()) { + case 8: os << "char"; break; + case 16: os << "short"; break; + case 32: os << "int"; break; + case 1: os << "bool"; break; + default: fail = true; break; + } + if (!fail && lanes == 1) return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; +} + +void CodeGenMetal::PrintStorageSync(const Call* op) { + const std::string& sync = op->args[0].as()->value; + if (sync == "warp") { + this->PrintIndent(); + this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n"; + } else if (sync == "shared") { + this->PrintIndent(); + this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n"; + } else if (sync == "global") { + LOG(FATAL) << "global barrier not supported"; + } +} + +void CodeGenMetal::PrintStorageScope( + const std::string& scope, std::ostream& os) { // NOLINT(*) + if (scope == "global") { + os << "device"; + } else if (scope == "shared") { + os << "threadgroup"; + } +} + +void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + PrintType(op->type, os); + os << "("; + for (int i = 0; i < op->lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } + os << ')'; +} +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/codegen_metal.h b/src/codegen/codegen_metal.h new file mode 100644 index 00000000..7331670d --- /dev/null +++ b/src/codegen/codegen_metal.h @@ -0,0 +1,34 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_metal.h + * \brief Generate Metal device code. + */ +#ifndef TVM_CODEGEN_CODEGEN_METAL_H_ +#define TVM_CODEGEN_CODEGEN_METAL_H_ + +#include +#include +#include +#include "./codegen_c.h" + +namespace tvm { +namespace codegen { + +class CodeGenMetal final : public CodeGenC { + public: + CodeGenMetal(); + void AddFunction(LoweredFunc f); + // override print thread tag. + void PrintArgUnionDecl(); + void InitFuncState(LoweredFunc f) final; + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const Call* op) final; // NOLINT(*) + void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) + + // overload visitor + void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) +}; +} // namespace codegen +} // namespace tvm + +#endif // TVM_CODEGEN_CODEGEN_METAL_H_ diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 11c86a8c..f2ad3fe5 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2017 by Contributors - * \file codegen_cuda.cc + * \file codegen_opencl.cc */ #include #include @@ -22,19 +22,20 @@ void CodeGenOpenCL::InitFuncState(LoweredFunc f) { } void CodeGenOpenCL::AddFunction(LoweredFunc f) { - this->stream << " __kernel "; + this->stream << "__kernel "; CodeGenC::AddFunction(f); } -void CodeGenOpenCL::PrintThreadIndexExpr( - std::string tag, std::ostream& os) { // NOLINT(*) - runtime::ThreadScope ts = runtime::ThreadScope::make(tag); +void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { + CHECK(!var_idmap_.count(iv->var.get())); + runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + std::ostringstream os; if (ts.rank == 1) { os << "get_local_id(" << ts.dim_index << ")"; } else { - os << "get_global_id(" << ts.dim_index << ")" - << " / get_local_size(" << ts.dim_index << ")"; + os << "get_group_id(" << ts.dim_index << ")"; } + var_idmap_[iv->var.get()] = os.str(); } void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*) @@ -115,7 +116,9 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer, void CodeGenOpenCL::PrintStorageSync(const Call* op) { const std::string& sync = op->args[0].as()->value; - if (sync == "shared") { + if (sync == "warp") { + LOG(FATAL) << "warp sync not supported in opencl"; + } else if (sync == "shared") { this->PrintIndent(); this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; } else if (sync == "global") { @@ -128,8 +131,20 @@ void CodeGenOpenCL::PrintStorageScope( if (scope == "global") { os << "__global"; } else if (scope == "shared") { - os << "__local "; + os << "__local"; } } + +void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) + std::string v = PrintExpr(op->value); + os << '('; + PrintType(op->type, os); + os << ")("; + for (int i = 0; i < op->lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } + os << ')'; +} } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index b96f891e..9a00b5e0 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file codegen_opencl.h - * \brief Utility to generate opencl code + * \brief Generate OpenCL device code. */ #ifndef TVM_CODEGEN_CODEGEN_OPENCL_H_ #define TVM_CODEGEN_CODEGEN_OPENCL_H_ @@ -14,13 +14,12 @@ namespace tvm { namespace codegen { -class CodeGenOpenCL : public CodeGenC { +class CodeGenOpenCL final : public CodeGenC { public: void AddFunction(LoweredFunc f); // override print thread tag. void InitFuncState(LoweredFunc f) final; - void PrintThreadIndexExpr( - std::string tag, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) @@ -32,6 +31,8 @@ class CodeGenOpenCL : public CodeGenC { // the address of load/store void PrintVecAddr(const Variable* buffer, Type t, Expr base, std::ostream& os); // NOLINT(*) + // overload visitor + void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) }; } // namespace codegen diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h index 6f2f25d2..0ee5b71d 100644 --- a/src/codegen/codegen_source_base.h +++ b/src/codegen/codegen_source_base.h @@ -102,6 +102,12 @@ class CodeGenSourceBase { int indent_{0}; }; +/*! + * \brief Create a source module for viewing. + * \param code The code to be viewed. + * \param fmt The code. format. + */ +runtime::Module SourceModuleCreate(std::string code, std::string fmt); } // namespace codegen } // namespace tvm #endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc new file mode 100644 index 00000000..0bd727a2 --- /dev/null +++ b/src/codegen/source_module.cc @@ -0,0 +1,52 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file source_module.cc + * \brief Source code module, only for viewing + */ +#include +#include "./codegen_source_base.h" + +namespace tvm { +namespace codegen { + +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::PackedFunc; +// Simulator function +class SourceModuleNode : public runtime::ModuleNode { + public: + SourceModuleNode(std::string code, + std::string fmt) + : code_(code), fmt_(fmt) {} + const char* type_key() const { + return "source"; + } + void PreCompile(const std::string& name, TVMContext ctx) final { + } + PackedFunc GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) final { + LOG(FATAL) << "Source module cannot execute, to get executable module" + << " build TVM with \'" << fmt_ << "\' runtime support"; + return PackedFunc(); + } + void SaveToFile(const std::string& file_name, + const std::string& format) final { + LOG(FATAL) << "not implemented"; + } + std::string GetSource(const std::string& format) final { + return code_; + } + + private: + std::string code_; + std::string fmt_; +}; + +runtime::Module SourceModuleCreate(std::string code, std::string fmt) { + std::shared_ptr n = + std::make_shared(code, fmt); + return runtime::Module(n); +} +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/verilog/verilog_module.cc b/src/codegen/verilog/verilog_module.cc index e0cc9be5..53215ad9 100644 --- a/src/codegen/verilog/verilog_module.cc +++ b/src/codegen/verilog/verilog_module.cc @@ -31,11 +31,7 @@ class VerilogModuleNode : public runtime::ModuleNode { const std::string& name, const std::shared_ptr& sptr_to_self) final { CHECK(sptr_to_self.get() == this); - if (name == runtime::symbol::tvm_entry_setdevice) { - return PackedFunc([](const TVMArgs& args, TVMRetValue* rv){}); - } - CHECK(m_.fmap.count(name)) << "Cannot find function " << name << " in the module"; - + if (!m_.fmap.count(name)) return PackedFunc(); auto f = [sptr_to_self, name, this](const runtime::TVMArgs& args, TVMRetValue* rv) { auto* fsim = runtime::Registry::Get("tvm_callback_verilog_simulator"); CHECK(fsim != nullptr) diff --git a/src/codegen/verilog/vpi_device_api.cc b/src/codegen/verilog/vpi_device_api.cc index b2a62df3..af886207 100644 --- a/src/codegen/verilog/vpi_device_api.cc +++ b/src/codegen/verilog/vpi_device_api.cc @@ -16,7 +16,7 @@ namespace tvm { namespace codegen { /*! \brief Simulated device ram */ -class VPIDeviceAPI : public runtime::DeviceAPI { +class VPIDeviceAPI final : public runtime::DeviceAPI { public: VPIDeviceAPI() { static const size_t kAllocAlign = 32U; @@ -44,6 +44,12 @@ class VPIDeviceAPI : public runtime::DeviceAPI { if (ptr + size >= ram_max_) return nullptr; return (char*)(&ram_[0]) + ptr; // NOLINT(*) } + void SetDevice(int dev_id) final {} + void GetAttr(int dev_id, runtime::DeviceAttrKind kind, TVMRetValue* rv) final { + if (kind == runtime::kExist) { + *rv = 1; + } + } void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final { static const size_t kAllocAlign = 32U; // always align to 32 bytes at least. @@ -80,16 +86,18 @@ class VPIDeviceAPI : public runtime::DeviceAPI { free_blocks_.insert({b.size, head}); } void CopyDataFromTo(const void* from, + size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, TVMStreamHandle stream) final { if (static_cast(ctx_from.device_type) == kVPI) { - from = RealAddr(from, size); + from = RealAddr(static_cast(from) + from_offset, size); } if (static_cast(ctx_to.device_type) == kVPI) { - to = RealAddr(to, size); + to = RealAddr(static_cast(to) + to_offset, size); } memcpy(to, from, size); } diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index a3961695..c271c61e 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -156,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator { seq.emplace_back(Store::make( shared_buf, value, BufIndex(reduce_index, group_index, reduce_extent))); - seq.emplace_back(SyncThread()); + seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce( combiner, value.type(), shared_buf, reduce_index, group_index, reduce_extent, threadx_extent)); @@ -202,7 +202,7 @@ class ThreadAllreduceBuilder : public IRMutator { reduce_align = reduce_align >> 1; Expr cond = reduce_index < (reduce_extent - reduce_align); seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align))); - seq.emplace_back(SyncThread()); + seq.emplace_back(SyncThread("shared")); } CHECK(threadx_extent >= 1 && warp_size_ >= 1); // normal synchronization @@ -211,7 +211,7 @@ class ThreadAllreduceBuilder : public IRMutator { reduce_align = reduce_align >> 1; Expr cond = reduce_index < reduce_align; seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align))); - seq.emplace_back(SyncThread()); + seq.emplace_back(SyncThread("shared")); } // in warp synchronization. std::vector in_warp_seq; @@ -219,6 +219,7 @@ class ThreadAllreduceBuilder : public IRMutator { while (reduce_align > 1) { reduce_align = reduce_align >> 1; in_warp_seq.emplace_back(freduce(reduce_align)); + seq.emplace_back(SyncThread("warp")); } if (in_warp_seq.size() != 0) { Stmt warp_body = MergeSeq(in_warp_seq); @@ -249,10 +250,10 @@ class ThreadAllreduceBuilder : public IRMutator { return ret; } // sync thread op. - static Stmt SyncThread() { + static Stmt SyncThread(const std::string& sync) { return Evaluate::make( Call::make(Int(32), intrinsic::tvm_storage_sync, - {StringImm::make("shared")}, + {StringImm::make(sync)}, Call::Intrinsic)); } // The local buffer index. diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 20c87c15..6e9e95fb 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -244,6 +244,12 @@ LoweredFunc MakeAPI(Stmt body, node, attr::device_context_id, device_id, nop)); seq_init.push_back(AttrStmt::make( node, attr::device_context_type, device_type, nop)); + Stmt set_device = IfThenElse::make( + device_type != kCPU, Evaluate::make(Call::make( + Int(32), intrinsic::tvm_call_packed, + {StringImm::make(runtime::symbol::tvm_set_device), + device_type, device_id}, Call::Intrinsic))); + body = Block::make(set_device, body); } n->body = MergeNest({seq_init, seq_check}, body); LoweredFunc f(n); diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index e1f7083d..a6e99f6a 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -162,20 +162,6 @@ class HostDeviceSplitter : public IRMutator { std::shared_ptr n = std::make_shared(*f.operator->()); n->body = this->Mutate(f->body); - - if (f->is_packed_func && device_funcs_.size() != 0) { - // insert auto set device from device function. - Array args = {StringImm::make(runtime::symbol::tvm_entry_setdevice)}; - for (Var arg : f->args) { - args.push_back(arg); - } - n->body = Block::make( - Evaluate::make(Call::make( - Int(32), intrinsic::tvm_call_packed, - args, Call::Intrinsic)), - n->body); - } - Array ret{LoweredFunc(n)}; for (LoweredFunc x : device_funcs_) { ret.push_back(x); @@ -193,14 +179,21 @@ class HostDeviceSplitter : public IRMutator { m.visit_thread_extent_ = false; n->body = m.Mutate(body); n->name = os.str(); - n->args = m.undefined_; n->thread_axis = m.thread_axis_; - - // improve the handle data type - for (Var arg : n->args) { - auto it = handle_data_type_.find(arg.get()); - if (it != handle_data_type_.end()) { - n->handle_data_type.Set(arg, it->second); + // Strictly order the arguments: Var pointers, positional arguments. + for (Var v : m.undefined_) { + if (v.type().is_handle()) { + n->args.push_back(v); + // mark handle data type. + auto it = handle_data_type_.find(v.get()); + if (it != handle_data_type_.end()) { + n->handle_data_type.Set(v, it->second); + } + } + } + for (Var v : m.undefined_) { + if (!v.type().is_handle()) { + n->args.push_back(v); } } LoweredFunc f_device(n); @@ -209,7 +202,6 @@ class HostDeviceSplitter : public IRMutator { for (Var arg : n->args) { call_args.push_back(arg); } - for (Expr ext : m.thread_extent_) { call_args.push_back(ext); } diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 7b5e9deb..f874f272 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -54,14 +54,17 @@ class StorageSyncPlanner : public IRVisitor { if (!in_device_env_) return; if (const Call* call = op->value.as()) { if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { - StorageScope scope = StorageScope::make(call->args[0].as()->value); - if (scope.rank <= sync_scope_.rank) { - CHECK_EQ(curr_stmt_.access.size(), 0U); - curr_stmt_.access.emplace_back( - AccessEntry(nullptr, Expr(), kSync, scope)); - // push to the scope - scope_.back().push_back(curr_stmt_); - curr_stmt_.access.clear(); + const std::string& s = call->args[0].as()->value; + if (s != "warp") { + StorageScope scope = StorageScope::make(s); + if (scope.rank <= sync_scope_.rank) { + CHECK_EQ(curr_stmt_.access.size(), 0U); + curr_stmt_.access.emplace_back( + AccessEntry(nullptr, Expr(), kSync, scope)); + // push to the scope + scope_.back().push_back(curr_stmt_); + curr_stmt_.access.clear(); + } } } } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 4f059726..08288c91 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -25,8 +25,11 @@ class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; // Get API - static DeviceAPI* Get(TVMContext ctx) { - return Global()->GetAPI(ctx.device_type); + static DeviceAPI* Get(const TVMContext& ctx) { + return Get(ctx.device_type); + } + static DeviceAPI* Get(int dev_type, bool allow_missing = false) { + return Global()->GetAPI(dev_type, allow_missing); } private: @@ -42,20 +45,25 @@ class DeviceAPIManager { return &inst; } // Get or initialize API. - DeviceAPI* GetAPI(DLDeviceType type) { - if (api_[type] != nullptr) return api_[type]; - std::lock_guard lock(mutex_); - if (api_[type] != nullptr) return api_[type]; - std::string factory = "device_api." + DeviceName(type); - auto* f = Registry::Get(factory); - CHECK(f != nullptr) - << "Device API " << DeviceName(type) << " is not enabled."; - void* ptr = (*f)(); - api_[type] = static_cast(ptr); - return api_[type]; - } + DeviceAPI* GetAPI(int type, bool allow_missing); }; +DeviceAPI* DeviceAPIManager::GetAPI(int type, bool allow_missing) { + if (api_[type] != nullptr) return api_[type]; + std::lock_guard lock(mutex_); + if (api_[type] != nullptr) return api_[type]; + std::string factory = "device_api." + DeviceName(type); + auto* f = Registry::Get(factory); + if (f == nullptr) { + CHECK(allow_missing) + << "Device API " << DeviceName(type) << " is not enabled."; + return nullptr; + } + void* ptr = (*f)(); + api_[type] = static_cast(ptr); + return api_[type]; +} + inline TVMArray* TVMArrayCreate_() { TVMArray* arr = new TVMArray(); @@ -352,8 +360,9 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, << "Can not copy across different ctx types directly"; } DeviceAPIManager::Get(ctx)->CopyDataFromTo( - from->data, to->data, from_size, - from->ctx, to->ctx, stream); + from->data, from->byte_offset, + to->data, to->byte_offset, + from_size, from->ctx, to->ctx, stream); API_END(); } @@ -362,3 +371,29 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream); API_END(); } + +// set device api +TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) +.set_body([](TVMArgs args, TVMRetValue *ret) { + int dev_type = args[0]; + int dev_id = args[1]; + DeviceAPIManager::Get(dev_type)->SetDevice(dev_id); + }); + +// set device api +TVM_REGISTER_GLOBAL("_GetDeviceAttr") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int dev_type = args[0]; + int dev_id = args[1]; + DeviceAttrKind kind = static_cast(args[2].operator int()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(dev_type, true); + if (api != nullptr) { + api->GetAttr(dev_id, kind, ret); + } else { + *ret = 0; + } + } else { + DeviceAPIManager::Get(dev_type)->GetAttr(dev_id, kind, ret); + } + }); diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index fa95d261..86f85d46 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -11,8 +11,14 @@ namespace tvm { namespace runtime { -class CPUDeviceAPI : public DeviceAPI { +class CPUDeviceAPI final : public DeviceAPI { public: + void SetDevice(int dev_id) final {} + void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final { + if (kind == kExist) { + *rv = 1; + } + } void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final { void* ptr; #if _MSC_VER @@ -34,12 +40,16 @@ class CPUDeviceAPI : public DeviceAPI { } void CopyDataFromTo(const void* from, + size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, TVMStreamHandle stream) final { - memcpy(to, from, size); + memcpy(static_cast(to) + to_offset, + static_cast(from) + from_offset, + size); } void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 141158cb..9c617a6e 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -15,8 +15,33 @@ namespace tvm { namespace runtime { -class CUDADeviceAPI : public DeviceAPI { +class CUDADeviceAPI final : public DeviceAPI { public: + void SetDevice(int dev_id) final { + CUDA_CALL(cudaSetDevice(dev_id)); + } + void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final { + int value; + switch (kind) { + case kExist: + value = ( + cudaDeviceGetAttribute( + &value, cudaDevAttrMaxThreadsPerBlock, dev_id) + == cudaSuccess); + break; + case kMaxThreadsPerBlock: { + CUDA_CALL(cudaDeviceGetAttribute( + &value, cudaDevAttrMaxThreadsPerBlock, dev_id)); + break; + } + case kWarpSize: { + CUDA_CALL(cudaDeviceGetAttribute( + &value, cudaDevAttrWarpSize, dev_id)); + break; + } + } + *rv = value; + } void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CHECK_EQ(256 % alignment, 0U) @@ -32,12 +57,16 @@ class CUDADeviceAPI : public DeviceAPI { } void CopyDataFromTo(const void* from, + size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, TVMStreamHandle stream) final { cudaStream_t cu_stream = static_cast(stream); + from = static_cast(from) + from_offset; + to = static_cast(to) + to_offset; if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); if (ctx_from.device_id == ctx_to.device_id) { diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 9be1c4c4..bf4cbf98 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -14,7 +14,7 @@ #include #include #include "./cuda_common.h" -#include "../void_addr_args.h" +#include "../pack_args.h" #include "../thread_storage_scope.h" #include "../meta_data.h" #include "../file_util.h" @@ -214,41 +214,13 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; - -void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 3); - TVMValue* values = static_cast(args[0].operator void*()); - int* type_codes = static_cast(args[1].operator void*()); - int num_args = args[2].operator int(); - - int device_id = -1; - for (int i = 0; i < num_args; ++i) { - if (type_codes[i] == kArrayHandle) { - TVMContext ctx = static_cast(values[i].v_handle)->ctx; - CHECK_EQ(ctx.device_type, kGPU) - << "All operands need to be GPU"; - if (device_id == -1) { - device_id = ctx.device_id; - } else { - CHECK_EQ(device_id, ctx.device_id) - << "Operands comes from different devices "; - } - } - } - CHECK_NE(device_id, -1) - << "Cannot detect device id from list"; - CUDA_CALL(cudaSetDevice(device_id)); -} - PackedFunc CUDAModuleNode::GetFunction( const std::string& name, const std::shared_ptr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; - if (name == symbol::tvm_entry_setdevice) { - return PackedFunc(AutoSetCUDADevice); - } else if (name == symbol::tvm_prepare_global_barrier) { + if (name == symbol::tvm_prepare_global_barrier) { return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self)); } auto it = fmap_.find(name); @@ -256,7 +228,7 @@ PackedFunc CUDAModuleNode::GetFunction( const FunctionInfo& info = it->second; CUDAWrappedFunc f; f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); - return PackFromVoidAddrArgs(f, info.arg_types); + return PackFuncVoidAddr(f, info.arg_types); } Module CUDAModuleCreate( diff --git a/src/runtime/device_api.h b/src/runtime/device_api.h index 699655b0..f444997c 100644 --- a/src/runtime/device_api.h +++ b/src/runtime/device_api.h @@ -13,10 +13,29 @@ namespace tvm { namespace runtime { +enum DeviceAttrKind : int { + kExist = 0, + kMaxThreadsPerBlock = 1, + kWarpSize = 2 +}; + class DeviceAPI { public: /*! \brief virtual destructor */ virtual ~DeviceAPI() {} + /*! + * \brief Set the environment device id to dev_id + * \param dev_id The device id. + * \return The allocated device pointer + */ + virtual void SetDevice(int dev_id) = 0; + /*! + * \brief Get attribute of specified device. + * \param dev_id The device id + * \param kind The result kind + * \param rv The return value. + */ + virtual void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) = 0; /*! * \brief Allocate a data space on device. * \param ctx The device context to perform operation. @@ -36,13 +55,18 @@ class DeviceAPI { * \brief copy data from one place to another * \param dev The device to perform operation. * \param from The source array. + * \param from_offset The byte offeset in the from. * \param to The target array. + * \param to_offset The byte offset in the to. * \param size The size of the memory * \param ctx_from The source context * \param ctx_to The target context + * \param stream Optional stream object. */ virtual void CopyDataFromTo(const void* from, + size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, @@ -59,11 +83,12 @@ class DeviceAPI { * \brief The name of Device API factory. * \param type The device type. */ -inline std::string DeviceName(DLDeviceType type) { - switch (static_cast(type)) { +inline std::string DeviceName(int type) { + switch (type) { case kCPU: return "cpu"; case kGPU: return "gpu"; case kOpenCL: return "opencl"; + case kMetal: return "metal"; case kVPI: return "vpi"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 69e6f6a5..6632bced 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -30,7 +30,6 @@ struct FunctionInfo { void Save(dmlc::JSONWriter *writer) const; void Load(dmlc::JSONReader *reader); }; - } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_META_DATA_H_ diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h new file mode 100644 index 00000000..742b5ad7 --- /dev/null +++ b/src/runtime/metal/metal_common.h @@ -0,0 +1,98 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file metal_common.h + * \brief Metal common header + */ +#ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_ +#define TVM_RUNTIME_METAL_METAL_COMMON_H_ + +#import +#import +#import +#import +#import +#import + +#include +#include +#include +#include +#include +#include +#include +#include "../device_api.h" + +namespace tvm { +namespace runtime { +namespace metal { +/*! + * \brief Process global Metal workspace. + */ +class MetalWorkspace final : public DeviceAPI { + public: + // the devices + std::vector > devices; + // the queues + std::vector > queues; + // Warp size constant + std::vector warp_size; + // Whether it is initialized. + bool initialized_{false}; + // the mutex for initialization + std::mutex mutex; + // Get command queue for given context. + id GetCommandQueue(TVMContext ctx) { + CHECK_EQ(ctx.device_type, kMetal); + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + << "Invalid Metal device_id=" << ctx.device_id; + return queues[ctx.device_id]; + } + // Get device for given context + id GetDevice(TVMContext ctx) { + CHECK_EQ(ctx.device_type, kMetal); + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) + << "Invalid Metal device_id=" << ctx.device_id; + return devices[ctx.device_id]; + } + // Initialize workspace + // Return false if already initialized, otherwise return true. + void Init(); + // override device API + void SetDevice(int dev_id) final; + void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final; + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final; + void FreeDataSpace(TVMContext ctx, void* ptr) final; + void CopyDataFromTo(const void* from, + size_t from_size, + void* to, + size_t to_size, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + TVMStreamHandle stream) final; + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; + // get the global workspace + static MetalWorkspace* Global(); +}; + +/*! \brief Thread local workspace */ +class MetalThreadEntry { + public: + /*! \brief The current context */ + TVMContext context; + /*! \brief The shared buffer used for copy. */ + std::vector > temp_buffer_; + + MetalThreadEntry() { + context.device_id = 0; + context.device_type = static_cast(kMetal); + } + // Get temp buffer with at least size under ctx. + id GetTempBuffer(TVMContext ctx, size_t size); + // get the global workspace + static MetalThreadEntry* ThreadLocal(); +}; +} // namespace metal +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_METAL_METAL_COMMON_H_ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm new file mode 100644 index 00000000..701c3944 --- /dev/null +++ b/src/runtime/metal/metal_device_api.mm @@ -0,0 +1,240 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file metal_device_api.mm + */ +#include "./metal_common.h" + +#if TVM_METAL_RUNTIME +#include +#include + +namespace tvm { +namespace runtime { +namespace metal { + +MetalWorkspace* MetalWorkspace::Global() { + static MetalWorkspace inst; + return &inst; +} + +void MetalWorkspace::GetAttr( + int dev_id, DeviceAttrKind kind, TVMRetValue* rv) { + this->Init(); + size_t index = static_cast(dev_id); + if (kind == kExist) { + *rv = int(index< devices.size()); + return; + } + CHECK_LT(index, devices.size()) + << "Invalid device id " << index; + switch (kind) { + case kMaxThreadsPerBlock: { + *rv = static_cast( + [devices[dev_id] maxThreadsPerThreadgroup].width); + break; + } + case kWarpSize: { + // Set warp size to be 1 for safty reason. + *rv = 1; + break; + } + case kExist: break; + } +} + +static const char* kDummyKernel = R"A0B0( +using namespace metal; +// Simple copy kernel +// Just to get threadExecutionWidth from current Metal API. +kernel void CopyKernel( + device float* dst [[buffer(0)]], + device float* src [[buffer(1)]], + ushort2 gid[[thread_position_in_grid]]) { + dst[gid.x] = src[gid.x]; +} +)A0B0"; + +// Hack to get Warp size from device. +// Note that in Metal +// state.threadExecutionWidth can vary per kernel +// maybe due to resource constraint. +// so state.threadExecutionWidth can be smaller than warp size +// For safe issue, turn off warp-aware optimization for now +// But we keep this code. +int GetWarpSize(id dev) { + NSError* error_msg = nil; + id lib = + [dev + newLibraryWithSource: + [NSString stringWithUTF8String:kDummyKernel] + options:nil + error:&error_msg]; + CHECK(lib != nil) << error_msg; + id f = + [lib + newFunctionWithName: + [NSString stringWithUTF8String:"CopyKernel"]]; + CHECK(f!= nil); + id state = + [dev + newComputePipelineStateWithFunction:f + error:&error_msg]; + CHECK(state != nil) << error_msg; + return state.threadExecutionWidth; +} + +void MetalWorkspace::Init() { + if (initialized_) return; + std::lock_guard(this->mutex); + if (initialized_) return; + initialized_ = true; + if (devices.size() != 0) return; + NSArray>* devs = MTLCopyAllDevices(); + for (size_t i = 0; i < devs.count; ++i) { + id d = [devs objectAtIndex:i]; + devices.push_back(d); + queues.push_back([d newCommandQueue]); + LOG(INFO) << "Intializing Metal device " << i + << ", name=" << d.name; + warp_size.push_back(GetWarpSize(d)); + } +} + +void MetalWorkspace::SetDevice(int dev_id) { + MetalThreadEntry::ThreadLocal()->context.device_id = dev_id; +} + +void* MetalWorkspace::AllocDataSpace( + TVMContext ctx, size_t size, size_t alignment) { + this->Init(); + id dev = GetDevice(ctx); + // allocate buffer in GPU only mode. + id buf = [ + dev newBufferWithLength:size + options:MTLResourceStorageModePrivate]; + // retain ARC to keep it alive before release. + return (__bridge_retained void*)(buf); +} + +void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { + // release the ptr. + CFBridgingRelease(ptr); +} + +void MetalWorkspace::CopyDataFromTo(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t size, + TVMContext ctx_from, + TVMContext ctx_to, + TVMStreamHandle stream) { + this->Init(); + CHECK(stream == nullptr); + TVMContext ctx = ctx_from; + if (ctx_from.device_type == kCPU) ctx = ctx_to; + id queue = GetCommandQueue(ctx); + id cb = [queue commandBuffer]; + id encoder = [cb blitCommandEncoder]; + int from_dev_type = static_cast(ctx_from.device_type); + int to_dev_type = static_cast(ctx_to.device_type); + + if (from_dev_type == kMetal && to_dev_type == kMetal) { + CHECK_EQ(ctx_from.device_id, ctx_to.device_id) + << "Metal disallow cross device copy."; + [encoder copyFromBuffer:(__bridge id)(from) + sourceOffset:from_offset + toBuffer:(__bridge id)(to) + destinationOffset:to_offset + size:size]; + [encoder endEncoding]; + [cb commit]; + } else if (from_dev_type == kMetal && to_dev_type == kCPU) { + // copy to a local buffer before get into global buffer. + id from_buf = (__bridge id)(from); + if (from_buf.storageMode != MTLStorageModeShared) { + id temp = MetalThreadEntry::ThreadLocal() + ->GetTempBuffer(ctx_from, size); + [encoder copyFromBuffer:from_buf + sourceOffset:from_offset + toBuffer:temp + destinationOffset:0 + size:size]; + [encoder endEncoding]; + [cb commit]; + [cb waitUntilCompleted]; + memcpy(static_cast(to) + to_offset, + static_cast([temp contents]), + size); + } else { + memcpy(static_cast(to) + to_offset, + static_cast([from_buf contents]) + from_offset, + size); + } + } else if (from_dev_type == kCPU && to_dev_type == kMetal) { + id to_buf = (__bridge id)(to); + if (to_buf.storageMode == MTLStorageModeShared) { + id temp = MetalThreadEntry::ThreadLocal() + ->GetTempBuffer(ctx_to, size); + memcpy([temp contents], + static_cast(from) + from_offset, + size); + [encoder copyFromBuffer:temp + sourceOffset:0 + toBuffer:to_buf + destinationOffset:to_offset + size:size]; + [encoder endEncoding]; + [cb commit]; + } else { + memcpy(static_cast([to_buf contents]) + to_offset, + static_cast(from) + from_offset, + size); + } + } else { + LOG(FATAL) << "Expect copy from/to Metal or between Metal" + << ", from=" << from_dev_type + << ", to=" << to_dev_type; + } +} + +void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { + CHECK(stream == nullptr); + // commit an empty command buffer and wait until it completes. + id queue = GetCommandQueue(ctx); + id cb = [queue commandBuffer]; + [cb commit]; + [cb waitUntilCompleted]; +} + +id MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) { + if (temp_buffer_.size() <= static_cast(ctx.device_id)) { + temp_buffer_.resize(ctx.device_id + 1, nil); + } + if (temp_buffer_[ctx.device_id] == nil || + temp_buffer_[ctx.device_id].length < size) { + id dev = MetalWorkspace::Global()->GetDevice(ctx); + temp_buffer_[ctx.device_id] = [ + dev newBufferWithLength:size + options:MTLStorageModeShared]; + } + return temp_buffer_[ctx.device_id]; +} + +typedef dmlc::ThreadLocalStore MetalThreadStore; + +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { + return MetalThreadStore::Get(); +} + +TVM_REGISTER_GLOBAL("device_api.metal") +.set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MetalWorkspace::Global(); + *rv = static_cast(ptr); + }); + +} // namespace metal +} // namespace runtime +} // namespace tvm + +#endif // TVM_METAL_RUNTIME diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h new file mode 100644 index 00000000..bb2f9c86 --- /dev/null +++ b/src/runtime/metal/metal_module.h @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file metal_module.h + * \brief Execution handling of Metal kernels + */ +#ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_ +#define TVM_RUNTIME_METAL_METAL_MODULE_H_ + +#include +#include +#include +#include +#include +#include "../meta_data.h" + +namespace tvm { +namespace runtime { +/*! \brief Maximum number of GPU supported in MetalModule. */ +static constexpr const int kMetalMaxNumDevice = 32; + +/*! + * \brief create a metal module from data. + * + * \param data The data content. + * \param fmt The format of the data, can be "metal" or "metallib" + * \param fmap The map function information map of each function. + * \param source Optional, source file + */ +Module MetalModuleCreate( + std::string data, + std::string fmt, + std::unordered_map fmap, + std::string source); +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm new file mode 100644 index 00000000..e2b39bdd --- /dev/null +++ b/src/runtime/metal/metal_module.mm @@ -0,0 +1,273 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file metal_module.cc + */ +#include "./metal_module.h" + +#if TVM_METAL_RUNTIME + +#include +#include +#include +#include +#include +#include "./metal_common.h" +#include "../pack_args.h" +#include "../thread_storage_scope.h" +#include "../meta_data.h" +#include "../file_util.h" + +namespace tvm { +namespace runtime { + +// Module to support thread-safe multi-GPU execution. +// cuModule is a per-GPU module +// The runtime will contain a per-device module table +// The modules will be lazily loaded +class MetalModuleNode final :public runtime::ModuleNode { + public: + explicit MetalModuleNode(std::string data, + std::string fmt, + std::unordered_map fmap, + std::string source) + : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { + } + const char* type_key() const final { + return "metal"; + } + + void PreCompile(const std::string& name, TVMContext ctx) final { + GetPipelineState(ctx.device_id, name); + } + + PackedFunc GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + void SaveToFile(const std::string& file_name, + const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + CHECK_EQ(fmt, fmt_) + << "Can only save to format=" << fmt_; + std::string meta_file = GetMetaFilePath(file_name); + SaveMetaDataToFile(meta_file, fmap_); + SaveBinaryToFile(file_name, data_); + } + + std::string GetSource(const std::string& format) final { + if (format == fmt_) return data_; + if (source_.length() != 0) { + return source_; + } else if (fmt_ == "metal") { + return data_; + } else { + return ""; + } + } + // get a CUfunction from primary context in device_id + id GetPipelineState( + size_t device_id, const std::string& func_name) { + metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); + CHECK_LT(device_id, w->devices.size()); + // start lock scope. + std::lock_guard lock(mutex_); + if (finfo_.size() <= device_id) { + finfo_.resize(device_id + 1, DeviceEntry()); + } + DeviceEntry& e = finfo_[device_id]; + auto it = e.smap.find(func_name); + if (it != e.smap.end()) return it->second; + // compile + NSError* err_msg = nil; + if (e.lib == nil) { + if (fmt_ == "metal") { + e.lib = [ + w->devices[device_id] + newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] + options:nil + error:&err_msg]; + if (err_msg != nil || e.lib == nil) { + LOG(FATAL) << "Fail to compile metal lib:" + << [[err_msg localizedDescription] UTF8String]; + } + } else { + // Build from library. + auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); + auto data = dispatch_data_create( + data_.c_str(), data_.length(), q, ^{}); + e.lib = [ + w->devices[device_id] + newLibraryWithData:data + error:&err_msg]; + if (err_msg != nil || e.lib == nil) { + LOG(FATAL) << "Fail to compile metal lib:" + << [[err_msg localizedDescription] UTF8String]; + } + } + } + id f = [ + e.lib + newFunctionWithName: + [NSString stringWithUTF8String:func_name.c_str()]]; + CHECK(f != nil) << "cannot find function " << func_name; + id state = + [w->devices[device_id] + newComputePipelineStateWithFunction:f + error:&err_msg]; + CHECK(state != nil) + << "cannot get state:" << " for function " << func_name + << [[err_msg localizedDescription] UTF8String]; + // The state.threadExecutionWidth can change dynamically according + // to the resource constraint in kernel, so it is not strictly hold + // Turn of warp aware optimziation for now. + // CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]); + e.smap[func_name] = state; + return state; + } + + private: + // device specific entry + struct DeviceEntry { + // library + id lib = nil; + // state cache; + std::unordered_map > smap; + }; + // the binary data + std::string data_; + // The format + std::string fmt_; + // function information table. + std::unordered_map fmap_; + // The source + std::string source_; + // function information. + std::vector finfo_; + // internal mutex when updating the module + std::mutex mutex_; +}; + +// a wrapped function class to get packed fucn. +class MetalWrappedFunc { + public: + // initialize the METAL function. + void Init(MetalModuleNode* m, + std::shared_ptr sptr, + const std::string& func_name, + size_t num_buffer_args, + size_t num_pack_args, + const std::vector& thread_axis_tags) { + w_ = metal::MetalWorkspace::Global(); + m_ = m; + sptr_ = sptr; + func_name_ = func_name; + num_buffer_args_ = num_buffer_args; + num_pack_args_ = num_pack_args; + std::fill(scache_.begin(), scache_.end(), (id)nil); + thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); + metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); + int dev_id = t->context.device_id; + scache_[dev_id] = m->GetPipelineState(dev_id, func_name); + } + // invoke the function with void arguments + void operator()(TVMArgs args, + TVMRetValue* rv, + const ArgUnion* pack_args) const { + metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); + int device_id = t->context.device_id; + if (scache_[device_id] == nil) { + scache_[device_id] = m_->GetPipelineState(device_id, func_name_); + } + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + id queue = w_->GetCommandQueue(t->context); + id cb = [queue commandBuffer]; + id encoder = [cb computeCommandEncoder]; + [encoder setComputePipelineState:scache_[device_id]]; + for (size_t i = 0; i < num_buffer_args_; ++i) { + void* buf = args[i]; + [encoder setBuffer:(__bridge id)(buf) offset:0 atIndex:i]; + } + if (num_pack_args_ != 0) { + [encoder setBytes:pack_args + length:num_pack_args_ * sizeof(ArgUnion) + atIndex:num_buffer_args_]; + } + // launch + MTLSize dimGrid = MTLSizeMake( + wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + MTLSize dimBlock = MTLSizeMake( + wl.block_dim(0), wl.block_dim(1), wl.work_size[2]); + [encoder dispatchThreadgroups: dimGrid + threadsPerThreadgroup: dimBlock]; + [encoder endEncoding]; + [cb commit]; + } + + private: + // Reference to global workspace. + metal::MetalWorkspace* w_; + // internal module + MetalModuleNode* m_; + // the resource holder + std::shared_ptr sptr_; + // The name of the function. + std::string func_name_; + // Number of buffer arguments + size_t num_buffer_args_; + // number of packed arguments. + size_t num_pack_args_; + // Device state cache per device. + // mark as mutable, to enable lazy initialization + mutable std::array, kMetalMaxNumDevice> scache_; + // thread axis configuration + ThreadAxisConfig thread_axis_cfg_; +}; + +PackedFunc MetalModuleNode::GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) { + CHECK_EQ(sptr_to_self.get(), this); + CHECK_NE(name, symbol::tvm_module_main) + << "Device function do not have main"; + auto it = fmap_.find(name); + if (it == fmap_.end()) return PackedFunc(); + const FunctionInfo& info = it->second; + MetalWrappedFunc f; + size_t num_buffer_args = NumBufferArgs(info.arg_types); + f.Init(this, sptr_to_self, name, + num_buffer_args, info.arg_types.size() - num_buffer_args, + info.thread_axis_tags); + return PackFuncNonBufferArg(f, info.arg_types); +} + +Module MetalModuleCreate( + std::string data, + std::string fmt, + std::unordered_map fmap, + std::string source) { + metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); + w->Init(); + std::shared_ptr n = + std::make_shared(data, fmt, fmap, source); + return Module(n); +} + +// Load module from module. +Module MetalModuleLoad(const std::string& file_name, + const std::string& format) { + std::string data; + std::unordered_map fmap; + std::string fmt = GetFileFormat(file_name, format); + std::string meta_file = GetMetaFilePath(file_name); + LoadBinaryFromFile(file_name, &data); + LoadMetaDataFromFile(meta_file, &fmap); + return MetalModuleCreate(data, fmt, fmap, ""); +} + +TVM_REGISTER_GLOBAL("module.loadfile_metal") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = MetalModuleLoad(args[0], args[1]); + }); +} // namespace runtime +} // namespace tvm +#endif // TVM_METAL_RUNTIME diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 8764bf02..b0827b73 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -84,17 +84,25 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { } bool RuntimeEnabled(const std::string& target) { - std::string load_f_name; + std::string f_name; if (target == "cpu") { return true; } else if (target == "cuda" || target == "gpu") { - load_f_name = "module.loadfile_ptx"; + f_name = "device_api.gpu"; } else if (target == "cl" || target == "opencl") { - load_f_name = "module.loadfile_cl"; + f_name = "device_api.opencl"; + } else if (target == "mtl" || target == "metal") { + f_name = "device_api.metal"; + } else if (target == "stackvm") { + f_name = "codegen.build_stackvm"; + } else if (target == "llvm") { + f_name = "codegen.build_llvm"; + } else if (target == "vpi" || target == "verilog") { + f_name = "device_api.vpi"; } else { LOG(FATAL) << "Unknown optional runtime " << target; } - return runtime::Registry::Get(load_f_name) != nullptr; + return runtime::Registry::Get(f_name) != nullptr; } TVM_REGISTER_GLOBAL("module._Enabled") diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 364a96c6..35c4f3a2 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -101,12 +101,14 @@ inline const char* CLGetErrorString(cl_int error) { /*! * \brief Process global OpenCL workspace. */ -class OpenCLWorkspace : public DeviceAPI { +class OpenCLWorkspace final : public DeviceAPI { public: // global platform id cl_platform_id platform_id; // global context of this process cl_context context{nullptr}; + // whether the workspace it initialized. + bool initialized_{false}; // the devices std::vector devices; // the queues @@ -126,24 +128,25 @@ class OpenCLWorkspace : public DeviceAPI { OPENCL_CALL(clReleaseContext(context)); } } - // whether the workspace is initialized. - inline bool initialized() const { - return context != nullptr; - } + // Initialzie the device. + void Init(); // get the queue of the context - cl_command_queue GetQueue(TVMContext ctx) const { + cl_command_queue GetQueue(TVMContext ctx) { CHECK_EQ(ctx.device_type, kOpenCL); - CHECK(initialized()) - << "The OpenCL is not initialized"; + this->Init(); CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid OpenCL device_id=" << ctx.device_id; return queues[ctx.device_id]; } // override device API + void SetDevice(int dev_id) final; + void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; void CopyDataFromTo(const void* from, + size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index cd961286..b0543b48 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -18,8 +18,41 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { return &inst; } +void OpenCLWorkspace::SetDevice(int dev_id) { + OpenCLThreadEntry::ThreadLocal()->context.device_id = dev_id; +} + +void OpenCLWorkspace::GetAttr( + int dev_id, DeviceAttrKind kind, TVMRetValue* rv) { + this->Init(); + size_t index = static_cast(dev_id); + if (kind == kExist) { + *rv = static_cast(index< devices.size()); + return; + } + CHECK_LT(index, devices.size()) + << "Invalid device id " << index; + size_t value; + switch (kind) { + case kMaxThreadsPerBlock: { + OPENCL_CALL(clGetDeviceInfo( + devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, + sizeof(size_t), &value, nullptr)); + *rv = static_cast(value); + break; + } + case kWarpSize: { + *rv = 1; + break; + } + case kExist: break; + } +} + void* OpenCLWorkspace::AllocDataSpace( TVMContext ctx, size_t size, size_t alignment) { + this->Init(); + CHECK(context != nullptr) << "No OpenCL device"; cl_int err_code; cl_mem mptr = clCreateBuffer( this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); @@ -33,30 +66,35 @@ void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { } void OpenCLWorkspace::CopyDataFromTo(const void* from, + size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, TVMStreamHandle stream) { + this->Init(); CHECK(stream == nullptr); if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) { OPENCL_CALL(clEnqueueCopyBuffer( this->GetQueue(ctx_to), static_cast((void*)from), // NOLINT(*) static_cast(to), - 0, 0, size, 0, nullptr, nullptr)); + from_offset, to_offset, size, 0, nullptr, nullptr)); } else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) { OPENCL_CALL(clEnqueueReadBuffer( this->GetQueue(ctx_from), static_cast((void*)from), // NOLINT(*) - CL_FALSE, 0, size, to, + CL_FALSE, from_offset, size, + static_cast(to) + to_offset, 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_from))); } else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) { OPENCL_CALL(clEnqueueWriteBuffer( this->GetQueue(ctx_to), static_cast(to), - CL_FALSE, 0, size, from, + CL_FALSE, to_offset, size, + static_cast(from) + from_offset, 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_to))); } else { @@ -97,8 +135,9 @@ std::string GetDeviceInfo( std::vector GetPlatformIDs() { cl_uint ret_size; - OPENCL_CALL(clGetPlatformIDs(0, nullptr, &ret_size)); + cl_int code = clGetPlatformIDs(0, nullptr, &ret_size); std::vector ret; + if (code != CL_SUCCESS) return ret; ret.resize(ret_size); OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr)); return ret; @@ -108,11 +147,12 @@ std::vector GetDeviceIDs( cl_platform_id pid, std::string device_type) { cl_device_type dtype = CL_DEVICE_TYPE_ALL; if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU; - if (device_type == "gpu") dtype = CL_DEVICE_TYPE_CPU; + if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU; if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR; cl_uint ret_size; - OPENCL_CALL(clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size)); + cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size); std::vector ret; + if (code != CL_SUCCESS) return ret; ret.resize(ret_size); OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr)); return ret; @@ -127,70 +167,53 @@ bool MatchPlatformInfo( return param_value.find(value) != std::string::npos; } -bool InitOpenCL(TVMArgs args, TVMRetValue* rv) { - cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); - std::lock_guard(w->mu); - if (w->initialized()) return false; - // matching conditions - std::string platform_name, device_type; - - for (int i = 0; i < args.num_args; ++i) { - std::string arg = args[i]; - size_t pos = arg.find_first_of('='); - CHECK_EQ(pos, std::string::npos) - << "Argumentes need to be key=value"; - std::string key = arg.substr(0, pos); - std::string val = arg.substr(pos + 1, arg.length() - pos - 1); - if (key == "platform_name") { - platform_name = val; - } else if (key == "device_type") { - device_type = val; - } else { - LOG(FATAL) << "unknown DeviceInit option " << key; - } - } +void OpenCLWorkspace::Init() { + if (initialized_) return; + std::lock_guard(this->mu); + if (initialized_) return; + initialized_ = true; + if (context != nullptr) return; // matched platforms - std::vector platform_matched; - for (cl_platform_id pid : cl::GetPlatformIDs()) { - bool matched = true; - if (!cl::MatchPlatformInfo(pid, CL_PLATFORM_NAME, platform_name)) matched = false; - if (matched) platform_matched.push_back(pid); - } + std::vector platform_matched = cl::GetPlatformIDs(); if (platform_matched.size() == 0) { - LOG(FATAL) << "No OpenCL platform matched given existing options ..."; + LOG(WARNING) << "No OpenCL platform matched given existing options ..."; + return; } if (platform_matched.size() > 1) { LOG(WARNING) << "Multiple OpenCL platforms matched, use the first one ... "; } - w->platform_id = platform_matched[0]; - + this->platform_id = platform_matched[0]; LOG(INFO) << "Initialize OpenCL platform \'" - << cl::GetPlatformInfo(w->platform_id, CL_PLATFORM_NAME) << '\''; + << cl::GetPlatformInfo(this->platform_id, CL_PLATFORM_NAME) << '\''; std::vector devices_matched = - cl::GetDeviceIDs(w->platform_id, device_type); - CHECK_GT(devices_matched.size(), 0U) - << "No OpenCL device any device matched given the options"; - w->devices = devices_matched; + cl::GetDeviceIDs(this->platform_id, "gpu"); + if (devices_matched.size() == 0) { + LOG(WARNING) << "No OpenCL device any device matched given the options"; + return; + } + this->devices = devices_matched; cl_int err_code; - w->context = clCreateContext( - nullptr, w->devices.size(), &(w->devices[0]), + this->context = clCreateContext( + nullptr, this->devices.size(), &(this->devices[0]), nullptr, nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); - CHECK_EQ(w->queues.size(), 0U); - for (size_t i = 0; i < w->devices.size(); ++i) { - cl_device_id did = w->devices[i]; - w->queues.push_back( - clCreateCommandQueue(w->context, did, 0, &err_code)); + CHECK_EQ(this->queues.size(), 0U); + for (size_t i = 0; i < this->devices.size(); ++i) { + cl_device_id did = this->devices[i]; + this->queues.push_back( + clCreateCommandQueue(this->context, did, 0, &err_code)); OPENCL_CHECK_ERROR(err_code); LOG(INFO) << "opencl(" << i << ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME) << "\' cl_device_id=" << did; } - return true; } -TVM_REGISTER_GLOBAL("module.init_opencl") -.set_body(InitOpenCL); +bool InitOpenCL(TVMArgs args, TVMRetValue* rv) { + cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); + w->Init(); + return true; +} TVM_REGISTER_GLOBAL("device_api.opencl") .set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 89293c50..72c9550e 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -11,7 +11,7 @@ #include #include #include -#include "../void_addr_args.h" +#include "../pack_args.h" #include "../thread_storage_scope.h" #include "../meta_data.h" #include "../file_util.h" @@ -90,7 +90,8 @@ class OpenCLModuleNode : public ModuleNode { // Initialize the programs void InitProgram() { cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); - CHECK(w->initialized()); + w->Init(); + CHECK(w->context != nullptr) << "No OpenCL device"; if (fmt_ == "cl") { const char* s = data_.c_str(); size_t len = data_.length(); @@ -179,6 +180,7 @@ class OpenCLWrappedFunc { std::string func_name, std::vector arg_size, const std::vector& thread_axis_tags) { + w_ = cl::OpenCLWorkspace::Global(); m_ = m; sptr_ = sptr; entry_ = entry; @@ -190,9 +192,7 @@ class OpenCLWrappedFunc { void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { - cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global(); cl::OpenCLThreadEntry* t = cl::OpenCLThreadEntry::ThreadLocal(); - CHECK(w->initialized()); // get the kernel from thread local kernel table. if (entry_.kernel_id >= t->kernel_table.size()) { t->kernel_table.resize(entry_.kernel_id + 1); @@ -200,13 +200,13 @@ class OpenCLWrappedFunc { const auto& e = t->kernel_table[entry_.kernel_id]; cl_kernel kernel = e.kernel; if (kernel == nullptr || e.version != entry_.version) { - kernel = m_->InstallKernel(w, t, func_name_, entry_); + kernel = m_->InstallKernel(w_, t, func_name_, entry_); } // setup arguments. for (cl_uint i = 0; i < arg_size_.size(); ++i) { OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i])); } - cl_command_queue queue = w->GetQueue(t->context); + cl_command_queue queue = w_->GetQueue(t->context); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); cl_uint work_dim = static_cast(thread_axis_cfg_.work_dim()); for (cl_uint i = 0; i < work_dim; ++i) { @@ -221,6 +221,8 @@ class OpenCLWrappedFunc { } private: + // global workspace. + cl::OpenCLWorkspace* w_; // The module OpenCLModuleNode* m_; // resource handle @@ -235,45 +237,12 @@ class OpenCLWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; -/*! - * \brief Automatically detect and set cuda device. - * \param args The arguments. - */ -void AutoSetOpenCLDevice(const TVMArgs& args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 3); - TVMValue* values = static_cast(args[0].operator void*()); - int* type_codes = static_cast(args[1].operator void*()); - int num_args = args[2].operator int(); - - // TODO(tqchen): merge this with CUDA logic. - int device_id = -1; - for (int i = 0; i < num_args; ++i) { - if (type_codes[i] == kArrayHandle) { - TVMContext ctx = static_cast(values[i].v_handle)->ctx; - CHECK_EQ(ctx.device_type, kOpenCL) - << "All operands need to be OpenCL"; - if (device_id == -1) { - device_id = ctx.device_id; - } else { - CHECK_EQ(device_id, ctx.device_id) - << "Operands comes from different devices "; - } - } - } - CHECK_NE(device_id, -1) - << "Cannot detect device id from list"; - cl::OpenCLThreadEntry::ThreadLocal()->context.device_id = device_id; -} - PackedFunc OpenCLModuleNode::GetFunction( const std::string& name, const std::shared_ptr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; - if (name == symbol::tvm_entry_setdevice) { - return PackedFunc(AutoSetOpenCLDevice); - } auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -289,7 +258,7 @@ PackedFunc OpenCLModuleNode::GetFunction( // initialize the wrapped func. f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); - return PackFromVoidAddrArgs(f, info.arg_types); + return PackFuncVoidAddr(f, info.arg_types); } Module OpenCLModuleCreate( diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 8acced9f..85c50e3e 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -18,7 +18,7 @@ namespace runtime { /*! * \brief create a cuda module from data. * - * \param data The module data, can be ptx, cubin + * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. */ diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h new file mode 100644 index 00000000..7f6fab6f --- /dev/null +++ b/src/runtime/pack_args.h @@ -0,0 +1,233 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file pack_args.h + * \brief Utility to pack TVMArgs to other type-erased fution calling convention. + * + * Two type erased function signatures are supported. + * - cuda_style(void** args, int num_args); + * - Pack everything by address + * - metal_style(void** buffers, int num_buffers, + * union_32bit args[N], int num_args); + * - Pack buffer by address, pack rest parameter into 32bit union buffer. + */ +#ifndef TVM_RUNTIME_PACK_ARGS_H_ +#define TVM_RUNTIME_PACK_ARGS_H_ + +#include +#include + +namespace tvm { +namespace runtime { +/*! + * \brief argument union type of 32bit. + * Choose 32 bit because most GPU API do not work well with 64 bit. + */ +union ArgUnion { + int32_t v_int32; + uint32_t v_uint32; + float v_float32; +}; +/*! + * \brief Create a packed function from void addr types. + * + * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) + * \param arg_types The arguments that wish to get from + * \tparam T the function type + * + * \return The wrapped packed function. + */ +template +inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types); +/*! + * \brief Create a packed function that from function only packs buffer arguments. + * + * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args) + * \param arg_types The arguments that wish to get from + * \tparam T the function type + * + * \return The wrapped packed function. + */ +template +inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types); +/*! + * \brief Extract number of buffer argument from the argument types. + * \param arg_types The argument types. + * \return number of buffer arguments + */ +inline size_t NumBufferArgs(const std::vector& arg_types); + +// implementations details +namespace detail { +template +class TempArray { + public: + explicit TempArray(int size) {} + T* data() { + return data_; + } + private: + T data_[kSize]; +}; +template +class TempArray { + public: + explicit TempArray(int size) : data_(size) {} + T* data() { + return data_.data(); + } + private: + std::vector data_; +}; + +/*! \brief conversion code used in void arg. */ +enum ArgConvertCode { + INT64_TO_INT64, + INT64_TO_INT32, + INT64_TO_UINT32, + FLOAT64_TO_FLOAT32, + FLOAT64_TO_FLOAT64, + HANDLE_TO_HANDLE +}; + +inline ArgConvertCode GetArgConvertCode(TVMType t) { + CHECK_EQ(t.lanes, 1U) + << "Cannot pass vector type argument to devic function for now"; + if (t.code == kInt) { + if (t.bits == 64U) return INT64_TO_INT64; + if (t.bits == 32U) return INT64_TO_INT32; + } else if (t.code == kUInt) { + if (t.bits == 32U) return INT64_TO_UINT32; + } else if (t.code == kFloat) { + if (t.bits == 64U) return FLOAT64_TO_FLOAT64; + if (t.bits == 32U) return FLOAT64_TO_FLOAT32; + } else if (t.code == kHandle) { + return HANDLE_TO_HANDLE; + } + LOG(FATAL) << "Cannot handle " << t << " as device function argument"; + return HANDLE_TO_HANDLE; +} + +template +inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& codes) { + int num_args = static_cast(codes.size()); + auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { + TempArray addr_(num_args); + TempArray holder_(num_args); + void** addr = addr_.data(); + ArgUnion* holder = holder_.data(); + for (int i = 0; i < num_args; ++i) { + switch (codes[i]) { + case INT64_TO_INT64: + case FLOAT64_TO_FLOAT64: + case HANDLE_TO_HANDLE: { + addr[i] = (void*)&(args.values[i]); // NOLINT(*) + break; + } + case INT64_TO_INT32: { + holder[i].v_int32 = static_cast(args.values[i].v_int64); + addr[i] = &(holder[i]); + break; + } + case INT64_TO_UINT32 : { + holder[i].v_uint32 = static_cast(args.values[i].v_int64); + addr[i] = &(holder[i]); + break; + } + case FLOAT64_TO_FLOAT32: { + holder[i].v_float32 = static_cast(args.values[i].v_float64); + addr[i] = &(holder[i]); + break; + } + } + } + f(args, ret, addr); + }; + return PackedFunc(ret); +} + +template +inline PackedFunc PackFuncNonBufferArg_( + F f, int base, const std::vector& codes) { + int num_args = static_cast(codes.size()); + auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { + TempArray holder_(num_args); + ArgUnion* holder = holder_.data(); + for (int i = 0; i < num_args; ++i) { + switch (codes[i]) { + case INT64_TO_INT64: + case FLOAT64_TO_FLOAT64: { + LOG(FATAL) << "Donot support 64bit argument to device function"; break; + } + case INT64_TO_INT32: { + holder[i].v_int32 = static_cast(args.values[base + i].v_int64); + break; + } + case INT64_TO_UINT32 : { + holder[i].v_uint32 = static_cast(args.values[base + i].v_int64); + break; + } + case FLOAT64_TO_FLOAT32: { + holder[i].v_float32 = static_cast(args.values[base + i].v_float64); + break; + } + case HANDLE_TO_HANDLE: { + LOG(FATAL) << "not reached"; break; + } + } + } + f(args, ret, holder); + }; + return PackedFunc(ret); +} +} // namespace detail + +template +inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types) { + std::vector codes(arg_types.size()); + for (size_t i = 0; i < arg_types.size(); ++i) { + codes[i] = detail::GetArgConvertCode(arg_types[i]); + } + size_t num_void_args = arg_types.size(); + // specialization + if (num_void_args <= 4) { + return detail::PackFuncVoidAddr_<4>(f, codes); + } else if (num_void_args <= 8) { + return detail::PackFuncVoidAddr_<8>(f, codes); + } else { + return detail::PackFuncVoidAddr_<0>(f, codes); + } +} + +inline size_t NumBufferArgs(const std::vector& arg_types) { + size_t base = arg_types.size(); + for (size_t i = 0; i < arg_types.size(); ++i) { + if (arg_types[i].code != kHandle) { + base = i; break; + } + } + for (size_t i = base; i < arg_types.size(); ++i) { + CHECK(arg_types[i].code != kHandle) + << "Device function need to be organized"; + } + return base; +} + +template +inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types) { + size_t num_buffer = NumBufferArgs(arg_types); + std::vector codes; + for (size_t i = num_buffer; i < arg_types.size(); ++i) { + codes.push_back(detail::GetArgConvertCode(arg_types[i])); + } + int base = static_cast(num_buffer); + size_t nargs = codes.size(); + // specialization + if (nargs <= 4) { + return detail::PackFuncNonBufferArg_<4>(f, base, codes); + } else { + return detail::PackFuncNonBufferArg_<0>(f, base, codes); + } +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_PACK_ARGS_H_ diff --git a/src/runtime/void_addr_args.h b/src/runtime/void_addr_args.h deleted file mode 100644 index 6f627339..00000000 --- a/src/runtime/void_addr_args.h +++ /dev/null @@ -1,164 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file void_addr_args.h - * \brief Utility to convert TVMArgs to void* array type-erasure function call. - * - * Array of argument address is a typical way of type-erasure for functions. - * The function signiture looks like function(void** args, int num_args); - * Where args takes the address of each input. - */ -#ifndef TVM_RUNTIME_VOID_ADDR_ARGS_H_ -#define TVM_RUNTIME_VOID_ADDR_ARGS_H_ - -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Create a packed function from void addr types - * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) - * \param arg_types The arguments that wish to get from - * \tparam T the function type - * - * \return The wrapped packed function. - */ -template -inline PackedFunc PackFromVoidAddrArgs( - F f, const std::vector& arg_types); - -// implementations details -namespace detail { -/*! - * \brief void addr argument data content - * holder in case conversion is needed. - */ -union VoidArgHolder { - int32_t v_int32; - uint32_t v_uint32; - float v_float32; -}; - -template -class VoidAddrArray { - public: - explicit VoidAddrArray(int num_args) { - } - void** addr() { - return addr_; - } - VoidArgHolder* holder() { - return holder_; - } - - private: - void* addr_[MAX_NARG]; - VoidArgHolder holder_[MAX_NARG]; -}; - -template<> -class VoidAddrArray<0> { - public: - explicit VoidAddrArray(int num_args) - : addr_(num_args), holder_(num_args) { - } - void** addr() { - return addr_.data(); - } - VoidArgHolder* holder() { - return holder_.data(); - } - - private: - std::vector addr_; - std::vector holder_; -}; - -/*! \brief conversion code used in void arg. */ -enum VoidArgConvertCode { - INT64_TO_INT64, - INT64_TO_INT32, - INT64_TO_UINT32, - FLOAT64_TO_FLOAT32, - FLOAT64_TO_FLOAT64, - HANDLE_TO_HANDLE -}; - -template -inline PackedFunc PackFromVoidAddrArgs_( - F f, const std::vector& codes) { - int num_args = static_cast(codes.size()); - auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { - VoidAddrArray temp(num_args); - void** addr = temp.addr(); - VoidArgHolder* holder = temp.holder(); - for (int i = 0; i < num_args; ++i) { - switch (codes[i]) { - case INT64_TO_INT64: - case FLOAT64_TO_FLOAT64: - case HANDLE_TO_HANDLE: { - addr[i] = (void*)&(args.values[i]); // NOLINT(*) - break; - } - case INT64_TO_INT32: { - holder[i].v_int32 = static_cast(args.values[i].v_int64); - addr[i] = &(holder[i]); - break; - } - case INT64_TO_UINT32 : { - holder[i].v_uint32 = static_cast(args.values[i].v_int64); - addr[i] = &(holder[i]); - break; - } - case FLOAT64_TO_FLOAT32: { - holder[i].v_float32 = static_cast(args.values[i].v_float64); - addr[i] = &(holder[i]); - break; - } - } - } - f(args, ret, addr); - }; - return PackedFunc(ret); -} - -inline VoidArgConvertCode GetVoidArgConvertCode(TVMType t) { - CHECK_EQ(t.lanes, 1U); - if (t.code == kInt) { - if (t.bits == 64U) return INT64_TO_INT64; - if (t.bits == 32U) return INT64_TO_INT32; - } else if (t.code == kUInt) { - if (t.bits == 32U) return INT64_TO_UINT32; - } else if (t.code == kFloat) { - if (t.bits == 64U) return FLOAT64_TO_FLOAT64; - if (t.bits == 32U) return FLOAT64_TO_FLOAT32; - } else if (t.code == kHandle) { - return HANDLE_TO_HANDLE; - } - LOG(FATAL) << "Cannot handle " << t; - return HANDLE_TO_HANDLE; -} - -} // namespace detail - -template -inline PackedFunc PackFromVoidAddrArgs( - F f, const std::vector& arg_types) { - std::vector codes(arg_types.size()); - for (size_t i = 0; i < arg_types.size(); ++i) { - codes[i] = detail::GetVoidArgConvertCode(arg_types[i]); - } - size_t num_void_args = arg_types.size(); - // specialization - if (num_void_args <= 4) { - return detail::PackFromVoidAddrArgs_<4>(f, codes); - } else if (num_void_args <= 8) { - return detail::PackFromVoidAddrArgs_<8>(f, codes); - } else { - return detail::PackFromVoidAddrArgs_<0>(f, codes); - } -} -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_VOID_ADDR_ARGS_H_ diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index ab0b32e2..5bd3e4ec 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -17,6 +17,7 @@ def lower(s, args, name="mydot"): stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.Simplify(stmt) fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0) + fapi = tvm.ir_pass.LowerPackedCall(fapi) return fapi @@ -35,7 +36,7 @@ def test_dot(): fapi = lower(s, [A, B, C]) def verify(target): - if not tvm.codegen.enabled(target): + if not tvm.module.enabled(target): print("Target %s is not enabled" % target) return f = tvm.codegen.build_module(fapi, target) diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 860d98eb..5e4c7bf3 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -1,5 +1,6 @@ import tvm import numpy as np +import time def test_exp(): # graph @@ -8,21 +9,21 @@ def test_exp(): B = tvm.compute(A.shape, lambda *i: tvm.exp(A(*i)), name='B') s = tvm.create_schedule(B.op) # create iter var and assign them tags. - num_thread = 64 + num_thread = 8 bx, tx = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(bx, tvm.thread_axis("blockIdx.x")) s[B].bind(tx, tvm.thread_axis("threadIdx.x")) # one line to build the function. def check_device(device, host="stackvm"): - if not tvm.codegen.enabled(host): + if not tvm.module.enabled(host): return - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): return fexp = tvm.build(s, [A, B], device, host, name="myexp") - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) # launch the kernel. n = 1024 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) @@ -31,8 +32,6 @@ def test_exp(): np.testing.assert_allclose( b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5) - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() check_device("cuda", "llvm") check_device("opencl") @@ -46,7 +45,7 @@ def test_log_llvm(): # create iter var and assign them tags. bx, tx = s[B].split(B.op.axis[0], factor=32) # one line to build the function. - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return flog = tvm.build(s, [A, B], @@ -60,17 +59,26 @@ def test_log_llvm(): np.testing.assert_allclose( b.asnumpy(), np.log(a.asnumpy()), rtol=1e-5) +from tvm.contrib import nvcc_compiler + +@tvm.register_func +def tvm_callback_cuda_compile(code): + print(code) + ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"]) + return ptx def test_add(): # graph n = tvm.convert(1024) A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') - C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + bias = tvm.var("bias", dtype="float32") + scale = tvm.var("scale", dtype="float32") + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C') # schedule s = tvm.create_schedule(C.op) # create iter var and assign them tags. - num_thread = 256 + num_thread = 32 bx, x = s[C].split(C.op.axis[0], factor=num_thread*4) tx, x = s[C].split(x, nparts=num_thread) _, x = s[C].split(x, factor=4) @@ -80,26 +88,28 @@ def test_add(): # one line to build the function. def check_device(device): - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): print("skip because %s is not enabled.." % device) return - fadd = tvm.build(s, [A, B, C], + fadd = tvm.build(s, [A, B, C, bias, scale], device, name="myadd") - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) + print(fadd.imported_modules[0].get_source()) # launch the kernel. n = 1024 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) - fadd(a, b, c) + vbias = np.random.uniform() + vscale = np.random.uniform() + fadd(a, b, c, vbias, vscale) np.testing.assert_allclose( - c.asnumpy(), a.asnumpy() + b.asnumpy()) + c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6) - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() - check_device("cuda") check_device("opencl") + check_device("metal") + check_device("cuda") if __name__ == "__main__": diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index d5e0780d..c907d6ab 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -1,6 +1,13 @@ import tvm from tvm.contrib import nvcc_compiler +from tvm.contrib import metal_compiler import numpy as np +import time + +#@tvm.register_func +def tvm_callback_metal_compile(code): + lib = metal_compiler.compile_source(code) + return lib def test_gemm(): # graph @@ -63,15 +70,14 @@ def test_gemm(): s = s.normalize() # one line to build the function. - def check_device(device, host="stackvm"): - if not tvm.codegen.enabled(host): - return - if not tvm.codegen.enabled(device): + def check_device(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) return - f = tvm.build(s, [A, B, C], device, host, + f = tvm.build(s, [A, B, C], device, max_auto_unroll_step=max_auto_unroll_step) - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) # launch the kernel. n = nn m = n @@ -81,15 +87,20 @@ def test_gemm(): a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - for i in range(4): - f(a, b, c) + f(a, b, c) + ctx.sync() + tbegin = time.time() + f(a, b, c) + tpush = time.time() + ctx.sync() + tend = time.time() + print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin)) np.testing.assert_allclose( c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) - check_device("cuda") - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() + check_device("metal") check_device("opencl") + check_device("cuda") if __name__ == "__main__": test_gemm() diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index ecab552e..4d30c73c 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -19,16 +19,16 @@ def test_reduce_prims(): # one line to build the function. def check_device(device, host="stackvm"): - if not tvm.codegen.enabled(host): + if not tvm.module.enabled(host): return - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) return - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) freduce = tvm.build(s, args=[A, B], target=device, target_host=host, name="myreduce") - print(freduce.imported_modules[0].get_source()) # launch the kernel. n = 1028 m = 129 @@ -41,9 +41,7 @@ def test_reduce_prims(): res[:2] = 0 np.testing.assert_allclose(npy, res, rtol=1e-4) - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() - + check_device("metal") check_device("cuda") check_device("opencl") test_prim(tvm.sum, np.sum) @@ -64,7 +62,7 @@ def test_rfactor(): s[BF].parallel(BF.op.axis[0]) # one line to build the function. def check_target(target="llvm"): - if not tvm.codegen.enabled(target): + if not tvm.module.enabled(target): return ctx = tvm.cpu(0) fapi = tvm.lower(s, args=[A, B]) @@ -105,15 +103,14 @@ def test_rfactor_threads(): # one line to build the function. def check_target(device, host="stackvm"): - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) return - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) fapi = tvm.lower(s, args=[A, B]) - fapi2 = tvm.ir_pass.LowerThreadAllreduce(fapi, 32) fsum = tvm.build(fapi, target=device, name="mysum") - print(fsum.imported_modules[0].get_source()) # launch the kernel. n = nn m = mm @@ -125,9 +122,8 @@ def test_rfactor_threads(): np.testing.assert_allclose( b.asnumpy(), res, rtol=1e-4) - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() check_target("cuda") + check_target("metal") check_target("opencl") if __name__ == "__main__": diff --git a/tests/python/integration/test_scan.py b/tests/python/integration/test_scan.py index b7ba5171..2f9d29e9 100644 --- a/tests/python/integration/test_scan.py +++ b/tests/python/integration/test_scan.py @@ -23,15 +23,14 @@ def test_scan(): s[s_update].bind(xi, thread_x) # one line to build the function. - def check_device(device, host="stackvm"): - if not tvm.codegen.enabled(host): - return - if not tvm.codegen.enabled(device): + def check_device(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) return fscan = tvm.build(s, [X, res], - device, host, + device, name="myscan") - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) # launch the kernel. n = 1024 m = 10 @@ -42,10 +41,8 @@ def test_scan(): np.testing.assert_allclose( b.asnumpy(), np.cumsum(a_np, axis=0)) - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() - check_device("cuda") + check_device("metal") check_device("opencl") diff --git a/tests/python/perf/gemm_square.py b/tests/python/perf/gemm_square.py index 418427d9..87a85a1f 100644 --- a/tests/python/perf/gemm_square.py +++ b/tests/python/perf/gemm_square.py @@ -99,9 +99,9 @@ def test_gemm(): # correctness def check_device(device, host="stackvm"): - if not tvm.codegen.enabled(host): + if not tvm.module.enabled(host): return - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): return f = tvm.build(s, [A, B, C], device, host, max_auto_unroll_step=max_auto_unroll_step) diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 6a563cba..1e3c4a53 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -29,17 +29,16 @@ def test_add_pipeline(): fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0]) def check_target(device, host="stackvm"): - if not tvm.codegen.enabled(host): + if not tvm.module.enabled(host): return - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): return - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) mhost = tvm.codegen.build_module(fsplits[0], host) mdev = tvm.codegen.build_module(fsplits[1:], device) mhost.import_module(mdev) code = mdev.get_source() f = mhost.entry_func - # launch the kernel. n = 1027 a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) @@ -50,11 +49,11 @@ def test_add_pipeline(): c.asnumpy(), a.asnumpy() + b.asnumpy()) def check_module_save(device, host="stackvm"): - if not tvm.codegen.enabled(host): + if not tvm.module.enabled(host): return - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): return - ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + ctx = tvm.context(device, 0) fmt = "ptx" if device == "cuda" else "cl" mhost = tvm.codegen.build_module(fsplits[0], host) mdev = tvm.codegen.build_module(fsplits[1:], device) diff --git a/tests/python/unittest/test_codegen_extern.py b/tests/python/unittest/test_codegen_extern.py index f58278b9..c283772d 100644 --- a/tests/python/unittest/test_codegen_extern.py +++ b/tests/python/unittest/test_codegen_extern.py @@ -19,7 +19,7 @@ def test_add_pipeline(): s = tvm.create_schedule(C.op) def check_llvm(): - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return # build and invoke the kernel. f = tvm.build(s, [A, C], "llvm") @@ -51,7 +51,7 @@ def test_pack_buffer_simple(): def check_target(target): - if not tvm.codegen.enabled(target): + if not tvm.module.enabled(target): return # build and invoke the kernel. f = tvm.build(s, [A, C], target) @@ -81,7 +81,7 @@ def test_pack_buffer_intermediate(): s = tvm.create_schedule(C.op) def check_target(target): - if not tvm.codegen.enabled(target): + if not tvm.module.enabled(target): return # build and invoke the kernel. f = tvm.build(s, [A, C], target) diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 51d9f72f..8f7b54fa 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -12,7 +12,7 @@ def test_llvm_add_pipeline(): s[C].parallel(xo) s[C].vectorize(xi) def check_llvm(): - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return # build and invoke the kernel. f = tvm.build(s, [A, B, C], "llvm") @@ -30,7 +30,7 @@ def test_llvm_add_pipeline(): def test_llvm_flip_pipeline(): def check_llvm(nn, base): - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return n = tvm.convert(nn) A = tvm.placeholder((n + base), name='A') @@ -57,7 +57,7 @@ def test_llvm_flip_pipeline(): def test_llvm_madd_pipeline(): def check_llvm(nn, base, stride): - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return n = tvm.convert(nn) A = tvm.placeholder((n + base, stride), name='A') @@ -89,7 +89,7 @@ def test_llvm_temp_space(): s = tvm.create_schedule(C.op) def check_llvm(): - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return # build and invoke the kernel. f = tvm.build(s, [A, C], "llvm") diff --git a/tests/python/unittest/test_codegen_vm_basic.py b/tests/python/unittest/test_codegen_vm_basic.py index bcb0c536..82051bf9 100644 --- a/tests/python/unittest/test_codegen_vm_basic.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -3,7 +3,7 @@ import numpy as np def run_jit(fapi, check): for target in ["llvm", "stackvm"]: - if not tvm.codegen.enabled(target): + if not tvm.module.enabled(target): continue f = tvm.codegen.build_module(fapi, target) s = f.get_source() diff --git a/tests/python/unittest/test_module_load.py b/tests/python/unittest/test_module_load.py index 77831e8d..602cb25a 100644 --- a/tests/python/unittest/test_module_load.py +++ b/tests/python/unittest/test_module_load.py @@ -22,7 +22,7 @@ print("Finish runtime checking...") """ def test_dso_module_load(): - if not tvm.codegen.enabled("llvm"): + if not tvm.module.enabled("llvm"): return dtype = 'int64' temp = util.tempdir() @@ -38,6 +38,7 @@ def test_dso_module_load(): tvm.make.Load(dtype, Ab.data, i) + 1, i + 1)) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0) + fapi = tvm.ir_pass.LowerPackedCall(fapi) m = tvm.codegen.build_module(fapi, "llvm") for name in names: m.save(name) diff --git a/tests/python/unittest/test_runtime_ndarray.py b/tests/python/unittest/test_runtime_ndarray.py index 0dc6e3f1..c8d3020b 100644 --- a/tests/python/unittest/test_runtime_ndarray.py +++ b/tests/python/unittest/test_runtime_ndarray.py @@ -2,14 +2,14 @@ import tvm import numpy as np def enabled_ctx_list(): - if tvm.module.enabled("opencl"): - tvm.module.init_opencl() - ctx_list = [('cpu', tvm.cpu(0)), ('gpu', tvm.gpu(0)), ('cl', tvm.opencl(0)), - ('cpu', tvm.vpi(0))] - ctx_list = [x[1] for x in ctx_list if tvm.module.enabled(x[0])] + ('metal', tvm.metal(0)), + ('vpi', tvm.vpi(0))] + for k, v in ctx_list: + assert tvm.context(k, 0) == v + ctx_list = [x[1] for x in ctx_list if x[1].exist] return ctx_list ENABLED_CTX_LIST = enabled_ctx_list() @@ -29,7 +29,8 @@ def test_nd_create(): np.testing.assert_equal(x, y.asnumpy()) np.testing.assert_equal(x, z.asnumpy()) # no need here, just to test usablity - tvm.nd.sync(ctx) + ctx.sync() + if __name__ == "__main__": test_nd_create() diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index cce0c5a0..30652a0f 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -18,16 +18,17 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then fi cp make/config.mk config.mk -echo "USE_CUDA=0" >> config.mk +echo "ENABLE_CUDA=0" >> config.mk if [ ${TRAVIS_OS_NAME} == "osx" ]; then - echo "USE_OPENCL=1" >> config.mk + echo "ENABLE_OPENCL=1" >> config.mk + echo "ENABLE_METAL=1" >> config.mk else # use g++-4.8 for linux if [ ${CXX} == "g++" ]; then export CXX=g++-4.8 fi - echo "USE_OPENCL=0" >> config.mk + echo "ENABLE_OPENCL=0" >> config.mk fi if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then diff --git a/tests/verilog/integration/test_codegen_verilog.py b/tests/verilog/integration/test_codegen_verilog.py index b53fd25c..97589a18 100644 --- a/tests/verilog/integration/test_codegen_verilog.py +++ b/tests/verilog/integration/test_codegen_verilog.py @@ -40,13 +40,12 @@ def test_add_pipeline(): fapi = lower(s, [A, B, C], "myadd") fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0]) - print(fsplits[1].body) print("------") def check_target(device, host="stackvm"): - if not tvm.codegen.enabled(host): + if not tvm.module.enabled(host): return - if not tvm.codegen.enabled(device): + if not tvm.module.enabled(device): return ctx = tvm.vpi(0) mhost = tvm.codegen.build_module(fsplits[0], host) diff --git a/tutorials/python/get_started.py b/tutorials/python/get_started.py index c9a00b65..96e82af1 100644 --- a/tutorials/python/get_started.py +++ b/tutorials/python/get_started.py @@ -228,7 +228,6 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) # The following codeblocks generate opencl code, creates array on opencl # device, and verifies the correctness of the code. # -tvm.module.init_opencl() fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd") print("------opencl code------") print(fadd_cl.imported_modules[0].get_source()) diff --git a/tutorials/python/intrin_math.py b/tutorials/python/intrin_math.py index e6ae438b..4ce7a24e 100644 --- a/tutorials/python/intrin_math.py +++ b/tutorials/python/intrin_math.py @@ -65,7 +65,6 @@ print(fcuda.imported_modules[0].get_source()) # We can find that the code works for both CUDA and opencl. # The same tvm.exp can also be used for float64 data types. # -tvm.module.init_opencl() fopencl = tvm.build(s, [A, B], "opencl", name="myexp") print(fopencl.imported_modules[0].get_source())