[CODEGEN/RUNTIME] Metal support, runtime improvement. (#111)
* [CODEGEN/RUNTIME] Metal support, runtime improvement. * Fix case when no device is available
This commit is contained in:
Родитель
9ba40dc0fe
Коммит
706f9b6f7e
29
Makefile
29
Makefile
|
@ -18,7 +18,9 @@ all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
|
|||
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a
|
||||
|
||||
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
|
||||
METAL_SRC = $(wildcard src/runtime/metal/*.mm)
|
||||
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
|
||||
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
|
||||
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
|
||||
|
||||
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
|
||||
|
@ -29,6 +31,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
|
|||
export LDFLAGS = -pthread -lm
|
||||
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
|
||||
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
|
||||
export OBJCFLAGS= -fobjc-arc
|
||||
|
||||
ifdef CUDA_PATH
|
||||
NVCC=$(CUDA_PATH)/bin/nvcc
|
||||
|
@ -36,7 +39,7 @@ ifdef CUDA_PATH
|
|||
LDFLAGS += -L$(CUDA_PATH)/lib64
|
||||
endif
|
||||
|
||||
ifeq ($(USE_CUDA), 1)
|
||||
ifeq ($(ENABLE_CUDA), 1)
|
||||
CFLAGS += -DTVM_CUDA_RUNTIME=1
|
||||
LDFLAGS += -lcuda -lcudart -lnvrtc
|
||||
else
|
||||
|
@ -45,9 +48,10 @@ endif
|
|||
|
||||
FRAMEWORKS=
|
||||
|
||||
ifeq ($(USE_OPENCL), 1)
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
ifeq ($(ENABLE_OPENCL), 1)
|
||||
CFLAGS += -DTVM_OPENCL_RUNTIME=1
|
||||
UNAME_S := $(shell uname -s)
|
||||
ifeq ($(UNAME_S), Darwin)
|
||||
FRAMEWORKS += -framework OpenCL
|
||||
else
|
||||
|
@ -57,10 +61,20 @@ else
|
|||
CFLAGS += -DTVM_OPENCL_RUNTIME=0
|
||||
endif
|
||||
|
||||
ifeq ($(ENABLE_METAL), 1)
|
||||
CFLAGS += -DTVM_METAL_RUNTIME=1
|
||||
LDFLAGS += -lObjc
|
||||
ALL_DEP += $(METAL_OBJ)
|
||||
RUNTIME_DEP += $(METAL_OBJ)
|
||||
FRAMEWORKS += -framework Metal -framework Foundation
|
||||
else
|
||||
CFLAGS += -DTVM_METAL_RUNTIME=0
|
||||
endif
|
||||
|
||||
# llvm configuration
|
||||
LLVM_CONFIG=llvm-config
|
||||
|
||||
ifeq ($(USE_LLVM), 1)
|
||||
ifeq ($(ENABLE_LLVM), 1)
|
||||
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
|
||||
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
|
||||
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
|
||||
|
@ -87,6 +101,11 @@ build/%.o: src/%.cc
|
|||
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
|
||||
$(CXX) -c $(CFLAGS) -c $< -o $@
|
||||
|
||||
build/%.o: src/%.mm
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
|
||||
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
|
||||
|
||||
lib/libtvm.so: $(ALL_DEP)
|
||||
@mkdir -p $(@D)
|
||||
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
|
||||
|
@ -105,7 +124,7 @@ LIBHALIDEIR:
|
|||
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
|
||||
|
||||
cpplint:
|
||||
python2 dmlc-core/scripts/lint.py tvm cpp include src verilog
|
||||
python dmlc-core/scripts/lint.py tvm cpp include src verilog
|
||||
|
||||
pylint:
|
||||
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
|
||||
|
|
|
@ -12,4 +12,5 @@ tvm.ndarray
|
|||
.. autofunction:: tvm.cpu
|
||||
.. autofunction:: tvm.gpu
|
||||
.. autofunction:: tvm.opencl
|
||||
.. autofunction:: tvm.metal
|
||||
.. autofunction:: tvm.ndarray.array
|
||||
|
|
|
@ -31,13 +31,6 @@ using runtime::TVMRetValue;
|
|||
*/
|
||||
runtime::Module Build(const Array<LoweredFunc>& funcs,
|
||||
const std::string& target);
|
||||
|
||||
/*!
|
||||
* \param target The target to be queried.
|
||||
* \return Whether target is enabled.
|
||||
*/
|
||||
bool TargetEnabled(const std::string& target);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
|
|
|
@ -41,6 +41,8 @@ typedef int64_t tvm_index_t;
|
|||
|
||||
/*! \brief Extension device types in TVM */
|
||||
typedef enum {
|
||||
/*! \brief Metal buffer. */
|
||||
kMetal = 8,
|
||||
/*! \brief Simulated on board RAM */
|
||||
kVPI = 9
|
||||
} TVMDeviceExtType;
|
||||
|
@ -360,7 +362,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
|
|||
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
|
||||
const char*** out_array);
|
||||
|
||||
// Array related apis for quick proptying
|
||||
// Array related apis for quick proptyping
|
||||
/*!
|
||||
* \brief Allocate a nd-array's memory,
|
||||
* including space of shape, of given spec.
|
||||
|
|
|
@ -20,4 +20,11 @@
|
|||
#define TVM_OPENCL_RUNTIME 0
|
||||
#endif
|
||||
|
||||
/*!
|
||||
*\brief whether to use metal runtime
|
||||
*/
|
||||
#ifndef TVM_METAL_RUNTIME
|
||||
#define TVM_METAL_RUNTIME 0
|
||||
#endif
|
||||
|
||||
#endif // TVM_RUNTIME_CONFIG_H_
|
||||
|
|
|
@ -145,8 +145,8 @@ class ModuleNode {
|
|||
namespace symbol {
|
||||
/*! \brief Global variable to store module context. */
|
||||
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
|
||||
/*! \brief Local function to set the device during API entry. */
|
||||
constexpr const char* tvm_entry_setdevice = "__tvm_entry_setdevice";
|
||||
/*! \brief global function to set device */
|
||||
constexpr const char* tvm_set_device = "__tvm_set_device";
|
||||
/*! \brief Auxiliary counter to global barrier. */
|
||||
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
|
||||
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
|
||||
|
|
|
@ -34,16 +34,19 @@ ADD_CFLAGS =
|
|||
# matrix computation libraries for CPU/GPU
|
||||
#---------------------------------------------
|
||||
|
||||
# whether use CUDA during compile
|
||||
USE_CUDA = 1
|
||||
# whether enable CUDA during compile
|
||||
ENABLE_CUDA = 1
|
||||
|
||||
# whether use OpenCL during compile
|
||||
USE_OPENCL = 0
|
||||
# whether enable OpenCL during compile
|
||||
ENABLE_OPENCL = 0
|
||||
|
||||
# whether enable Metal during compile
|
||||
ENABLE_METAL = 0
|
||||
|
||||
# whether build with LLVM support
|
||||
# This requires llvm-config to be in your PATH
|
||||
# Requires LLVM version >= 4.0
|
||||
USE_LLVM = 0
|
||||
ENABLE_LLVM = 0
|
||||
|
||||
# add the path to CUDA library to link and compile flag
|
||||
# if you have already add them to environment variable.
|
||||
|
|
|
@ -16,7 +16,7 @@ from . import node
|
|||
from . import ir_builder
|
||||
|
||||
from . import ndarray as nd
|
||||
from .ndarray import cpu, gpu, opencl, cl, vpi
|
||||
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi
|
||||
|
||||
from ._ffi.function import Function
|
||||
from ._ffi.base import TVMError, __version__
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
from __future__ import absolute_import
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from .base import _LIB, check_call, c_array
|
||||
from .base import _LIB, check_call, c_array, string_types
|
||||
from .. import _api_internal
|
||||
|
||||
tvm_shape_index_t = ctypes.c_int64
|
||||
|
||||
|
@ -63,22 +64,62 @@ class TVMType(ctypes.Structure):
|
|||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class TVMContext(ctypes.Structure):
|
||||
"""TVM context strucure."""
|
||||
_fields_ = [("device_id", ctypes.c_int),
|
||||
("device_type", ctypes.c_int)]
|
||||
|
||||
MASK2STR = {
|
||||
1 : 'cpu',
|
||||
2 : 'gpu',
|
||||
4 : 'opencl',
|
||||
8 : 'metal',
|
||||
9 : 'vpi'
|
||||
}
|
||||
def __init__(self, device_id, device_type):
|
||||
STR2MASK = {
|
||||
'cpu': 1,
|
||||
'gpu': 2,
|
||||
'cuda': 2,
|
||||
'cl': 4,
|
||||
'opencl': 4,
|
||||
'metal': 8,
|
||||
'vpi': 9
|
||||
}
|
||||
def __init__(self, device_type, device_id):
|
||||
super(TVMContext, self).__init__()
|
||||
self.device_id = device_id
|
||||
self.device_type = device_type
|
||||
|
||||
@property
|
||||
def exist(self):
|
||||
"""Whether this device exist."""
|
||||
return _api_internal._GetDeviceAttr(
|
||||
self.device_type, self.device_id, 0) != 0
|
||||
|
||||
@property
|
||||
def max_threads_per_block(self):
|
||||
"""Maximum number of threads on each block."""
|
||||
return _api_internal._GetDeviceAttr(
|
||||
self.device_type, self.device_id, 1)
|
||||
|
||||
@property
|
||||
def warp_size(self):
|
||||
"""Number of threads that executes in concurrent."""
|
||||
return _api_internal._GetDeviceAttr(
|
||||
self.device_type, self.device_id, 2)
|
||||
|
||||
def sync(self):
|
||||
"""Synchronize until jobs finished at the context."""
|
||||
check_call(_LIB.TVMSynchronize(self, None))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, TVMContext) and
|
||||
self.device_id == other.device_id and
|
||||
self.device_type == other.device_type)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%d)" % (
|
||||
TVMContext.MASK2STR[self.device_type], self.device_id)
|
||||
|
@ -97,48 +138,38 @@ class TVMArray(ctypes.Structure):
|
|||
|
||||
TVMArrayHandle = ctypes.POINTER(TVMArray)
|
||||
|
||||
|
||||
def cpu(dev_id=0):
|
||||
"""Construct a CPU device
|
||||
def context(dev_type, dev_id=0):
|
||||
"""Construct a TVM context with given device type and id.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_type: int or str
|
||||
The device type mask or name of the device.
|
||||
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
|
||||
Returns
|
||||
-------
|
||||
ctx: TVMContext
|
||||
The corresponding context.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Context can be used to create reflection of context by
|
||||
string representation of the device type.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
assert tvm.context("cpu", 1) == tvm.cpu(1)
|
||||
assert tvm.context("gpu", 0) == tvm.gpu(0)
|
||||
assert tvm.context("cuda", 0) == tvm.gpu(0)
|
||||
"""
|
||||
return TVMContext(dev_id, 1)
|
||||
|
||||
|
||||
def gpu(dev_id=0):
|
||||
"""Construct a CPU device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(dev_id, 2)
|
||||
|
||||
|
||||
def opencl(dev_id=0):
|
||||
"""Construct a OpenCL device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(dev_id, 4)
|
||||
|
||||
def vpi(dev_id=0):
|
||||
"""Construct a VPI simulated device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(dev_id, 9)
|
||||
if isinstance(dev_type, string_types):
|
||||
if not dev_type in TVMContext.STR2MASK:
|
||||
raise ValueError("Unknown device type %s" % dev_type)
|
||||
dev_type = TVMContext.STR2MASK[dev_type]
|
||||
return TVMContext(dev_type, dev_id)
|
||||
|
||||
|
||||
def numpyasarray(np_data):
|
||||
|
@ -154,10 +185,11 @@ def numpyasarray(np_data):
|
|||
arr.dtype = TVMType(np.dtype(data.dtype).name)
|
||||
arr.ndim = data.ndim
|
||||
# CPU device
|
||||
arr.ctx = cpu(0)
|
||||
arr.ctx = context(1, 0)
|
||||
return arr, shape
|
||||
|
||||
def empty(shape, dtype="float32", ctx=cpu(0)):
|
||||
|
||||
def empty(shape, dtype="float32", ctx=context(1, 0)):
|
||||
"""Create an empty array given shape and device
|
||||
|
||||
Parameters
|
||||
|
@ -185,17 +217,6 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
|
|||
return _CLASS_NDARRAY(handle)
|
||||
|
||||
|
||||
def sync(ctx):
|
||||
"""Synchronize all the context
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : TVMContext
|
||||
The context to be synced
|
||||
"""
|
||||
check_call(_LIB.TVMSynchronize(ctx, None))
|
||||
|
||||
|
||||
class NDArrayBase(object):
|
||||
"""A simple Device/CPU Array object in runtime."""
|
||||
__slots__ = ["handle", "is_view"]
|
||||
|
|
|
@ -10,6 +10,7 @@ from . import schedule
|
|||
from . import expr
|
||||
from . import ir_pass
|
||||
from . import collections
|
||||
from . import module
|
||||
from . import codegen
|
||||
|
||||
|
||||
|
@ -149,7 +150,7 @@ def build(sch,
|
|||
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0])
|
||||
if len(fsplits) > 1:
|
||||
if not target_host:
|
||||
target_host = "llvm" if codegen.enabled("llvm") else "stackvm"
|
||||
target_host = "llvm" if module.enabled("llvm") else "stackvm"
|
||||
mhost = codegen.build_module(fsplits[0], target_host)
|
||||
if target:
|
||||
mdev = codegen.build_module(fsplits[1:], target)
|
||||
|
|
|
@ -19,21 +19,4 @@ def build_module(lowered_func, target):
|
|||
"""
|
||||
return _Build(lowered_func, target)
|
||||
|
||||
|
||||
def enabled(target):
|
||||
"""Whether target is enabled for codegen.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
The target module type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
enabled : boolean
|
||||
Whether the target module is enabled.
|
||||
"""
|
||||
return _Enabled(target)
|
||||
|
||||
|
||||
_init_api("tvm.codegen")
|
||||
|
|
|
@ -24,6 +24,8 @@ def create_shared(path_target, objects,
|
|||
"""
|
||||
cmd = [cc]
|
||||
cmd += ["-shared"]
|
||||
if sys.platform == "darwin":
|
||||
cmd += ["-undefined", "dynamic_lookup"]
|
||||
cmd += ["-o", path_target]
|
||||
cmd += objects
|
||||
if options:
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""Utility to invoke metal compiler in the CLI system"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import sys
|
||||
import subprocess
|
||||
from . import util
|
||||
|
||||
def compile_source(code, path_target=None):
|
||||
"""Compile metal with CLI tool from env.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code : str
|
||||
The cuda code.
|
||||
|
||||
path_target : str, optional
|
||||
Output file.
|
||||
|
||||
Return
|
||||
------
|
||||
metallib : bytearray
|
||||
The bytearray of the metallib
|
||||
"""
|
||||
temp = util.tempdir()
|
||||
temp_code = temp.relpath("my_lib.metal")
|
||||
temp_ir = temp.relpath("my_lib.air")
|
||||
temp_target = temp.relpath("my_lib.metallib")
|
||||
|
||||
with open(temp_code, "w") as out_file:
|
||||
out_file.write(code)
|
||||
file_target = path_target if path_target else temp_target
|
||||
|
||||
cmd1 = ["xcrun", "-sdk", "macosx", "metal", "-O3"]
|
||||
cmd1 += [temp_code, "-o", temp_ir]
|
||||
cmd2 = ["xcrun", "-sdk", "macosx", "metallib"]
|
||||
cmd2 += [temp_ir, "-o", file_target]
|
||||
proc = subprocess.Popen(
|
||||
' '.join(cmd1) + ";" + ' '.join(cmd2),
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
(out, _) = proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
sys.stderr.write("Compilation error:\n")
|
||||
sys.stderr.write(out)
|
||||
sys.stderr.flush()
|
||||
libbin = None
|
||||
else:
|
||||
libbin = bytearray(open(file_target, "rb").read())
|
||||
return libbin
|
|
@ -8,11 +8,9 @@ from __future__ import absolute_import as _abs
|
|||
import numpy as _np
|
||||
|
||||
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
|
||||
from ._ffi.ndarray import cpu, gpu, opencl, vpi, empty, sync
|
||||
from ._ffi.ndarray import context, empty
|
||||
from ._ffi.ndarray import _set_class_ndarray
|
||||
|
||||
cl = opencl
|
||||
|
||||
class NDArray(NDArrayBase):
|
||||
"""Lightweight NDArray class of TVM runtime.
|
||||
|
||||
|
@ -27,6 +25,64 @@ class NDArray(NDArrayBase):
|
|||
pass
|
||||
|
||||
|
||||
def cpu(dev_id=0):
|
||||
"""Construct a CPU device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(1, dev_id)
|
||||
|
||||
|
||||
def gpu(dev_id=0):
|
||||
"""Construct a CPU device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(2, dev_id)
|
||||
|
||||
|
||||
def opencl(dev_id=0):
|
||||
"""Construct a OpenCL device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(4, dev_id)
|
||||
|
||||
|
||||
def metal(dev_id=0):
|
||||
"""Construct a metal device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(8, dev_id)
|
||||
|
||||
|
||||
def vpi(dev_id=0):
|
||||
"""Construct a VPI simulated device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
"""
|
||||
return TVMContext(9, dev_id)
|
||||
|
||||
cl = opencl
|
||||
mtl = metal
|
||||
|
||||
|
||||
def array(arr, ctx=cpu(0)):
|
||||
"""Create an array from source arr.
|
||||
|
||||
|
|
|
@ -20,10 +20,5 @@ TVM_REGISTER_API("codegen._Build")
|
|||
*ret = Build(args[0], args[1]);
|
||||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("codegen._Enabled")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = TargetEnabled(args[0]);
|
||||
});
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Common build utilities
|
||||
* \file build_common.h
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_BUILD_COMMON_H_
|
||||
#define TVM_CODEGEN_BUILD_COMMON_H_
|
||||
|
||||
#include <tvm/codegen.h>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include "../runtime/meta_data.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
// Extract function information from device function.
|
||||
inline std::unordered_map<std::string, runtime::FunctionInfo>
|
||||
ExtractFuncInfo(const Array<LoweredFunc>& funcs) {
|
||||
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
|
||||
for (LoweredFunc f : funcs) {
|
||||
runtime::FunctionInfo info;
|
||||
for (size_t i = 0; i < f->args.size(); ++i) {
|
||||
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
|
||||
}
|
||||
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
|
||||
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
|
||||
}
|
||||
fmap[f->name] = info;
|
||||
}
|
||||
return fmap;
|
||||
}
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_CODEGEN_BUILD_COMMON_H_
|
|
@ -6,11 +6,10 @@
|
|||
#include <tvm/base.h>
|
||||
#include <tvm/runtime/config.h>
|
||||
#include "./codegen_cuda.h"
|
||||
#include "./build_common.h"
|
||||
|
||||
#if TVM_CUDA_RUNTIME
|
||||
|
||||
#include <nvrtc.h>
|
||||
#include "../runtime/meta_data.h"
|
||||
#include "../runtime/cuda/cuda_common.h"
|
||||
#include "../runtime/cuda/cuda_module.h"
|
||||
|
||||
|
@ -71,19 +70,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
|
|||
} else {
|
||||
ptx = NVRTCCompile(code);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
|
||||
for (LoweredFunc f : funcs) {
|
||||
runtime::FunctionInfo info;
|
||||
for (size_t i = 0; i < f->args.size(); ++i) {
|
||||
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
|
||||
}
|
||||
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
|
||||
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
|
||||
}
|
||||
fmap[f->name] = info;
|
||||
}
|
||||
return CUDAModuleCreate(ptx, fmt, fmap, code);
|
||||
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_cuda")
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Build metal modules from source.
|
||||
* \file build_metal.cc
|
||||
*/
|
||||
#include <tvm/base.h>
|
||||
#include <tvm/runtime/config.h>
|
||||
#include "./codegen_metal.h"
|
||||
#include "./build_common.h"
|
||||
|
||||
#if TVM_METAL_RUNTIME
|
||||
#include "../runtime/metal/metal_module.h"
|
||||
#endif // TVM_METAL_RUNTIME
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
|
||||
using tvm::runtime::Registry;
|
||||
bool output_ssa = false;
|
||||
CodeGenMetal cg;
|
||||
cg.Init(output_ssa);
|
||||
for (LoweredFunc f : funcs) {
|
||||
cg.AddFunction(f);
|
||||
}
|
||||
std::string code = cg.Finish();
|
||||
#if TVM_METAL_RUNTIME
|
||||
std::string fmt = "metal";
|
||||
std::string source = "";
|
||||
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
|
||||
source = code;
|
||||
code = (*f)(code).operator std::string();
|
||||
fmt = "metallib";
|
||||
}
|
||||
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
|
||||
#else
|
||||
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
|
||||
return SourceModuleCreate(code, "metal");
|
||||
#endif // TVM_METAL_RUNTIME
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_metal")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = BuildMetal(args[0]);
|
||||
});
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
|
@ -6,39 +6,29 @@
|
|||
#include <tvm/base.h>
|
||||
#include <tvm/runtime/config.h>
|
||||
#include "./codegen_opencl.h"
|
||||
#include "./build_common.h"
|
||||
|
||||
#if TVM_OPENCL_RUNTIME
|
||||
|
||||
#include "../runtime/meta_data.h"
|
||||
#include "../runtime/opencl/opencl_common.h"
|
||||
#include "../runtime/opencl/opencl_module.h"
|
||||
#endif // TVM_OPENCL_RUNTIME
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
|
||||
std::ostringstream os;
|
||||
bool output_ssa = false;
|
||||
CodeGenOpenCL cg;
|
||||
cg.Init(output_ssa);
|
||||
|
||||
for (LoweredFunc f : funcs) {
|
||||
cg.AddFunction(f);
|
||||
}
|
||||
std::string code = cg.Finish();
|
||||
|
||||
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
|
||||
for (LoweredFunc f : funcs) {
|
||||
runtime::FunctionInfo info;
|
||||
for (size_t i = 0; i < f->args.size(); ++i) {
|
||||
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
|
||||
}
|
||||
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
|
||||
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
|
||||
}
|
||||
fmap[f->name] = info;
|
||||
}
|
||||
return OpenCLModuleCreate(code, "cl", fmap);
|
||||
#if TVM_OPENCL_RUNTIME
|
||||
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
|
||||
#else
|
||||
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
|
||||
return SourceModuleCreate(code, "cl");
|
||||
#endif // TVM_OPENCL_RUNTIME
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_opencl")
|
||||
|
@ -47,4 +37,3 @@ TVM_REGISTER_API("codegen.build_opencl")
|
|||
});
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_OPENCL_RUNTIME
|
||||
|
|
|
@ -32,10 +32,5 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
|
|||
return m;
|
||||
}
|
||||
|
||||
bool TargetEnabled(const std::string& target) {
|
||||
std::string build_f_name = "codegen.build_" + target;
|
||||
return runtime::Registry::Get(build_f_name) != nullptr;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
|
|
@ -24,7 +24,6 @@ void CodeGenC::InitFuncState(LoweredFunc f) {
|
|||
void CodeGenC::AddFunction(LoweredFunc f) {
|
||||
// clear previous generated state.
|
||||
this->InitFuncState(f);
|
||||
|
||||
// skip the first underscore, so SSA variable starts from _1
|
||||
GetUniqueName("_");
|
||||
// add to alloc buffer type.
|
||||
|
@ -41,8 +40,8 @@ void CodeGenC::AddFunction(LoweredFunc f) {
|
|||
auto it = alloc_storage_scope_.find(v.get());
|
||||
if (it != alloc_storage_scope_.end()) {
|
||||
PrintStorageScope(it->second, stream);
|
||||
stream << ' ';
|
||||
}
|
||||
stream << ' ';
|
||||
}
|
||||
if (handle_data_type_.count(v.get())) {
|
||||
PrintType(handle_data_type_.at(v.get()), stream);
|
||||
|
@ -246,9 +245,9 @@ void CodeGenC::PrintVecStore(const Variable* buffer,
|
|||
stream << ref << " = " << value << ";\n";
|
||||
}
|
||||
|
||||
void CodeGenC::PrintThreadIndexExpr(
|
||||
std::string thread_tag, std::ostream& os) { // NOLINT(*)
|
||||
os << thread_tag;
|
||||
void CodeGenC::BindThreadIndex(const IterVar& iv) {
|
||||
CHECK(!var_idmap_.count(iv->var.get()));
|
||||
var_idmap_[iv->var.get()] = iv->thread_tag;
|
||||
}
|
||||
|
||||
void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
|
||||
|
@ -674,6 +673,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
|
|||
const Variable* buffer = op->buffer_var.as<Variable>();
|
||||
std::string scope = alloc_storage_scope_.at(buffer);
|
||||
PrintStorageScope(scope, stream);
|
||||
stream << ' ';
|
||||
PrintType(op->type, stream);
|
||||
stream << ' '<< vid << '['
|
||||
<< constant_size << "];\n";
|
||||
|
@ -687,13 +687,7 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
|
|||
IterVar iv(op->node.node_);
|
||||
if (iv->thread_tag.length() != 0) {
|
||||
if (!var_idmap_.count(iv->var.get())) {
|
||||
this->PrintIndent();
|
||||
PrintType(iv->var.type(), stream);
|
||||
stream << ' '
|
||||
<< AllocVarID(iv->var.get())
|
||||
<< " = ";
|
||||
PrintThreadIndexExpr(iv->thread_tag, stream);
|
||||
stream << ";\n";
|
||||
BindThreadIndex(iv);
|
||||
}
|
||||
}
|
||||
} else if (op->attr_key == ir::attr::storage_scope) {
|
||||
|
@ -740,7 +734,11 @@ void CodeGenC::VisitStmt_(const For* op) {
|
|||
void CodeGenC::VisitStmt_(const IfThenElse* op) {
|
||||
std::string cond = PrintExpr(op->condition);
|
||||
PrintIndent();
|
||||
stream << "if (" << cond << ") {\n";
|
||||
if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
|
||||
stream << "if " << cond << " {\n";
|
||||
} else {
|
||||
stream << "if (" << cond << ") {\n";
|
||||
}
|
||||
int then_scope = BeginScope();
|
||||
PrintStmt(op->then_case);
|
||||
this->EndScope(then_scope);
|
||||
|
|
|
@ -121,12 +121,10 @@ class CodeGenC :
|
|||
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
|
||||
/*!
|
||||
* \brief Print expr representing the thread tag
|
||||
* \param tag The tag in the thread.
|
||||
* \param os The strean to output to
|
||||
* \param IterVar iv The thread index to be binded;
|
||||
*/
|
||||
virtual void PrintThreadIndexExpr(
|
||||
std::string tag, std::ostream& os); // NOLINT(*)
|
||||
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
|
||||
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
|
||||
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
|
||||
virtual void PrintStorageSync(const Call* op); // NOLINT(*)
|
||||
// Binary vector op.
|
||||
virtual void PrintVecBinaryOp(
|
||||
|
@ -169,12 +167,12 @@ class CodeGenC :
|
|||
const std::string& target, const std::string& src, Type t) final;
|
||||
/*! \brief the storage scope of allocation */
|
||||
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
|
||||
/*! \brief the data type of allocated buffers */
|
||||
std::unordered_map<const Variable*, Type> handle_data_type_;
|
||||
|
||||
private:
|
||||
/*! \brief whether to print in SSA form */
|
||||
bool print_ssa_form_{false};
|
||||
/*! \brief the data type of allocated buffers */
|
||||
std::unordered_map<const Variable*, Type> handle_data_type_;
|
||||
/*! \brief set of volatile buf access */
|
||||
std::unordered_set<const Variable*> volatile_buf_;
|
||||
};
|
||||
|
|
|
@ -141,7 +141,9 @@ void CodeGenCUDA::PrintVecElemStore(
|
|||
|
||||
void CodeGenCUDA::PrintStorageSync(const Call* op) {
|
||||
const std::string& sync = op->args[0].as<StringImm>()->value;
|
||||
if (sync == "shared") {
|
||||
if (sync == "warp") {
|
||||
// DO nothing.
|
||||
} else if (sync == "shared") {
|
||||
this->PrintIndent();
|
||||
this->stream << "__syncthreads();\n";
|
||||
} else if (sync == "global") {
|
||||
|
@ -182,7 +184,7 @@ void CodeGenCUDA::PrintStorageScope(
|
|||
const std::string& scope, std::ostream& os) { // NOLINT(*)
|
||||
CHECK_NE(scope, "global");
|
||||
if (scope == "shared") {
|
||||
os << "__shared__ ";
|
||||
os << "__shared__";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,6 +205,17 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
|
|||
}
|
||||
}
|
||||
|
||||
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
|
||||
std::string v = PrintExpr(op->value);
|
||||
os << "make_";
|
||||
PrintType(op->type, os);
|
||||
os << "(";
|
||||
for (int i = 0; i < op->lanes; ++i) {
|
||||
if (i != 0) os << ", ";
|
||||
os << v;
|
||||
}
|
||||
os << ')';
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
class CodeGenCUDA : public CodeGenC {
|
||||
class CodeGenCUDA final : public CodeGenC {
|
||||
public:
|
||||
void Init(bool output_ssa);
|
||||
void AddFunction(LoweredFunc f);
|
||||
|
@ -31,6 +31,8 @@ class CodeGenCUDA : public CodeGenC {
|
|||
void PrintVecElemStore(
|
||||
const std::string& vec, Type t, int i, const std::string& value) final;
|
||||
|
||||
// overload visitor
|
||||
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
|
||||
void VisitStmt_(const Evaluate *op) final;
|
||||
|
||||
private:
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_metal.cc
|
||||
*/
|
||||
#include <tvm/runtime/config.h>
|
||||
#include <tvm/packed_func_ext.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "./codegen_metal.h"
|
||||
#include "../runtime/thread_storage_scope.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
void CodeGenMetal::InitFuncState(LoweredFunc f) {
|
||||
CodeGenC::InitFuncState(f);
|
||||
// analyze the data;
|
||||
for (Var arg : f->args) {
|
||||
if (arg.type().is_handle()) {
|
||||
alloc_storage_scope_[arg.get()] = "global";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CodeGenMetal::CodeGenMetal() {
|
||||
decl_stream << "#include <metal_stdlib>\n";
|
||||
decl_stream << "using namespace metal;\n\n";
|
||||
decl_stream << "union __TVMArgUnion {\n"
|
||||
<< " int v_int;\n"
|
||||
<< "};\n\n";
|
||||
}
|
||||
|
||||
void CodeGenMetal::AddFunction(LoweredFunc f) {
|
||||
// clear previous generated state.
|
||||
this->InitFuncState(f);
|
||||
// skip the first underscore, so SSA variable starts from _1
|
||||
GetUniqueName("_");
|
||||
// add to alloc buffer type.
|
||||
for (const auto & kv : f->handle_data_type) {
|
||||
RegisterHandleType(kv.first.get(), kv.second.type());
|
||||
}
|
||||
// Function header.
|
||||
this->stream << "kernel void " << f->name << "(\n";
|
||||
// Buffer arguments
|
||||
size_t num_buffer = 0;
|
||||
for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) {
|
||||
Var v = f->args[i];
|
||||
if (!v.type().is_handle()) break;
|
||||
stream << " ";
|
||||
std::string vid = AllocVarID(v.get());
|
||||
auto it = alloc_storage_scope_.find(v.get());
|
||||
CHECK(it != alloc_storage_scope_.end());
|
||||
PrintStorageScope(it->second, stream);
|
||||
stream << ' ';
|
||||
if (handle_data_type_.count(v.get())) {
|
||||
PrintType(handle_data_type_.at(v.get()), stream);
|
||||
stream << "*";
|
||||
} else {
|
||||
PrintType(v.type(), stream);
|
||||
}
|
||||
stream << ' ' << vid
|
||||
<< " [[ buffer(" << i << ") ]],\n";
|
||||
}
|
||||
// Setup normal arguments.
|
||||
size_t nargs = f->args.size() - num_buffer;
|
||||
std::string varg = GetUniqueName("arg");
|
||||
if (nargs != 0) {
|
||||
std::string arg_buf_type = f->name + "_args_t";
|
||||
stream << " constant " << arg_buf_type << "& " << varg
|
||||
<< " [[ buffer(" << num_buffer << ") ]],\n";
|
||||
// declare the struct
|
||||
decl_stream << "struct " << arg_buf_type << " {\n";
|
||||
for (size_t i = num_buffer; i < f->args.size(); ++i) {
|
||||
Var v = f->args[i];
|
||||
CHECK(!v.type().is_handle());
|
||||
std::string vid = AllocVarID(v.get());
|
||||
std::ostringstream vref;
|
||||
if (v.type().bits() == 32) {
|
||||
decl_stream << " ";
|
||||
PrintType(v.type(), decl_stream);
|
||||
decl_stream << " " << vid << ";\n";
|
||||
vref << varg << "." << vid;
|
||||
} else {
|
||||
// For non 32bit type, ref through arg union.
|
||||
decl_stream << " __TVMArgUnion " << vid << ";\n";
|
||||
vref << varg << "." << vid << ".v_";
|
||||
PrintType(v.type(), vref);
|
||||
}
|
||||
var_idmap_[v.get()] = vref.str();
|
||||
}
|
||||
decl_stream << "};\n\n";
|
||||
}
|
||||
// Setup the thread group info.
|
||||
CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
|
||||
CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
|
||||
int work_dim = 0;
|
||||
for (IterVar iv : f->thread_axis) {
|
||||
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
|
||||
work_dim = std::max(work_dim, scope.dim_index + 1);
|
||||
}
|
||||
if (work_dim != 0) {
|
||||
// use ushort by default for now
|
||||
stream << " ";
|
||||
PrintType(UInt(16, work_dim), stream);
|
||||
stream << " blockIdx [[threadgroup_position_in_grid]],\n";
|
||||
stream << " ";
|
||||
PrintType(UInt(16, work_dim), stream);
|
||||
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
|
||||
}
|
||||
// bind thread axis
|
||||
for (IterVar iv : f->thread_axis) {
|
||||
CHECK(!var_idmap_.count(iv->var.get()));
|
||||
if (work_dim <= 1) {
|
||||
var_idmap_[iv->var.get()] =
|
||||
iv->thread_tag.substr(0, iv->thread_tag.length() - 2);
|
||||
} else {
|
||||
var_idmap_[iv->var.get()] = iv->thread_tag;
|
||||
}
|
||||
}
|
||||
// the function scope.
|
||||
stream << ") {\n";
|
||||
int func_scope = this->BeginScope();
|
||||
this->PrintStmt(f->body);
|
||||
this->EndScope(func_scope);
|
||||
this->PrintIndent();
|
||||
this->stream << "}\n\n";
|
||||
}
|
||||
|
||||
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
|
||||
int lanes = t.lanes();
|
||||
if (t.is_handle()) {
|
||||
CHECK_EQ(lanes, 1)
|
||||
<< "do not yet support vector types";
|
||||
os << "void*"; return;
|
||||
}
|
||||
bool fail = false;
|
||||
if (t.is_float()) {
|
||||
switch (t.bits()) {
|
||||
case 16: os << "half"; break;
|
||||
case 32: os << "float"; break;
|
||||
default: fail = true; break;
|
||||
}
|
||||
if (!fail && lanes == 1) return;
|
||||
if (!fail && (lanes >= 2 && lanes <= 4)) {
|
||||
os << lanes; return;
|
||||
}
|
||||
} else if (t.is_uint() || t.is_int()) {
|
||||
if (t.is_uint()) {
|
||||
os << 'u';
|
||||
}
|
||||
if (t.bits() == 8 && t.lanes() == 4) {
|
||||
// directly 4 8 bit int in integer.
|
||||
os << "int"; return;
|
||||
}
|
||||
switch (t.bits()) {
|
||||
case 8: os << "char"; break;
|
||||
case 16: os << "short"; break;
|
||||
case 32: os << "int"; break;
|
||||
case 1: os << "bool"; break;
|
||||
default: fail = true; break;
|
||||
}
|
||||
if (!fail && lanes == 1) return;
|
||||
if (!fail && (lanes >= 2 && lanes <= 4)) {
|
||||
os << lanes; return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
|
||||
}
|
||||
|
||||
void CodeGenMetal::PrintStorageSync(const Call* op) {
|
||||
const std::string& sync = op->args[0].as<StringImm>()->value;
|
||||
if (sync == "warp") {
|
||||
this->PrintIndent();
|
||||
this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
|
||||
} else if (sync == "shared") {
|
||||
this->PrintIndent();
|
||||
this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n";
|
||||
} else if (sync == "global") {
|
||||
LOG(FATAL) << "global barrier not supported";
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenMetal::PrintStorageScope(
|
||||
const std::string& scope, std::ostream& os) { // NOLINT(*)
|
||||
if (scope == "global") {
|
||||
os << "device";
|
||||
} else if (scope == "shared") {
|
||||
os << "threadgroup";
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
|
||||
std::string v = PrintExpr(op->value);
|
||||
PrintType(op->type, os);
|
||||
os << "(";
|
||||
for (int i = 0; i < op->lanes; ++i) {
|
||||
if (i != 0) os << ", ";
|
||||
os << v;
|
||||
}
|
||||
os << ')';
|
||||
}
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
|
@ -0,0 +1,34 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_metal.h
|
||||
* \brief Generate Metal device code.
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_CODEGEN_METAL_H_
|
||||
#define TVM_CODEGEN_CODEGEN_METAL_H_
|
||||
|
||||
#include <tvm/codegen.h>
|
||||
#include <tvm/packed_func_ext.h>
|
||||
#include <string>
|
||||
#include "./codegen_c.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
class CodeGenMetal final : public CodeGenC {
|
||||
public:
|
||||
CodeGenMetal();
|
||||
void AddFunction(LoweredFunc f);
|
||||
// override print thread tag.
|
||||
void PrintArgUnionDecl();
|
||||
void InitFuncState(LoweredFunc f) final;
|
||||
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
|
||||
void PrintStorageSync(const Call* op) final; // NOLINT(*)
|
||||
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
|
||||
|
||||
// overload visitor
|
||||
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
|
||||
};
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_CODEGEN_CODEGEN_METAL_H_
|
|
@ -1,6 +1,6 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_cuda.cc
|
||||
* \file codegen_opencl.cc
|
||||
*/
|
||||
#include <tvm/runtime/config.h>
|
||||
#include <tvm/packed_func_ext.h>
|
||||
|
@ -22,19 +22,20 @@ void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
|
|||
}
|
||||
|
||||
void CodeGenOpenCL::AddFunction(LoweredFunc f) {
|
||||
this->stream << " __kernel ";
|
||||
this->stream << "__kernel ";
|
||||
CodeGenC::AddFunction(f);
|
||||
}
|
||||
|
||||
void CodeGenOpenCL::PrintThreadIndexExpr(
|
||||
std::string tag, std::ostream& os) { // NOLINT(*)
|
||||
runtime::ThreadScope ts = runtime::ThreadScope::make(tag);
|
||||
void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
|
||||
CHECK(!var_idmap_.count(iv->var.get()));
|
||||
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
|
||||
std::ostringstream os;
|
||||
if (ts.rank == 1) {
|
||||
os << "get_local_id(" << ts.dim_index << ")";
|
||||
} else {
|
||||
os << "get_global_id(" << ts.dim_index << ")"
|
||||
<< " / get_local_size(" << ts.dim_index << ")";
|
||||
os << "get_group_id(" << ts.dim_index << ")";
|
||||
}
|
||||
var_idmap_[iv->var.get()] = os.str();
|
||||
}
|
||||
|
||||
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
|
||||
|
@ -115,7 +116,9 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
|
|||
|
||||
void CodeGenOpenCL::PrintStorageSync(const Call* op) {
|
||||
const std::string& sync = op->args[0].as<StringImm>()->value;
|
||||
if (sync == "shared") {
|
||||
if (sync == "warp") {
|
||||
LOG(FATAL) << "warp sync not supported in opencl";
|
||||
} else if (sync == "shared") {
|
||||
this->PrintIndent();
|
||||
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
} else if (sync == "global") {
|
||||
|
@ -128,8 +131,20 @@ void CodeGenOpenCL::PrintStorageScope(
|
|||
if (scope == "global") {
|
||||
os << "__global";
|
||||
} else if (scope == "shared") {
|
||||
os << "__local ";
|
||||
os << "__local";
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
|
||||
std::string v = PrintExpr(op->value);
|
||||
os << '(';
|
||||
PrintType(op->type, os);
|
||||
os << ")(";
|
||||
for (int i = 0; i < op->lanes; ++i) {
|
||||
if (i != 0) os << ", ";
|
||||
os << v;
|
||||
}
|
||||
os << ')';
|
||||
}
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_opencl.h
|
||||
* \brief Utility to generate opencl code
|
||||
* \brief Generate OpenCL device code.
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_CODEGEN_OPENCL_H_
|
||||
#define TVM_CODEGEN_CODEGEN_OPENCL_H_
|
||||
|
@ -14,13 +14,12 @@
|
|||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
class CodeGenOpenCL : public CodeGenC {
|
||||
class CodeGenOpenCL final : public CodeGenC {
|
||||
public:
|
||||
void AddFunction(LoweredFunc f);
|
||||
// override print thread tag.
|
||||
void InitFuncState(LoweredFunc f) final;
|
||||
void PrintThreadIndexExpr(
|
||||
std::string tag, std::ostream& os) final; // NOLINT(*)
|
||||
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
|
||||
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
|
||||
void PrintStorageSync(const Call* op) final; // NOLINT(*)
|
||||
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
|
||||
|
@ -32,6 +31,8 @@ class CodeGenOpenCL : public CodeGenC {
|
|||
// the address of load/store
|
||||
void PrintVecAddr(const Variable* buffer, Type t,
|
||||
Expr base, std::ostream& os); // NOLINT(*)
|
||||
// overload visitor
|
||||
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
|
|
|
@ -102,6 +102,12 @@ class CodeGenSourceBase {
|
|||
int indent_{0};
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Create a source module for viewing.
|
||||
* \param code The code to be viewed.
|
||||
* \param fmt The code. format.
|
||||
*/
|
||||
runtime::Module SourceModuleCreate(std::string code, std::string fmt);
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file source_module.cc
|
||||
* \brief Source code module, only for viewing
|
||||
*/
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include "./codegen_source_base.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
using runtime::TVMArgs;
|
||||
using runtime::TVMRetValue;
|
||||
using runtime::PackedFunc;
|
||||
// Simulator function
|
||||
class SourceModuleNode : public runtime::ModuleNode {
|
||||
public:
|
||||
SourceModuleNode(std::string code,
|
||||
std::string fmt)
|
||||
: code_(code), fmt_(fmt) {}
|
||||
const char* type_key() const {
|
||||
return "source";
|
||||
}
|
||||
void PreCompile(const std::string& name, TVMContext ctx) final {
|
||||
}
|
||||
PackedFunc GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
|
||||
LOG(FATAL) << "Source module cannot execute, to get executable module"
|
||||
<< " build TVM with \'" << fmt_ << "\' runtime support";
|
||||
return PackedFunc();
|
||||
}
|
||||
void SaveToFile(const std::string& file_name,
|
||||
const std::string& format) final {
|
||||
LOG(FATAL) << "not implemented";
|
||||
}
|
||||
std::string GetSource(const std::string& format) final {
|
||||
return code_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string code_;
|
||||
std::string fmt_;
|
||||
};
|
||||
|
||||
runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
|
||||
std::shared_ptr<SourceModuleNode> n =
|
||||
std::make_shared<SourceModuleNode>(code, fmt);
|
||||
return runtime::Module(n);
|
||||
}
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
|
@ -31,11 +31,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
|
|||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
|
||||
CHECK(sptr_to_self.get() == this);
|
||||
if (name == runtime::symbol::tvm_entry_setdevice) {
|
||||
return PackedFunc([](const TVMArgs& args, TVMRetValue* rv){});
|
||||
}
|
||||
CHECK(m_.fmap.count(name)) << "Cannot find function " << name << " in the module";
|
||||
|
||||
if (!m_.fmap.count(name)) return PackedFunc();
|
||||
auto f = [sptr_to_self, name, this](const runtime::TVMArgs& args, TVMRetValue* rv) {
|
||||
auto* fsim = runtime::Registry::Get("tvm_callback_verilog_simulator");
|
||||
CHECK(fsim != nullptr)
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace tvm {
|
|||
namespace codegen {
|
||||
|
||||
/*! \brief Simulated device ram */
|
||||
class VPIDeviceAPI : public runtime::DeviceAPI {
|
||||
class VPIDeviceAPI final : public runtime::DeviceAPI {
|
||||
public:
|
||||
VPIDeviceAPI() {
|
||||
static const size_t kAllocAlign = 32U;
|
||||
|
@ -44,6 +44,12 @@ class VPIDeviceAPI : public runtime::DeviceAPI {
|
|||
if (ptr + size >= ram_max_) return nullptr;
|
||||
return (char*)(&ram_[0]) + ptr; // NOLINT(*)
|
||||
}
|
||||
void SetDevice(int dev_id) final {}
|
||||
void GetAttr(int dev_id, runtime::DeviceAttrKind kind, TVMRetValue* rv) final {
|
||||
if (kind == runtime::kExist) {
|
||||
*rv = 1;
|
||||
}
|
||||
}
|
||||
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
|
||||
static const size_t kAllocAlign = 32U;
|
||||
// always align to 32 bytes at least.
|
||||
|
@ -80,16 +86,18 @@ class VPIDeviceAPI : public runtime::DeviceAPI {
|
|||
free_blocks_.insert({b.size, head});
|
||||
}
|
||||
void CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) final {
|
||||
if (static_cast<int>(ctx_from.device_type) == kVPI) {
|
||||
from = RealAddr(from, size);
|
||||
from = RealAddr(static_cast<const char*>(from) + from_offset, size);
|
||||
}
|
||||
if (static_cast<int>(ctx_to.device_type) == kVPI) {
|
||||
to = RealAddr(to, size);
|
||||
to = RealAddr(static_cast<char*>(to) + to_offset, size);
|
||||
}
|
||||
memcpy(to, from, size);
|
||||
}
|
||||
|
|
|
@ -156,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
seq.emplace_back(Store::make(
|
||||
shared_buf, value,
|
||||
BufIndex(reduce_index, group_index, reduce_extent)));
|
||||
seq.emplace_back(SyncThread());
|
||||
seq.emplace_back(SyncThread("shared"));
|
||||
seq.emplace_back(MakeBufAllreduce(
|
||||
combiner, value.type(), shared_buf,
|
||||
reduce_index, group_index, reduce_extent, threadx_extent));
|
||||
|
@ -202,7 +202,7 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
reduce_align = reduce_align >> 1;
|
||||
Expr cond = reduce_index < (reduce_extent - reduce_align);
|
||||
seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
|
||||
seq.emplace_back(SyncThread());
|
||||
seq.emplace_back(SyncThread("shared"));
|
||||
}
|
||||
CHECK(threadx_extent >= 1 && warp_size_ >= 1);
|
||||
// normal synchronization
|
||||
|
@ -211,7 +211,7 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
reduce_align = reduce_align >> 1;
|
||||
Expr cond = reduce_index < reduce_align;
|
||||
seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
|
||||
seq.emplace_back(SyncThread());
|
||||
seq.emplace_back(SyncThread("shared"));
|
||||
}
|
||||
// in warp synchronization.
|
||||
std::vector<Stmt> in_warp_seq;
|
||||
|
@ -219,6 +219,7 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
while (reduce_align > 1) {
|
||||
reduce_align = reduce_align >> 1;
|
||||
in_warp_seq.emplace_back(freduce(reduce_align));
|
||||
seq.emplace_back(SyncThread("warp"));
|
||||
}
|
||||
if (in_warp_seq.size() != 0) {
|
||||
Stmt warp_body = MergeSeq(in_warp_seq);
|
||||
|
@ -249,10 +250,10 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
return ret;
|
||||
}
|
||||
// sync thread op.
|
||||
static Stmt SyncThread() {
|
||||
static Stmt SyncThread(const std::string& sync) {
|
||||
return Evaluate::make(
|
||||
Call::make(Int(32), intrinsic::tvm_storage_sync,
|
||||
{StringImm::make("shared")},
|
||||
{StringImm::make(sync)},
|
||||
Call::Intrinsic));
|
||||
}
|
||||
// The local buffer index.
|
||||
|
|
|
@ -244,6 +244,12 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
node, attr::device_context_id, device_id, nop));
|
||||
seq_init.push_back(AttrStmt::make(
|
||||
node, attr::device_context_type, device_type, nop));
|
||||
Stmt set_device = IfThenElse::make(
|
||||
device_type != kCPU, Evaluate::make(Call::make(
|
||||
Int(32), intrinsic::tvm_call_packed,
|
||||
{StringImm::make(runtime::symbol::tvm_set_device),
|
||||
device_type, device_id}, Call::Intrinsic)));
|
||||
body = Block::make(set_device, body);
|
||||
}
|
||||
n->body = MergeNest({seq_init, seq_check}, body);
|
||||
LoweredFunc f(n);
|
||||
|
|
|
@ -162,20 +162,6 @@ class HostDeviceSplitter : public IRMutator {
|
|||
std::shared_ptr<LoweredFuncNode> n =
|
||||
std::make_shared<LoweredFuncNode>(*f.operator->());
|
||||
n->body = this->Mutate(f->body);
|
||||
|
||||
if (f->is_packed_func && device_funcs_.size() != 0) {
|
||||
// insert auto set device from device function.
|
||||
Array<Expr> args = {StringImm::make(runtime::symbol::tvm_entry_setdevice)};
|
||||
for (Var arg : f->args) {
|
||||
args.push_back(arg);
|
||||
}
|
||||
n->body = Block::make(
|
||||
Evaluate::make(Call::make(
|
||||
Int(32), intrinsic::tvm_call_packed,
|
||||
args, Call::Intrinsic)),
|
||||
n->body);
|
||||
}
|
||||
|
||||
Array<LoweredFunc> ret{LoweredFunc(n)};
|
||||
for (LoweredFunc x : device_funcs_) {
|
||||
ret.push_back(x);
|
||||
|
@ -193,14 +179,21 @@ class HostDeviceSplitter : public IRMutator {
|
|||
m.visit_thread_extent_ = false;
|
||||
n->body = m.Mutate(body);
|
||||
n->name = os.str();
|
||||
n->args = m.undefined_;
|
||||
n->thread_axis = m.thread_axis_;
|
||||
|
||||
// improve the handle data type
|
||||
for (Var arg : n->args) {
|
||||
auto it = handle_data_type_.find(arg.get());
|
||||
if (it != handle_data_type_.end()) {
|
||||
n->handle_data_type.Set(arg, it->second);
|
||||
// Strictly order the arguments: Var pointers, positional arguments.
|
||||
for (Var v : m.undefined_) {
|
||||
if (v.type().is_handle()) {
|
||||
n->args.push_back(v);
|
||||
// mark handle data type.
|
||||
auto it = handle_data_type_.find(v.get());
|
||||
if (it != handle_data_type_.end()) {
|
||||
n->handle_data_type.Set(v, it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Var v : m.undefined_) {
|
||||
if (!v.type().is_handle()) {
|
||||
n->args.push_back(v);
|
||||
}
|
||||
}
|
||||
LoweredFunc f_device(n);
|
||||
|
@ -209,7 +202,6 @@ class HostDeviceSplitter : public IRMutator {
|
|||
for (Var arg : n->args) {
|
||||
call_args.push_back(arg);
|
||||
}
|
||||
|
||||
for (Expr ext : m.thread_extent_) {
|
||||
call_args.push_back(ext);
|
||||
}
|
||||
|
|
|
@ -54,14 +54,17 @@ class StorageSyncPlanner : public IRVisitor {
|
|||
if (!in_device_env_) return;
|
||||
if (const Call* call = op->value.as<Call>()) {
|
||||
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
|
||||
StorageScope scope = StorageScope::make(call->args[0].as<StringImm>()->value);
|
||||
if (scope.rank <= sync_scope_.rank) {
|
||||
CHECK_EQ(curr_stmt_.access.size(), 0U);
|
||||
curr_stmt_.access.emplace_back(
|
||||
AccessEntry(nullptr, Expr(), kSync, scope));
|
||||
// push to the scope
|
||||
scope_.back().push_back(curr_stmt_);
|
||||
curr_stmt_.access.clear();
|
||||
const std::string& s = call->args[0].as<StringImm>()->value;
|
||||
if (s != "warp") {
|
||||
StorageScope scope = StorageScope::make(s);
|
||||
if (scope.rank <= sync_scope_.rank) {
|
||||
CHECK_EQ(curr_stmt_.access.size(), 0U);
|
||||
curr_stmt_.access.emplace_back(
|
||||
AccessEntry(nullptr, Expr(), kSync, scope));
|
||||
// push to the scope
|
||||
scope_.back().push_back(curr_stmt_);
|
||||
curr_stmt_.access.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,8 +25,11 @@ class DeviceAPIManager {
|
|||
public:
|
||||
static const int kMaxDeviceAPI = 32;
|
||||
// Get API
|
||||
static DeviceAPI* Get(TVMContext ctx) {
|
||||
return Global()->GetAPI(ctx.device_type);
|
||||
static DeviceAPI* Get(const TVMContext& ctx) {
|
||||
return Get(ctx.device_type);
|
||||
}
|
||||
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
|
||||
return Global()->GetAPI(dev_type, allow_missing);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -42,20 +45,25 @@ class DeviceAPIManager {
|
|||
return &inst;
|
||||
}
|
||||
// Get or initialize API.
|
||||
DeviceAPI* GetAPI(DLDeviceType type) {
|
||||
if (api_[type] != nullptr) return api_[type];
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (api_[type] != nullptr) return api_[type];
|
||||
std::string factory = "device_api." + DeviceName(type);
|
||||
auto* f = Registry::Get(factory);
|
||||
CHECK(f != nullptr)
|
||||
<< "Device API " << DeviceName(type) << " is not enabled.";
|
||||
void* ptr = (*f)();
|
||||
api_[type] = static_cast<DeviceAPI*>(ptr);
|
||||
return api_[type];
|
||||
}
|
||||
DeviceAPI* GetAPI(int type, bool allow_missing);
|
||||
};
|
||||
|
||||
DeviceAPI* DeviceAPIManager::GetAPI(int type, bool allow_missing) {
|
||||
if (api_[type] != nullptr) return api_[type];
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (api_[type] != nullptr) return api_[type];
|
||||
std::string factory = "device_api." + DeviceName(type);
|
||||
auto* f = Registry::Get(factory);
|
||||
if (f == nullptr) {
|
||||
CHECK(allow_missing)
|
||||
<< "Device API " << DeviceName(type) << " is not enabled.";
|
||||
return nullptr;
|
||||
}
|
||||
void* ptr = (*f)();
|
||||
api_[type] = static_cast<DeviceAPI*>(ptr);
|
||||
return api_[type];
|
||||
}
|
||||
|
||||
|
||||
inline TVMArray* TVMArrayCreate_() {
|
||||
TVMArray* arr = new TVMArray();
|
||||
|
@ -352,8 +360,9 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
|
|||
<< "Can not copy across different ctx types directly";
|
||||
}
|
||||
DeviceAPIManager::Get(ctx)->CopyDataFromTo(
|
||||
from->data, to->data, from_size,
|
||||
from->ctx, to->ctx, stream);
|
||||
from->data, from->byte_offset,
|
||||
to->data, to->byte_offset,
|
||||
from_size, from->ctx, to->ctx, stream);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
@ -362,3 +371,29 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
|
|||
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// set device api
|
||||
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
int dev_type = args[0];
|
||||
int dev_id = args[1];
|
||||
DeviceAPIManager::Get(dev_type)->SetDevice(dev_id);
|
||||
});
|
||||
|
||||
// set device api
|
||||
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
int dev_type = args[0];
|
||||
int dev_id = args[1];
|
||||
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
|
||||
if (kind == kExist) {
|
||||
DeviceAPI* api = DeviceAPIManager::Get(dev_type, true);
|
||||
if (api != nullptr) {
|
||||
api->GetAttr(dev_id, kind, ret);
|
||||
} else {
|
||||
*ret = 0;
|
||||
}
|
||||
} else {
|
||||
DeviceAPIManager::Get(dev_type)->GetAttr(dev_id, kind, ret);
|
||||
}
|
||||
});
|
||||
|
|
|
@ -11,8 +11,14 @@
|
|||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
class CPUDeviceAPI : public DeviceAPI {
|
||||
class CPUDeviceAPI final : public DeviceAPI {
|
||||
public:
|
||||
void SetDevice(int dev_id) final {}
|
||||
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final {
|
||||
if (kind == kExist) {
|
||||
*rv = 1;
|
||||
}
|
||||
}
|
||||
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
|
||||
void* ptr;
|
||||
#if _MSC_VER
|
||||
|
@ -34,12 +40,16 @@ class CPUDeviceAPI : public DeviceAPI {
|
|||
}
|
||||
|
||||
void CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) final {
|
||||
memcpy(to, from, size);
|
||||
memcpy(static_cast<char*>(to) + to_offset,
|
||||
static_cast<const char*>(from) + from_offset,
|
||||
size);
|
||||
}
|
||||
|
||||
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
|
||||
|
|
|
@ -15,8 +15,33 @@
|
|||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
class CUDADeviceAPI : public DeviceAPI {
|
||||
class CUDADeviceAPI final : public DeviceAPI {
|
||||
public:
|
||||
void SetDevice(int dev_id) final {
|
||||
CUDA_CALL(cudaSetDevice(dev_id));
|
||||
}
|
||||
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final {
|
||||
int value;
|
||||
switch (kind) {
|
||||
case kExist:
|
||||
value = (
|
||||
cudaDeviceGetAttribute(
|
||||
&value, cudaDevAttrMaxThreadsPerBlock, dev_id)
|
||||
== cudaSuccess);
|
||||
break;
|
||||
case kMaxThreadsPerBlock: {
|
||||
CUDA_CALL(cudaDeviceGetAttribute(
|
||||
&value, cudaDevAttrMaxThreadsPerBlock, dev_id));
|
||||
break;
|
||||
}
|
||||
case kWarpSize: {
|
||||
CUDA_CALL(cudaDeviceGetAttribute(
|
||||
&value, cudaDevAttrWarpSize, dev_id));
|
||||
break;
|
||||
}
|
||||
}
|
||||
*rv = value;
|
||||
}
|
||||
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
|
||||
CUDA_CALL(cudaSetDevice(ctx.device_id));
|
||||
CHECK_EQ(256 % alignment, 0U)
|
||||
|
@ -32,12 +57,16 @@ class CUDADeviceAPI : public DeviceAPI {
|
|||
}
|
||||
|
||||
void CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) final {
|
||||
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
|
||||
from = static_cast<const char*>(from) + from_offset;
|
||||
to = static_cast<char*>(to) + to_offset;
|
||||
if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) {
|
||||
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
|
||||
if (ctx_from.device_id == ctx_to.device_id) {
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#include <string>
|
||||
#include <mutex>
|
||||
#include "./cuda_common.h"
|
||||
#include "../void_addr_args.h"
|
||||
#include "../pack_args.h"
|
||||
#include "../thread_storage_scope.h"
|
||||
#include "../meta_data.h"
|
||||
#include "../file_util.h"
|
||||
|
@ -214,41 +214,13 @@ class CUDAPrepGlobalBarrier {
|
|||
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
|
||||
};
|
||||
|
||||
|
||||
void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) {
|
||||
CHECK_EQ(args.size(), 3);
|
||||
TVMValue* values = static_cast<TVMValue*>(args[0].operator void*());
|
||||
int* type_codes = static_cast<int*>(args[1].operator void*());
|
||||
int num_args = args[2].operator int();
|
||||
|
||||
int device_id = -1;
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
if (type_codes[i] == kArrayHandle) {
|
||||
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
|
||||
CHECK_EQ(ctx.device_type, kGPU)
|
||||
<< "All operands need to be GPU";
|
||||
if (device_id == -1) {
|
||||
device_id = ctx.device_id;
|
||||
} else {
|
||||
CHECK_EQ(device_id, ctx.device_id)
|
||||
<< "Operands comes from different devices ";
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_NE(device_id, -1)
|
||||
<< "Cannot detect device id from list";
|
||||
CUDA_CALL(cudaSetDevice(device_id));
|
||||
}
|
||||
|
||||
PackedFunc CUDAModuleNode::GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) {
|
||||
CHECK_EQ(sptr_to_self.get(), this);
|
||||
CHECK_NE(name, symbol::tvm_module_main)
|
||||
<< "Device function do not have main";
|
||||
if (name == symbol::tvm_entry_setdevice) {
|
||||
return PackedFunc(AutoSetCUDADevice);
|
||||
} else if (name == symbol::tvm_prepare_global_barrier) {
|
||||
if (name == symbol::tvm_prepare_global_barrier) {
|
||||
return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self));
|
||||
}
|
||||
auto it = fmap_.find(name);
|
||||
|
@ -256,7 +228,7 @@ PackedFunc CUDAModuleNode::GetFunction(
|
|||
const FunctionInfo& info = it->second;
|
||||
CUDAWrappedFunc f;
|
||||
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
|
||||
return PackFromVoidAddrArgs(f, info.arg_types);
|
||||
return PackFuncVoidAddr(f, info.arg_types);
|
||||
}
|
||||
|
||||
Module CUDAModuleCreate(
|
||||
|
|
|
@ -13,10 +13,29 @@
|
|||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
enum DeviceAttrKind : int {
|
||||
kExist = 0,
|
||||
kMaxThreadsPerBlock = 1,
|
||||
kWarpSize = 2
|
||||
};
|
||||
|
||||
class DeviceAPI {
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~DeviceAPI() {}
|
||||
/*!
|
||||
* \brief Set the environment device id to dev_id
|
||||
* \param dev_id The device id.
|
||||
* \return The allocated device pointer
|
||||
*/
|
||||
virtual void SetDevice(int dev_id) = 0;
|
||||
/*!
|
||||
* \brief Get attribute of specified device.
|
||||
* \param dev_id The device id
|
||||
* \param kind The result kind
|
||||
* \param rv The return value.
|
||||
*/
|
||||
virtual void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) = 0;
|
||||
/*!
|
||||
* \brief Allocate a data space on device.
|
||||
* \param ctx The device context to perform operation.
|
||||
|
@ -36,13 +55,18 @@ class DeviceAPI {
|
|||
* \brief copy data from one place to another
|
||||
* \param dev The device to perform operation.
|
||||
* \param from The source array.
|
||||
* \param from_offset The byte offeset in the from.
|
||||
* \param to The target array.
|
||||
* \param to_offset The byte offset in the to.
|
||||
* \param size The size of the memory
|
||||
* \param ctx_from The source context
|
||||
* \param ctx_to The target context
|
||||
* \param stream Optional stream object.
|
||||
*/
|
||||
virtual void CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
|
@ -59,11 +83,12 @@ class DeviceAPI {
|
|||
* \brief The name of Device API factory.
|
||||
* \param type The device type.
|
||||
*/
|
||||
inline std::string DeviceName(DLDeviceType type) {
|
||||
switch (static_cast<int>(type)) {
|
||||
inline std::string DeviceName(int type) {
|
||||
switch (type) {
|
||||
case kCPU: return "cpu";
|
||||
case kGPU: return "gpu";
|
||||
case kOpenCL: return "opencl";
|
||||
case kMetal: return "metal";
|
||||
case kVPI: return "vpi";
|
||||
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
|
||||
}
|
||||
|
|
|
@ -30,7 +30,6 @@ struct FunctionInfo {
|
|||
void Save(dmlc::JSONWriter *writer) const;
|
||||
void Load(dmlc::JSONReader *reader);
|
||||
};
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_META_DATA_H_
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file metal_common.h
|
||||
* \brief Metal common header
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_
|
||||
#define TVM_RUNTIME_METAL_METAL_COMMON_H_
|
||||
|
||||
#import <Metal/MTLBuffer.h>
|
||||
#import <Metal/MTLCommandQueue.h>
|
||||
#import <Metal/MTLCommandBuffer.h>
|
||||
#import <Metal/MTLBlitCommandEncoder.h>
|
||||
#import <Metal/MTLDevice.h>
|
||||
#import <Metal/MTLLibrary.h>
|
||||
|
||||
#include <tvm/runtime/config.h>
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "../device_api.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
namespace metal {
|
||||
/*!
|
||||
* \brief Process global Metal workspace.
|
||||
*/
|
||||
class MetalWorkspace final : public DeviceAPI {
|
||||
public:
|
||||
// the devices
|
||||
std::vector<id<MTLDevice> > devices;
|
||||
// the queues
|
||||
std::vector<id<MTLCommandQueue> > queues;
|
||||
// Warp size constant
|
||||
std::vector<int> warp_size;
|
||||
// Whether it is initialized.
|
||||
bool initialized_{false};
|
||||
// the mutex for initialization
|
||||
std::mutex mutex;
|
||||
// Get command queue for given context.
|
||||
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
|
||||
CHECK_EQ(ctx.device_type, kMetal);
|
||||
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
|
||||
<< "Invalid Metal device_id=" << ctx.device_id;
|
||||
return queues[ctx.device_id];
|
||||
}
|
||||
// Get device for given context
|
||||
id<MTLDevice> GetDevice(TVMContext ctx) {
|
||||
CHECK_EQ(ctx.device_type, kMetal);
|
||||
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size())
|
||||
<< "Invalid Metal device_id=" << ctx.device_id;
|
||||
return devices[ctx.device_id];
|
||||
}
|
||||
// Initialize workspace
|
||||
// Return false if already initialized, otherwise return true.
|
||||
void Init();
|
||||
// override device API
|
||||
void SetDevice(int dev_id) final;
|
||||
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final;
|
||||
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
|
||||
void FreeDataSpace(TVMContext ctx, void* ptr) final;
|
||||
void CopyDataFromTo(const void* from,
|
||||
size_t from_size,
|
||||
void* to,
|
||||
size_t to_size,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) final;
|
||||
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
|
||||
// get the global workspace
|
||||
static MetalWorkspace* Global();
|
||||
};
|
||||
|
||||
/*! \brief Thread local workspace */
|
||||
class MetalThreadEntry {
|
||||
public:
|
||||
/*! \brief The current context */
|
||||
TVMContext context;
|
||||
/*! \brief The shared buffer used for copy. */
|
||||
std::vector<id<MTLBuffer> > temp_buffer_;
|
||||
|
||||
MetalThreadEntry() {
|
||||
context.device_id = 0;
|
||||
context.device_type = static_cast<DLDeviceType>(kMetal);
|
||||
}
|
||||
// Get temp buffer with at least size under ctx.
|
||||
id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size);
|
||||
// get the global workspace
|
||||
static MetalThreadEntry* ThreadLocal();
|
||||
};
|
||||
} // namespace metal
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_METAL_METAL_COMMON_H_
|
|
@ -0,0 +1,240 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file metal_device_api.mm
|
||||
*/
|
||||
#include "./metal_common.h"
|
||||
|
||||
#if TVM_METAL_RUNTIME
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
namespace metal {
|
||||
|
||||
MetalWorkspace* MetalWorkspace::Global() {
|
||||
static MetalWorkspace inst;
|
||||
return &inst;
|
||||
}
|
||||
|
||||
void MetalWorkspace::GetAttr(
|
||||
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) {
|
||||
this->Init();
|
||||
size_t index = static_cast<size_t>(dev_id);
|
||||
if (kind == kExist) {
|
||||
*rv = int(index< devices.size());
|
||||
return;
|
||||
}
|
||||
CHECK_LT(index, devices.size())
|
||||
<< "Invalid device id " << index;
|
||||
switch (kind) {
|
||||
case kMaxThreadsPerBlock: {
|
||||
*rv = static_cast<int>(
|
||||
[devices[dev_id] maxThreadsPerThreadgroup].width);
|
||||
break;
|
||||
}
|
||||
case kWarpSize: {
|
||||
// Set warp size to be 1 for safty reason.
|
||||
*rv = 1;
|
||||
break;
|
||||
}
|
||||
case kExist: break;
|
||||
}
|
||||
}
|
||||
|
||||
static const char* kDummyKernel = R"A0B0(
|
||||
using namespace metal;
|
||||
// Simple copy kernel
|
||||
// Just to get threadExecutionWidth from current Metal API.
|
||||
kernel void CopyKernel(
|
||||
device float* dst [[buffer(0)]],
|
||||
device float* src [[buffer(1)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
dst[gid.x] = src[gid.x];
|
||||
}
|
||||
)A0B0";
|
||||
|
||||
// Hack to get Warp size from device.
|
||||
// Note that in Metal
|
||||
// state.threadExecutionWidth can vary per kernel
|
||||
// maybe due to resource constraint.
|
||||
// so state.threadExecutionWidth can be smaller than warp size
|
||||
// For safe issue, turn off warp-aware optimization for now
|
||||
// But we keep this code.
|
||||
int GetWarpSize(id<MTLDevice> dev) {
|
||||
NSError* error_msg = nil;
|
||||
id<MTLLibrary> lib =
|
||||
[dev
|
||||
newLibraryWithSource:
|
||||
[NSString stringWithUTF8String:kDummyKernel]
|
||||
options:nil
|
||||
error:&error_msg];
|
||||
CHECK(lib != nil) << error_msg;
|
||||
id<MTLFunction> f =
|
||||
[lib
|
||||
newFunctionWithName:
|
||||
[NSString stringWithUTF8String:"CopyKernel"]];
|
||||
CHECK(f!= nil);
|
||||
id<MTLComputePipelineState> state =
|
||||
[dev
|
||||
newComputePipelineStateWithFunction:f
|
||||
error:&error_msg];
|
||||
CHECK(state != nil) << error_msg;
|
||||
return state.threadExecutionWidth;
|
||||
}
|
||||
|
||||
void MetalWorkspace::Init() {
|
||||
if (initialized_) return;
|
||||
std::lock_guard<std::mutex>(this->mutex);
|
||||
if (initialized_) return;
|
||||
initialized_ = true;
|
||||
if (devices.size() != 0) return;
|
||||
NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
|
||||
for (size_t i = 0; i < devs.count; ++i) {
|
||||
id<MTLDevice> d = [devs objectAtIndex:i];
|
||||
devices.push_back(d);
|
||||
queues.push_back([d newCommandQueue]);
|
||||
LOG(INFO) << "Intializing Metal device " << i
|
||||
<< ", name=" << d.name;
|
||||
warp_size.push_back(GetWarpSize(d));
|
||||
}
|
||||
}
|
||||
|
||||
void MetalWorkspace::SetDevice(int dev_id) {
|
||||
MetalThreadEntry::ThreadLocal()->context.device_id = dev_id;
|
||||
}
|
||||
|
||||
void* MetalWorkspace::AllocDataSpace(
|
||||
TVMContext ctx, size_t size, size_t alignment) {
|
||||
this->Init();
|
||||
id<MTLDevice> dev = GetDevice(ctx);
|
||||
// allocate buffer in GPU only mode.
|
||||
id<MTLBuffer> buf = [
|
||||
dev newBufferWithLength:size
|
||||
options:MTLResourceStorageModePrivate];
|
||||
// retain ARC to keep it alive before release.
|
||||
return (__bridge_retained void*)(buf);
|
||||
}
|
||||
|
||||
void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
|
||||
// release the ptr.
|
||||
CFBridgingRelease(ptr);
|
||||
}
|
||||
|
||||
void MetalWorkspace::CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) {
|
||||
this->Init();
|
||||
CHECK(stream == nullptr);
|
||||
TVMContext ctx = ctx_from;
|
||||
if (ctx_from.device_type == kCPU) ctx = ctx_to;
|
||||
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
|
||||
id<MTLCommandBuffer> cb = [queue commandBuffer];
|
||||
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
|
||||
int from_dev_type = static_cast<int>(ctx_from.device_type);
|
||||
int to_dev_type = static_cast<int>(ctx_to.device_type);
|
||||
|
||||
if (from_dev_type == kMetal && to_dev_type == kMetal) {
|
||||
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
|
||||
<< "Metal disallow cross device copy.";
|
||||
[encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
|
||||
sourceOffset:from_offset
|
||||
toBuffer:(__bridge id<MTLBuffer>)(to)
|
||||
destinationOffset:to_offset
|
||||
size:size];
|
||||
[encoder endEncoding];
|
||||
[cb commit];
|
||||
} else if (from_dev_type == kMetal && to_dev_type == kCPU) {
|
||||
// copy to a local buffer before get into global buffer.
|
||||
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
|
||||
if (from_buf.storageMode != MTLStorageModeShared) {
|
||||
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
|
||||
->GetTempBuffer(ctx_from, size);
|
||||
[encoder copyFromBuffer:from_buf
|
||||
sourceOffset:from_offset
|
||||
toBuffer:temp
|
||||
destinationOffset:0
|
||||
size:size];
|
||||
[encoder endEncoding];
|
||||
[cb commit];
|
||||
[cb waitUntilCompleted];
|
||||
memcpy(static_cast<char*>(to) + to_offset,
|
||||
static_cast<char*>([temp contents]),
|
||||
size);
|
||||
} else {
|
||||
memcpy(static_cast<char*>(to) + to_offset,
|
||||
static_cast<char*>([from_buf contents]) + from_offset,
|
||||
size);
|
||||
}
|
||||
} else if (from_dev_type == kCPU && to_dev_type == kMetal) {
|
||||
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
|
||||
if (to_buf.storageMode == MTLStorageModeShared) {
|
||||
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
|
||||
->GetTempBuffer(ctx_to, size);
|
||||
memcpy([temp contents],
|
||||
static_cast<const char*>(from) + from_offset,
|
||||
size);
|
||||
[encoder copyFromBuffer:temp
|
||||
sourceOffset:0
|
||||
toBuffer:to_buf
|
||||
destinationOffset:to_offset
|
||||
size:size];
|
||||
[encoder endEncoding];
|
||||
[cb commit];
|
||||
} else {
|
||||
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
|
||||
static_cast<const char*>(from) + from_offset,
|
||||
size);
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
|
||||
<< ", from=" << from_dev_type
|
||||
<< ", to=" << to_dev_type;
|
||||
}
|
||||
}
|
||||
|
||||
void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
|
||||
CHECK(stream == nullptr);
|
||||
// commit an empty command buffer and wait until it completes.
|
||||
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
|
||||
id<MTLCommandBuffer> cb = [queue commandBuffer];
|
||||
[cb commit];
|
||||
[cb waitUntilCompleted];
|
||||
}
|
||||
|
||||
id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
|
||||
if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
|
||||
temp_buffer_.resize(ctx.device_id + 1, nil);
|
||||
}
|
||||
if (temp_buffer_[ctx.device_id] == nil ||
|
||||
temp_buffer_[ctx.device_id].length < size) {
|
||||
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
|
||||
temp_buffer_[ctx.device_id] = [
|
||||
dev newBufferWithLength:size
|
||||
options:MTLStorageModeShared];
|
||||
}
|
||||
return temp_buffer_[ctx.device_id];
|
||||
}
|
||||
|
||||
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
|
||||
|
||||
MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
|
||||
return MetalThreadStore::Get();
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("device_api.metal")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
DeviceAPI* ptr = MetalWorkspace::Global();
|
||||
*rv = static_cast<void*>(ptr);
|
||||
});
|
||||
|
||||
} // namespace metal
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_METAL_RUNTIME
|
|
@ -0,0 +1,36 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file metal_module.h
|
||||
* \brief Execution handling of Metal kernels
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_
|
||||
#define TVM_RUNTIME_METAL_METAL_MODULE_H_
|
||||
|
||||
#include <tvm/runtime/config.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "../meta_data.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
/*! \brief Maximum number of GPU supported in MetalModule. */
|
||||
static constexpr const int kMetalMaxNumDevice = 32;
|
||||
|
||||
/*!
|
||||
* \brief create a metal module from data.
|
||||
*
|
||||
* \param data The data content.
|
||||
* \param fmt The format of the data, can be "metal" or "metallib"
|
||||
* \param fmap The map function information map of each function.
|
||||
* \param source Optional, source file
|
||||
*/
|
||||
Module MetalModuleCreate(
|
||||
std::string data,
|
||||
std::string fmt,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string source);
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_
|
|
@ -0,0 +1,273 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file metal_module.cc
|
||||
*/
|
||||
#include "./metal_module.h"
|
||||
|
||||
#if TVM_METAL_RUNTIME
|
||||
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <tvm/runtime/module.h>
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include "./metal_common.h"
|
||||
#include "../pack_args.h"
|
||||
#include "../thread_storage_scope.h"
|
||||
#include "../meta_data.h"
|
||||
#include "../file_util.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
// Module to support thread-safe multi-GPU execution.
|
||||
// cuModule is a per-GPU module
|
||||
// The runtime will contain a per-device module table
|
||||
// The modules will be lazily loaded
|
||||
class MetalModuleNode final :public runtime::ModuleNode {
|
||||
public:
|
||||
explicit MetalModuleNode(std::string data,
|
||||
std::string fmt,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string source)
|
||||
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
|
||||
}
|
||||
const char* type_key() const final {
|
||||
return "metal";
|
||||
}
|
||||
|
||||
void PreCompile(const std::string& name, TVMContext ctx) final {
|
||||
GetPipelineState(ctx.device_id, name);
|
||||
}
|
||||
|
||||
PackedFunc GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
|
||||
|
||||
void SaveToFile(const std::string& file_name,
|
||||
const std::string& format) final {
|
||||
std::string fmt = GetFileFormat(file_name, format);
|
||||
CHECK_EQ(fmt, fmt_)
|
||||
<< "Can only save to format=" << fmt_;
|
||||
std::string meta_file = GetMetaFilePath(file_name);
|
||||
SaveMetaDataToFile(meta_file, fmap_);
|
||||
SaveBinaryToFile(file_name, data_);
|
||||
}
|
||||
|
||||
std::string GetSource(const std::string& format) final {
|
||||
if (format == fmt_) return data_;
|
||||
if (source_.length() != 0) {
|
||||
return source_;
|
||||
} else if (fmt_ == "metal") {
|
||||
return data_;
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
// get a CUfunction from primary context in device_id
|
||||
id<MTLComputePipelineState> GetPipelineState(
|
||||
size_t device_id, const std::string& func_name) {
|
||||
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
|
||||
CHECK_LT(device_id, w->devices.size());
|
||||
// start lock scope.
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (finfo_.size() <= device_id) {
|
||||
finfo_.resize(device_id + 1, DeviceEntry());
|
||||
}
|
||||
DeviceEntry& e = finfo_[device_id];
|
||||
auto it = e.smap.find(func_name);
|
||||
if (it != e.smap.end()) return it->second;
|
||||
// compile
|
||||
NSError* err_msg = nil;
|
||||
if (e.lib == nil) {
|
||||
if (fmt_ == "metal") {
|
||||
e.lib = [
|
||||
w->devices[device_id]
|
||||
newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
|
||||
options:nil
|
||||
error:&err_msg];
|
||||
if (err_msg != nil || e.lib == nil) {
|
||||
LOG(FATAL) << "Fail to compile metal lib:"
|
||||
<< [[err_msg localizedDescription] UTF8String];
|
||||
}
|
||||
} else {
|
||||
// Build from library.
|
||||
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
|
||||
auto data = dispatch_data_create(
|
||||
data_.c_str(), data_.length(), q, ^{});
|
||||
e.lib = [
|
||||
w->devices[device_id]
|
||||
newLibraryWithData:data
|
||||
error:&err_msg];
|
||||
if (err_msg != nil || e.lib == nil) {
|
||||
LOG(FATAL) << "Fail to compile metal lib:"
|
||||
<< [[err_msg localizedDescription] UTF8String];
|
||||
}
|
||||
}
|
||||
}
|
||||
id<MTLFunction> f = [
|
||||
e.lib
|
||||
newFunctionWithName:
|
||||
[NSString stringWithUTF8String:func_name.c_str()]];
|
||||
CHECK(f != nil) << "cannot find function " << func_name;
|
||||
id<MTLComputePipelineState> state =
|
||||
[w->devices[device_id]
|
||||
newComputePipelineStateWithFunction:f
|
||||
error:&err_msg];
|
||||
CHECK(state != nil)
|
||||
<< "cannot get state:" << " for function " << func_name
|
||||
<< [[err_msg localizedDescription] UTF8String];
|
||||
// The state.threadExecutionWidth can change dynamically according
|
||||
// to the resource constraint in kernel, so it is not strictly hold
|
||||
// Turn of warp aware optimziation for now.
|
||||
// CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
|
||||
e.smap[func_name] = state;
|
||||
return state;
|
||||
}
|
||||
|
||||
private:
|
||||
// device specific entry
|
||||
struct DeviceEntry {
|
||||
// library
|
||||
id<MTLLibrary> lib = nil;
|
||||
// state cache;
|
||||
std::unordered_map<std::string, id<MTLComputePipelineState> > smap;
|
||||
};
|
||||
// the binary data
|
||||
std::string data_;
|
||||
// The format
|
||||
std::string fmt_;
|
||||
// function information table.
|
||||
std::unordered_map<std::string, FunctionInfo> fmap_;
|
||||
// The source
|
||||
std::string source_;
|
||||
// function information.
|
||||
std::vector<DeviceEntry> finfo_;
|
||||
// internal mutex when updating the module
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
// a wrapped function class to get packed fucn.
|
||||
class MetalWrappedFunc {
|
||||
public:
|
||||
// initialize the METAL function.
|
||||
void Init(MetalModuleNode* m,
|
||||
std::shared_ptr<ModuleNode> sptr,
|
||||
const std::string& func_name,
|
||||
size_t num_buffer_args,
|
||||
size_t num_pack_args,
|
||||
const std::vector<std::string>& thread_axis_tags) {
|
||||
w_ = metal::MetalWorkspace::Global();
|
||||
m_ = m;
|
||||
sptr_ = sptr;
|
||||
func_name_ = func_name;
|
||||
num_buffer_args_ = num_buffer_args;
|
||||
num_pack_args_ = num_pack_args;
|
||||
std::fill(scache_.begin(), scache_.end(), (id<MTLComputePipelineState>)nil);
|
||||
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
|
||||
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
|
||||
int dev_id = t->context.device_id;
|
||||
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
|
||||
}
|
||||
// invoke the function with void arguments
|
||||
void operator()(TVMArgs args,
|
||||
TVMRetValue* rv,
|
||||
const ArgUnion* pack_args) const {
|
||||
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
|
||||
int device_id = t->context.device_id;
|
||||
if (scache_[device_id] == nil) {
|
||||
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
|
||||
}
|
||||
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
||||
id<MTLCommandQueue> queue = w_->GetCommandQueue(t->context);
|
||||
id<MTLCommandBuffer> cb = [queue commandBuffer];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
|
||||
[encoder setComputePipelineState:scache_[device_id]];
|
||||
for (size_t i = 0; i < num_buffer_args_; ++i) {
|
||||
void* buf = args[i];
|
||||
[encoder setBuffer:(__bridge id<MTLBuffer>)(buf) offset:0 atIndex:i];
|
||||
}
|
||||
if (num_pack_args_ != 0) {
|
||||
[encoder setBytes:pack_args
|
||||
length:num_pack_args_ * sizeof(ArgUnion)
|
||||
atIndex:num_buffer_args_];
|
||||
}
|
||||
// launch
|
||||
MTLSize dimGrid = MTLSizeMake(
|
||||
wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
|
||||
MTLSize dimBlock = MTLSizeMake(
|
||||
wl.block_dim(0), wl.block_dim(1), wl.work_size[2]);
|
||||
[encoder dispatchThreadgroups: dimGrid
|
||||
threadsPerThreadgroup: dimBlock];
|
||||
[encoder endEncoding];
|
||||
[cb commit];
|
||||
}
|
||||
|
||||
private:
|
||||
// Reference to global workspace.
|
||||
metal::MetalWorkspace* w_;
|
||||
// internal module
|
||||
MetalModuleNode* m_;
|
||||
// the resource holder
|
||||
std::shared_ptr<ModuleNode> sptr_;
|
||||
// The name of the function.
|
||||
std::string func_name_;
|
||||
// Number of buffer arguments
|
||||
size_t num_buffer_args_;
|
||||
// number of packed arguments.
|
||||
size_t num_pack_args_;
|
||||
// Device state cache per device.
|
||||
// mark as mutable, to enable lazy initialization
|
||||
mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
|
||||
// thread axis configuration
|
||||
ThreadAxisConfig thread_axis_cfg_;
|
||||
};
|
||||
|
||||
PackedFunc MetalModuleNode::GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) {
|
||||
CHECK_EQ(sptr_to_self.get(), this);
|
||||
CHECK_NE(name, symbol::tvm_module_main)
|
||||
<< "Device function do not have main";
|
||||
auto it = fmap_.find(name);
|
||||
if (it == fmap_.end()) return PackedFunc();
|
||||
const FunctionInfo& info = it->second;
|
||||
MetalWrappedFunc f;
|
||||
size_t num_buffer_args = NumBufferArgs(info.arg_types);
|
||||
f.Init(this, sptr_to_self, name,
|
||||
num_buffer_args, info.arg_types.size() - num_buffer_args,
|
||||
info.thread_axis_tags);
|
||||
return PackFuncNonBufferArg(f, info.arg_types);
|
||||
}
|
||||
|
||||
Module MetalModuleCreate(
|
||||
std::string data,
|
||||
std::string fmt,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string source) {
|
||||
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
|
||||
w->Init();
|
||||
std::shared_ptr<MetalModuleNode> n =
|
||||
std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
|
||||
return Module(n);
|
||||
}
|
||||
|
||||
// Load module from module.
|
||||
Module MetalModuleLoad(const std::string& file_name,
|
||||
const std::string& format) {
|
||||
std::string data;
|
||||
std::unordered_map<std::string, FunctionInfo> fmap;
|
||||
std::string fmt = GetFileFormat(file_name, format);
|
||||
std::string meta_file = GetMetaFilePath(file_name);
|
||||
LoadBinaryFromFile(file_name, &data);
|
||||
LoadMetaDataFromFile(meta_file, &fmap);
|
||||
return MetalModuleCreate(data, fmt, fmap, "");
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("module.loadfile_metal")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = MetalModuleLoad(args[0], args[1]);
|
||||
});
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_METAL_RUNTIME
|
|
@ -84,17 +84,25 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
|
|||
}
|
||||
|
||||
bool RuntimeEnabled(const std::string& target) {
|
||||
std::string load_f_name;
|
||||
std::string f_name;
|
||||
if (target == "cpu") {
|
||||
return true;
|
||||
} else if (target == "cuda" || target == "gpu") {
|
||||
load_f_name = "module.loadfile_ptx";
|
||||
f_name = "device_api.gpu";
|
||||
} else if (target == "cl" || target == "opencl") {
|
||||
load_f_name = "module.loadfile_cl";
|
||||
f_name = "device_api.opencl";
|
||||
} else if (target == "mtl" || target == "metal") {
|
||||
f_name = "device_api.metal";
|
||||
} else if (target == "stackvm") {
|
||||
f_name = "codegen.build_stackvm";
|
||||
} else if (target == "llvm") {
|
||||
f_name = "codegen.build_llvm";
|
||||
} else if (target == "vpi" || target == "verilog") {
|
||||
f_name = "device_api.vpi";
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown optional runtime " << target;
|
||||
}
|
||||
return runtime::Registry::Get(load_f_name) != nullptr;
|
||||
return runtime::Registry::Get(f_name) != nullptr;
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("module._Enabled")
|
||||
|
|
|
@ -101,12 +101,14 @@ inline const char* CLGetErrorString(cl_int error) {
|
|||
/*!
|
||||
* \brief Process global OpenCL workspace.
|
||||
*/
|
||||
class OpenCLWorkspace : public DeviceAPI {
|
||||
class OpenCLWorkspace final : public DeviceAPI {
|
||||
public:
|
||||
// global platform id
|
||||
cl_platform_id platform_id;
|
||||
// global context of this process
|
||||
cl_context context{nullptr};
|
||||
// whether the workspace it initialized.
|
||||
bool initialized_{false};
|
||||
// the devices
|
||||
std::vector<cl_device_id> devices;
|
||||
// the queues
|
||||
|
@ -126,24 +128,25 @@ class OpenCLWorkspace : public DeviceAPI {
|
|||
OPENCL_CALL(clReleaseContext(context));
|
||||
}
|
||||
}
|
||||
// whether the workspace is initialized.
|
||||
inline bool initialized() const {
|
||||
return context != nullptr;
|
||||
}
|
||||
// Initialzie the device.
|
||||
void Init();
|
||||
// get the queue of the context
|
||||
cl_command_queue GetQueue(TVMContext ctx) const {
|
||||
cl_command_queue GetQueue(TVMContext ctx) {
|
||||
CHECK_EQ(ctx.device_type, kOpenCL);
|
||||
CHECK(initialized())
|
||||
<< "The OpenCL is not initialized";
|
||||
this->Init();
|
||||
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
|
||||
<< "Invalid OpenCL device_id=" << ctx.device_id;
|
||||
return queues[ctx.device_id];
|
||||
}
|
||||
// override device API
|
||||
void SetDevice(int dev_id) final;
|
||||
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final;
|
||||
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
|
||||
void FreeDataSpace(TVMContext ctx, void* ptr) final;
|
||||
void CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
|
|
|
@ -18,8 +18,41 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
|
|||
return &inst;
|
||||
}
|
||||
|
||||
void OpenCLWorkspace::SetDevice(int dev_id) {
|
||||
OpenCLThreadEntry::ThreadLocal()->context.device_id = dev_id;
|
||||
}
|
||||
|
||||
void OpenCLWorkspace::GetAttr(
|
||||
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) {
|
||||
this->Init();
|
||||
size_t index = static_cast<size_t>(dev_id);
|
||||
if (kind == kExist) {
|
||||
*rv = static_cast<int>(index< devices.size());
|
||||
return;
|
||||
}
|
||||
CHECK_LT(index, devices.size())
|
||||
<< "Invalid device id " << index;
|
||||
size_t value;
|
||||
switch (kind) {
|
||||
case kMaxThreadsPerBlock: {
|
||||
OPENCL_CALL(clGetDeviceInfo(
|
||||
devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE,
|
||||
sizeof(size_t), &value, nullptr));
|
||||
*rv = static_cast<int64_t>(value);
|
||||
break;
|
||||
}
|
||||
case kWarpSize: {
|
||||
*rv = 1;
|
||||
break;
|
||||
}
|
||||
case kExist: break;
|
||||
}
|
||||
}
|
||||
|
||||
void* OpenCLWorkspace::AllocDataSpace(
|
||||
TVMContext ctx, size_t size, size_t alignment) {
|
||||
this->Init();
|
||||
CHECK(context != nullptr) << "No OpenCL device";
|
||||
cl_int err_code;
|
||||
cl_mem mptr = clCreateBuffer(
|
||||
this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
|
||||
|
@ -33,30 +66,35 @@ void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
|
|||
}
|
||||
|
||||
void OpenCLWorkspace::CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) {
|
||||
this->Init();
|
||||
CHECK(stream == nullptr);
|
||||
if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) {
|
||||
OPENCL_CALL(clEnqueueCopyBuffer(
|
||||
this->GetQueue(ctx_to),
|
||||
static_cast<cl_mem>((void*)from), // NOLINT(*)
|
||||
static_cast<cl_mem>(to),
|
||||
0, 0, size, 0, nullptr, nullptr));
|
||||
from_offset, to_offset, size, 0, nullptr, nullptr));
|
||||
} else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) {
|
||||
OPENCL_CALL(clEnqueueReadBuffer(
|
||||
this->GetQueue(ctx_from),
|
||||
static_cast<cl_mem>((void*)from), // NOLINT(*)
|
||||
CL_FALSE, 0, size, to,
|
||||
CL_FALSE, from_offset, size,
|
||||
static_cast<char*>(to) + to_offset,
|
||||
0, nullptr, nullptr));
|
||||
OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
|
||||
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) {
|
||||
OPENCL_CALL(clEnqueueWriteBuffer(
|
||||
this->GetQueue(ctx_to),
|
||||
static_cast<cl_mem>(to),
|
||||
CL_FALSE, 0, size, from,
|
||||
CL_FALSE, to_offset, size,
|
||||
static_cast<const char*>(from) + from_offset,
|
||||
0, nullptr, nullptr));
|
||||
OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
|
||||
} else {
|
||||
|
@ -97,8 +135,9 @@ std::string GetDeviceInfo(
|
|||
|
||||
std::vector<cl_platform_id> GetPlatformIDs() {
|
||||
cl_uint ret_size;
|
||||
OPENCL_CALL(clGetPlatformIDs(0, nullptr, &ret_size));
|
||||
cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
|
||||
std::vector<cl_platform_id> ret;
|
||||
if (code != CL_SUCCESS) return ret;
|
||||
ret.resize(ret_size);
|
||||
OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
|
||||
return ret;
|
||||
|
@ -108,11 +147,12 @@ std::vector<cl_device_id> GetDeviceIDs(
|
|||
cl_platform_id pid, std::string device_type) {
|
||||
cl_device_type dtype = CL_DEVICE_TYPE_ALL;
|
||||
if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
|
||||
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_CPU;
|
||||
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
|
||||
if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
|
||||
cl_uint ret_size;
|
||||
OPENCL_CALL(clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size));
|
||||
cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size);
|
||||
std::vector<cl_device_id> ret;
|
||||
if (code != CL_SUCCESS) return ret;
|
||||
ret.resize(ret_size);
|
||||
OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
|
||||
return ret;
|
||||
|
@ -127,70 +167,53 @@ bool MatchPlatformInfo(
|
|||
return param_value.find(value) != std::string::npos;
|
||||
}
|
||||
|
||||
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
|
||||
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
|
||||
std::lock_guard<std::mutex>(w->mu);
|
||||
if (w->initialized()) return false;
|
||||
// matching conditions
|
||||
std::string platform_name, device_type;
|
||||
|
||||
for (int i = 0; i < args.num_args; ++i) {
|
||||
std::string arg = args[i];
|
||||
size_t pos = arg.find_first_of('=');
|
||||
CHECK_EQ(pos, std::string::npos)
|
||||
<< "Argumentes need to be key=value";
|
||||
std::string key = arg.substr(0, pos);
|
||||
std::string val = arg.substr(pos + 1, arg.length() - pos - 1);
|
||||
if (key == "platform_name") {
|
||||
platform_name = val;
|
||||
} else if (key == "device_type") {
|
||||
device_type = val;
|
||||
} else {
|
||||
LOG(FATAL) << "unknown DeviceInit option " << key;
|
||||
}
|
||||
}
|
||||
void OpenCLWorkspace::Init() {
|
||||
if (initialized_) return;
|
||||
std::lock_guard<std::mutex>(this->mu);
|
||||
if (initialized_) return;
|
||||
initialized_ = true;
|
||||
if (context != nullptr) return;
|
||||
// matched platforms
|
||||
std::vector<cl_platform_id> platform_matched;
|
||||
for (cl_platform_id pid : cl::GetPlatformIDs()) {
|
||||
bool matched = true;
|
||||
if (!cl::MatchPlatformInfo(pid, CL_PLATFORM_NAME, platform_name)) matched = false;
|
||||
if (matched) platform_matched.push_back(pid);
|
||||
}
|
||||
std::vector<cl_platform_id> platform_matched = cl::GetPlatformIDs();
|
||||
if (platform_matched.size() == 0) {
|
||||
LOG(FATAL) << "No OpenCL platform matched given existing options ...";
|
||||
LOG(WARNING) << "No OpenCL platform matched given existing options ...";
|
||||
return;
|
||||
}
|
||||
if (platform_matched.size() > 1) {
|
||||
LOG(WARNING) << "Multiple OpenCL platforms matched, use the first one ... ";
|
||||
}
|
||||
w->platform_id = platform_matched[0];
|
||||
|
||||
this->platform_id = platform_matched[0];
|
||||
LOG(INFO) << "Initialize OpenCL platform \'"
|
||||
<< cl::GetPlatformInfo(w->platform_id, CL_PLATFORM_NAME) << '\'';
|
||||
<< cl::GetPlatformInfo(this->platform_id, CL_PLATFORM_NAME) << '\'';
|
||||
std::vector<cl_device_id> devices_matched =
|
||||
cl::GetDeviceIDs(w->platform_id, device_type);
|
||||
CHECK_GT(devices_matched.size(), 0U)
|
||||
<< "No OpenCL device any device matched given the options";
|
||||
w->devices = devices_matched;
|
||||
cl::GetDeviceIDs(this->platform_id, "gpu");
|
||||
if (devices_matched.size() == 0) {
|
||||
LOG(WARNING) << "No OpenCL device any device matched given the options";
|
||||
return;
|
||||
}
|
||||
this->devices = devices_matched;
|
||||
cl_int err_code;
|
||||
w->context = clCreateContext(
|
||||
nullptr, w->devices.size(), &(w->devices[0]),
|
||||
this->context = clCreateContext(
|
||||
nullptr, this->devices.size(), &(this->devices[0]),
|
||||
nullptr, nullptr, &err_code);
|
||||
OPENCL_CHECK_ERROR(err_code);
|
||||
CHECK_EQ(w->queues.size(), 0U);
|
||||
for (size_t i = 0; i < w->devices.size(); ++i) {
|
||||
cl_device_id did = w->devices[i];
|
||||
w->queues.push_back(
|
||||
clCreateCommandQueue(w->context, did, 0, &err_code));
|
||||
CHECK_EQ(this->queues.size(), 0U);
|
||||
for (size_t i = 0; i < this->devices.size(); ++i) {
|
||||
cl_device_id did = this->devices[i];
|
||||
this->queues.push_back(
|
||||
clCreateCommandQueue(this->context, did, 0, &err_code));
|
||||
OPENCL_CHECK_ERROR(err_code);
|
||||
LOG(INFO) << "opencl(" << i
|
||||
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
|
||||
<< "\' cl_device_id=" << did;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("module.init_opencl")
|
||||
.set_body(InitOpenCL);
|
||||
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
|
||||
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
|
||||
w->Init();
|
||||
return true;
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("device_api.opencl")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "../void_addr_args.h"
|
||||
#include "../pack_args.h"
|
||||
#include "../thread_storage_scope.h"
|
||||
#include "../meta_data.h"
|
||||
#include "../file_util.h"
|
||||
|
@ -90,7 +90,8 @@ class OpenCLModuleNode : public ModuleNode {
|
|||
// Initialize the programs
|
||||
void InitProgram() {
|
||||
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
|
||||
CHECK(w->initialized());
|
||||
w->Init();
|
||||
CHECK(w->context != nullptr) << "No OpenCL device";
|
||||
if (fmt_ == "cl") {
|
||||
const char* s = data_.c_str();
|
||||
size_t len = data_.length();
|
||||
|
@ -179,6 +180,7 @@ class OpenCLWrappedFunc {
|
|||
std::string func_name,
|
||||
std::vector<size_t> arg_size,
|
||||
const std::vector<std::string>& thread_axis_tags) {
|
||||
w_ = cl::OpenCLWorkspace::Global();
|
||||
m_ = m;
|
||||
sptr_ = sptr;
|
||||
entry_ = entry;
|
||||
|
@ -190,9 +192,7 @@ class OpenCLWrappedFunc {
|
|||
void operator()(TVMArgs args,
|
||||
TVMRetValue* rv,
|
||||
void** void_args) const {
|
||||
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
|
||||
cl::OpenCLThreadEntry* t = cl::OpenCLThreadEntry::ThreadLocal();
|
||||
CHECK(w->initialized());
|
||||
// get the kernel from thread local kernel table.
|
||||
if (entry_.kernel_id >= t->kernel_table.size()) {
|
||||
t->kernel_table.resize(entry_.kernel_id + 1);
|
||||
|
@ -200,13 +200,13 @@ class OpenCLWrappedFunc {
|
|||
const auto& e = t->kernel_table[entry_.kernel_id];
|
||||
cl_kernel kernel = e.kernel;
|
||||
if (kernel == nullptr || e.version != entry_.version) {
|
||||
kernel = m_->InstallKernel(w, t, func_name_, entry_);
|
||||
kernel = m_->InstallKernel(w_, t, func_name_, entry_);
|
||||
}
|
||||
// setup arguments.
|
||||
for (cl_uint i = 0; i < arg_size_.size(); ++i) {
|
||||
OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i]));
|
||||
}
|
||||
cl_command_queue queue = w->GetQueue(t->context);
|
||||
cl_command_queue queue = w_->GetQueue(t->context);
|
||||
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
||||
cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
|
||||
for (cl_uint i = 0; i < work_dim; ++i) {
|
||||
|
@ -221,6 +221,8 @@ class OpenCLWrappedFunc {
|
|||
}
|
||||
|
||||
private:
|
||||
// global workspace.
|
||||
cl::OpenCLWorkspace* w_;
|
||||
// The module
|
||||
OpenCLModuleNode* m_;
|
||||
// resource handle
|
||||
|
@ -235,45 +237,12 @@ class OpenCLWrappedFunc {
|
|||
ThreadAxisConfig thread_axis_cfg_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Automatically detect and set cuda device.
|
||||
* \param args The arguments.
|
||||
*/
|
||||
void AutoSetOpenCLDevice(const TVMArgs& args, TVMRetValue* rv) {
|
||||
CHECK_EQ(args.size(), 3);
|
||||
TVMValue* values = static_cast<TVMValue*>(args[0].operator void*());
|
||||
int* type_codes = static_cast<int*>(args[1].operator void*());
|
||||
int num_args = args[2].operator int();
|
||||
|
||||
// TODO(tqchen): merge this with CUDA logic.
|
||||
int device_id = -1;
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
if (type_codes[i] == kArrayHandle) {
|
||||
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
|
||||
CHECK_EQ(ctx.device_type, kOpenCL)
|
||||
<< "All operands need to be OpenCL";
|
||||
if (device_id == -1) {
|
||||
device_id = ctx.device_id;
|
||||
} else {
|
||||
CHECK_EQ(device_id, ctx.device_id)
|
||||
<< "Operands comes from different devices ";
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_NE(device_id, -1)
|
||||
<< "Cannot detect device id from list";
|
||||
cl::OpenCLThreadEntry::ThreadLocal()->context.device_id = device_id;
|
||||
}
|
||||
|
||||
PackedFunc OpenCLModuleNode::GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) {
|
||||
CHECK_EQ(sptr_to_self.get(), this);
|
||||
CHECK_NE(name, symbol::tvm_module_main)
|
||||
<< "Device function do not have main";
|
||||
if (name == symbol::tvm_entry_setdevice) {
|
||||
return PackedFunc(AutoSetOpenCLDevice);
|
||||
}
|
||||
auto it = fmap_.find(name);
|
||||
if (it == fmap_.end()) return PackedFunc();
|
||||
const FunctionInfo& info = it->second;
|
||||
|
@ -289,7 +258,7 @@ PackedFunc OpenCLModuleNode::GetFunction(
|
|||
// initialize the wrapped func.
|
||||
f.Init(this, sptr_to_self, kid_map_.at(name),
|
||||
name, arg_size, info.thread_axis_tags);
|
||||
return PackFromVoidAddrArgs(f, info.arg_types);
|
||||
return PackFuncVoidAddr(f, info.arg_types);
|
||||
}
|
||||
|
||||
Module OpenCLModuleCreate(
|
||||
|
|
|
@ -18,7 +18,7 @@ namespace runtime {
|
|||
/*!
|
||||
* \brief create a cuda module from data.
|
||||
*
|
||||
* \param data The module data, can be ptx, cubin
|
||||
* \param data The module data.
|
||||
* \param fmt The format of the data, can be "clbin", "cl"
|
||||
* \param fmap The map function information map of each function.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,233 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file pack_args.h
|
||||
* \brief Utility to pack TVMArgs to other type-erased fution calling convention.
|
||||
*
|
||||
* Two type erased function signatures are supported.
|
||||
* - cuda_style(void** args, int num_args);
|
||||
* - Pack everything by address
|
||||
* - metal_style(void** buffers, int num_buffers,
|
||||
* union_32bit args[N], int num_args);
|
||||
* - Pack buffer by address, pack rest parameter into 32bit union buffer.
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_PACK_ARGS_H_
|
||||
#define TVM_RUNTIME_PACK_ARGS_H_
|
||||
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <vector>
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
/*!
|
||||
* \brief argument union type of 32bit.
|
||||
* Choose 32 bit because most GPU API do not work well with 64 bit.
|
||||
*/
|
||||
union ArgUnion {
|
||||
int32_t v_int32;
|
||||
uint32_t v_uint32;
|
||||
float v_float32;
|
||||
};
|
||||
/*!
|
||||
* \brief Create a packed function from void addr types.
|
||||
*
|
||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
|
||||
* \param arg_types The arguments that wish to get from
|
||||
* \tparam T the function type
|
||||
*
|
||||
* \return The wrapped packed function.
|
||||
*/
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
|
||||
/*!
|
||||
* \brief Create a packed function that from function only packs buffer arguments.
|
||||
*
|
||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
|
||||
* \param arg_types The arguments that wish to get from
|
||||
* \tparam T the function type
|
||||
*
|
||||
* \return The wrapped packed function.
|
||||
*/
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
|
||||
/*!
|
||||
* \brief Extract number of buffer argument from the argument types.
|
||||
* \param arg_types The argument types.
|
||||
* \return number of buffer arguments
|
||||
*/
|
||||
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types);
|
||||
|
||||
// implementations details
|
||||
namespace detail {
|
||||
template<typename T, int kSize>
|
||||
class TempArray {
|
||||
public:
|
||||
explicit TempArray(int size) {}
|
||||
T* data() {
|
||||
return data_;
|
||||
}
|
||||
private:
|
||||
T data_[kSize];
|
||||
};
|
||||
template<typename T>
|
||||
class TempArray<T, 0> {
|
||||
public:
|
||||
explicit TempArray(int size) : data_(size) {}
|
||||
T* data() {
|
||||
return data_.data();
|
||||
}
|
||||
private:
|
||||
std::vector<T> data_;
|
||||
};
|
||||
|
||||
/*! \brief conversion code used in void arg. */
|
||||
enum ArgConvertCode {
|
||||
INT64_TO_INT64,
|
||||
INT64_TO_INT32,
|
||||
INT64_TO_UINT32,
|
||||
FLOAT64_TO_FLOAT32,
|
||||
FLOAT64_TO_FLOAT64,
|
||||
HANDLE_TO_HANDLE
|
||||
};
|
||||
|
||||
inline ArgConvertCode GetArgConvertCode(TVMType t) {
|
||||
CHECK_EQ(t.lanes, 1U)
|
||||
<< "Cannot pass vector type argument to devic function for now";
|
||||
if (t.code == kInt) {
|
||||
if (t.bits == 64U) return INT64_TO_INT64;
|
||||
if (t.bits == 32U) return INT64_TO_INT32;
|
||||
} else if (t.code == kUInt) {
|
||||
if (t.bits == 32U) return INT64_TO_UINT32;
|
||||
} else if (t.code == kFloat) {
|
||||
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
|
||||
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
|
||||
} else if (t.code == kHandle) {
|
||||
return HANDLE_TO_HANDLE;
|
||||
}
|
||||
LOG(FATAL) << "Cannot handle " << t << " as device function argument";
|
||||
return HANDLE_TO_HANDLE;
|
||||
}
|
||||
|
||||
template<int N, typename F>
|
||||
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
|
||||
int num_args = static_cast<int>(codes.size());
|
||||
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
|
||||
TempArray<void*, N> addr_(num_args);
|
||||
TempArray<ArgUnion, N> holder_(num_args);
|
||||
void** addr = addr_.data();
|
||||
ArgUnion* holder = holder_.data();
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
switch (codes[i]) {
|
||||
case INT64_TO_INT64:
|
||||
case FLOAT64_TO_FLOAT64:
|
||||
case HANDLE_TO_HANDLE: {
|
||||
addr[i] = (void*)&(args.values[i]); // NOLINT(*)
|
||||
break;
|
||||
}
|
||||
case INT64_TO_INT32: {
|
||||
holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
|
||||
addr[i] = &(holder[i]);
|
||||
break;
|
||||
}
|
||||
case INT64_TO_UINT32 : {
|
||||
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
|
||||
addr[i] = &(holder[i]);
|
||||
break;
|
||||
}
|
||||
case FLOAT64_TO_FLOAT32: {
|
||||
holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
|
||||
addr[i] = &(holder[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
f(args, ret, addr);
|
||||
};
|
||||
return PackedFunc(ret);
|
||||
}
|
||||
|
||||
template<int N, typename F>
|
||||
inline PackedFunc PackFuncNonBufferArg_(
|
||||
F f, int base, const std::vector<ArgConvertCode>& codes) {
|
||||
int num_args = static_cast<int>(codes.size());
|
||||
auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
|
||||
TempArray<ArgUnion, N> holder_(num_args);
|
||||
ArgUnion* holder = holder_.data();
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
switch (codes[i]) {
|
||||
case INT64_TO_INT64:
|
||||
case FLOAT64_TO_FLOAT64: {
|
||||
LOG(FATAL) << "Donot support 64bit argument to device function"; break;
|
||||
}
|
||||
case INT64_TO_INT32: {
|
||||
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
|
||||
break;
|
||||
}
|
||||
case INT64_TO_UINT32 : {
|
||||
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
|
||||
break;
|
||||
}
|
||||
case FLOAT64_TO_FLOAT32: {
|
||||
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
|
||||
break;
|
||||
}
|
||||
case HANDLE_TO_HANDLE: {
|
||||
LOG(FATAL) << "not reached"; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
f(args, ret, holder);
|
||||
};
|
||||
return PackedFunc(ret);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) {
|
||||
std::vector<detail::ArgConvertCode> codes(arg_types.size());
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
codes[i] = detail::GetArgConvertCode(arg_types[i]);
|
||||
}
|
||||
size_t num_void_args = arg_types.size();
|
||||
// specialization
|
||||
if (num_void_args <= 4) {
|
||||
return detail::PackFuncVoidAddr_<4>(f, codes);
|
||||
} else if (num_void_args <= 8) {
|
||||
return detail::PackFuncVoidAddr_<8>(f, codes);
|
||||
} else {
|
||||
return detail::PackFuncVoidAddr_<0>(f, codes);
|
||||
}
|
||||
}
|
||||
|
||||
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types) {
|
||||
size_t base = arg_types.size();
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i].code != kHandle) {
|
||||
base = i; break;
|
||||
}
|
||||
}
|
||||
for (size_t i = base; i < arg_types.size(); ++i) {
|
||||
CHECK(arg_types[i].code != kHandle)
|
||||
<< "Device function need to be organized";
|
||||
}
|
||||
return base;
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types) {
|
||||
size_t num_buffer = NumBufferArgs(arg_types);
|
||||
std::vector<detail::ArgConvertCode> codes;
|
||||
for (size_t i = num_buffer; i < arg_types.size(); ++i) {
|
||||
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
|
||||
}
|
||||
int base = static_cast<int>(num_buffer);
|
||||
size_t nargs = codes.size();
|
||||
// specialization
|
||||
if (nargs <= 4) {
|
||||
return detail::PackFuncNonBufferArg_<4>(f, base, codes);
|
||||
} else {
|
||||
return detail::PackFuncNonBufferArg_<0>(f, base, codes);
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_PACK_ARGS_H_
|
|
@ -1,164 +0,0 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file void_addr_args.h
|
||||
* \brief Utility to convert TVMArgs to void* array type-erasure function call.
|
||||
*
|
||||
* Array of argument address is a typical way of type-erasure for functions.
|
||||
* The function signiture looks like function(void** args, int num_args);
|
||||
* Where args takes the address of each input.
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_VOID_ADDR_ARGS_H_
|
||||
#define TVM_RUNTIME_VOID_ADDR_ARGS_H_
|
||||
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <vector>
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
/*!
|
||||
* \brief Create a packed function from void addr types
|
||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
|
||||
* \param arg_types The arguments that wish to get from
|
||||
* \tparam T the function type
|
||||
*
|
||||
* \return The wrapped packed function.
|
||||
*/
|
||||
template<typename F>
|
||||
inline PackedFunc PackFromVoidAddrArgs(
|
||||
F f, const std::vector<TVMType>& arg_types);
|
||||
|
||||
// implementations details
|
||||
namespace detail {
|
||||
/*!
|
||||
* \brief void addr argument data content
|
||||
* holder in case conversion is needed.
|
||||
*/
|
||||
union VoidArgHolder {
|
||||
int32_t v_int32;
|
||||
uint32_t v_uint32;
|
||||
float v_float32;
|
||||
};
|
||||
|
||||
template<int MAX_NARG>
|
||||
class VoidAddrArray {
|
||||
public:
|
||||
explicit VoidAddrArray(int num_args) {
|
||||
}
|
||||
void** addr() {
|
||||
return addr_;
|
||||
}
|
||||
VoidArgHolder* holder() {
|
||||
return holder_;
|
||||
}
|
||||
|
||||
private:
|
||||
void* addr_[MAX_NARG];
|
||||
VoidArgHolder holder_[MAX_NARG];
|
||||
};
|
||||
|
||||
template<>
|
||||
class VoidAddrArray<0> {
|
||||
public:
|
||||
explicit VoidAddrArray(int num_args)
|
||||
: addr_(num_args), holder_(num_args) {
|
||||
}
|
||||
void** addr() {
|
||||
return addr_.data();
|
||||
}
|
||||
VoidArgHolder* holder() {
|
||||
return holder_.data();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<void*> addr_;
|
||||
std::vector<VoidArgHolder> holder_;
|
||||
};
|
||||
|
||||
/*! \brief conversion code used in void arg. */
|
||||
enum VoidArgConvertCode {
|
||||
INT64_TO_INT64,
|
||||
INT64_TO_INT32,
|
||||
INT64_TO_UINT32,
|
||||
FLOAT64_TO_FLOAT32,
|
||||
FLOAT64_TO_FLOAT64,
|
||||
HANDLE_TO_HANDLE
|
||||
};
|
||||
|
||||
template<int N, typename F>
|
||||
inline PackedFunc PackFromVoidAddrArgs_(
|
||||
F f, const std::vector<VoidArgConvertCode>& codes) {
|
||||
int num_args = static_cast<int>(codes.size());
|
||||
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
|
||||
VoidAddrArray<N> temp(num_args);
|
||||
void** addr = temp.addr();
|
||||
VoidArgHolder* holder = temp.holder();
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
switch (codes[i]) {
|
||||
case INT64_TO_INT64:
|
||||
case FLOAT64_TO_FLOAT64:
|
||||
case HANDLE_TO_HANDLE: {
|
||||
addr[i] = (void*)&(args.values[i]); // NOLINT(*)
|
||||
break;
|
||||
}
|
||||
case INT64_TO_INT32: {
|
||||
holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
|
||||
addr[i] = &(holder[i]);
|
||||
break;
|
||||
}
|
||||
case INT64_TO_UINT32 : {
|
||||
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
|
||||
addr[i] = &(holder[i]);
|
||||
break;
|
||||
}
|
||||
case FLOAT64_TO_FLOAT32: {
|
||||
holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
|
||||
addr[i] = &(holder[i]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
f(args, ret, addr);
|
||||
};
|
||||
return PackedFunc(ret);
|
||||
}
|
||||
|
||||
inline VoidArgConvertCode GetVoidArgConvertCode(TVMType t) {
|
||||
CHECK_EQ(t.lanes, 1U);
|
||||
if (t.code == kInt) {
|
||||
if (t.bits == 64U) return INT64_TO_INT64;
|
||||
if (t.bits == 32U) return INT64_TO_INT32;
|
||||
} else if (t.code == kUInt) {
|
||||
if (t.bits == 32U) return INT64_TO_UINT32;
|
||||
} else if (t.code == kFloat) {
|
||||
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
|
||||
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
|
||||
} else if (t.code == kHandle) {
|
||||
return HANDLE_TO_HANDLE;
|
||||
}
|
||||
LOG(FATAL) << "Cannot handle " << t;
|
||||
return HANDLE_TO_HANDLE;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template<typename F>
|
||||
inline PackedFunc PackFromVoidAddrArgs(
|
||||
F f, const std::vector<TVMType>& arg_types) {
|
||||
std::vector<detail::VoidArgConvertCode> codes(arg_types.size());
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
codes[i] = detail::GetVoidArgConvertCode(arg_types[i]);
|
||||
}
|
||||
size_t num_void_args = arg_types.size();
|
||||
// specialization
|
||||
if (num_void_args <= 4) {
|
||||
return detail::PackFromVoidAddrArgs_<4>(f, codes);
|
||||
} else if (num_void_args <= 8) {
|
||||
return detail::PackFromVoidAddrArgs_<8>(f, codes);
|
||||
} else {
|
||||
return detail::PackFromVoidAddrArgs_<0>(f, codes);
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_VOID_ADDR_ARGS_H_
|
|
@ -17,6 +17,7 @@ def lower(s, args, name="mydot"):
|
|||
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
|
||||
stmt = tvm.ir_pass.Simplify(stmt)
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0)
|
||||
fapi = tvm.ir_pass.LowerPackedCall(fapi)
|
||||
return fapi
|
||||
|
||||
|
||||
|
@ -35,7 +36,7 @@ def test_dot():
|
|||
fapi = lower(s, [A, B, C])
|
||||
|
||||
def verify(target):
|
||||
if not tvm.codegen.enabled(target):
|
||||
if not tvm.module.enabled(target):
|
||||
print("Target %s is not enabled" % target)
|
||||
return
|
||||
f = tvm.codegen.build_module(fapi, target)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
def test_exp():
|
||||
# graph
|
||||
|
@ -8,21 +9,21 @@ def test_exp():
|
|||
B = tvm.compute(A.shape, lambda *i: tvm.exp(A(*i)), name='B')
|
||||
s = tvm.create_schedule(B.op)
|
||||
# create iter var and assign them tags.
|
||||
num_thread = 64
|
||||
num_thread = 8
|
||||
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
|
||||
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
|
||||
# one line to build the function.
|
||||
def check_device(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
fexp = tvm.build(s, [A, B],
|
||||
device, host,
|
||||
name="myexp")
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
# launch the kernel.
|
||||
n = 1024
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
|
@ -31,8 +32,6 @@ def test_exp():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
|
||||
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
check_device("cuda", "llvm")
|
||||
check_device("opencl")
|
||||
|
||||
|
@ -46,7 +45,7 @@ def test_log_llvm():
|
|||
# create iter var and assign them tags.
|
||||
bx, tx = s[B].split(B.op.axis[0], factor=32)
|
||||
# one line to build the function.
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
|
||||
flog = tvm.build(s, [A, B],
|
||||
|
@ -60,17 +59,26 @@ def test_log_llvm():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), np.log(a.asnumpy()), rtol=1e-5)
|
||||
|
||||
from tvm.contrib import nvcc_compiler
|
||||
|
||||
@tvm.register_func
|
||||
def tvm_callback_cuda_compile(code):
|
||||
print(code)
|
||||
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
|
||||
return ptx
|
||||
|
||||
def test_add():
|
||||
# graph
|
||||
n = tvm.convert(1024)
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.placeholder((n,), name='B')
|
||||
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
|
||||
bias = tvm.var("bias", dtype="float32")
|
||||
scale = tvm.var("scale", dtype="float32")
|
||||
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
|
||||
# schedule
|
||||
s = tvm.create_schedule(C.op)
|
||||
# create iter var and assign them tags.
|
||||
num_thread = 256
|
||||
num_thread = 32
|
||||
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
|
||||
tx, x = s[C].split(x, nparts=num_thread)
|
||||
_, x = s[C].split(x, factor=4)
|
||||
|
@ -80,26 +88,28 @@ def test_add():
|
|||
|
||||
# one line to build the function.
|
||||
def check_device(device):
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
fadd = tvm.build(s, [A, B, C],
|
||||
fadd = tvm.build(s, [A, B, C, bias, scale],
|
||||
device,
|
||||
name="myadd")
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
print(fadd.imported_modules[0].get_source())
|
||||
# launch the kernel.
|
||||
n = 1024
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||
fadd(a, b, c)
|
||||
vbias = np.random.uniform()
|
||||
vscale = np.random.uniform()
|
||||
fadd(a, b, c, vbias, vscale)
|
||||
np.testing.assert_allclose(
|
||||
c.asnumpy(), a.asnumpy() + b.asnumpy())
|
||||
c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
|
||||
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
import tvm
|
||||
from tvm.contrib import nvcc_compiler
|
||||
from tvm.contrib import metal_compiler
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
#@tvm.register_func
|
||||
def tvm_callback_metal_compile(code):
|
||||
lib = metal_compiler.compile_source(code)
|
||||
return lib
|
||||
|
||||
def test_gemm():
|
||||
# graph
|
||||
|
@ -63,15 +70,14 @@ def test_gemm():
|
|||
s = s.normalize()
|
||||
|
||||
# one line to build the function.
|
||||
def check_device(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
|
||||
f = tvm.build(s, [A, B, C], device, host,
|
||||
f = tvm.build(s, [A, B, C], device,
|
||||
max_auto_unroll_step=max_auto_unroll_step)
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
m = n
|
||||
|
@ -81,15 +87,20 @@ def test_gemm():
|
|||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(b_np, ctx)
|
||||
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
|
||||
for i in range(4):
|
||||
f(a, b, c)
|
||||
f(a, b, c)
|
||||
ctx.sync()
|
||||
tbegin = time.time()
|
||||
f(a, b, c)
|
||||
tpush = time.time()
|
||||
ctx.sync()
|
||||
tend = time.time()
|
||||
print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin))
|
||||
np.testing.assert_allclose(
|
||||
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
|
||||
|
||||
check_device("cuda")
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
check_device("metal")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemm()
|
||||
|
|
|
@ -19,16 +19,16 @@ def test_reduce_prims():
|
|||
|
||||
# one line to build the function.
|
||||
def check_device(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
freduce = tvm.build(s,
|
||||
args=[A, B],
|
||||
target=device, target_host=host,
|
||||
name="myreduce")
|
||||
print(freduce.imported_modules[0].get_source())
|
||||
# launch the kernel.
|
||||
n = 1028
|
||||
m = 129
|
||||
|
@ -41,9 +41,7 @@ def test_reduce_prims():
|
|||
res[:2] = 0
|
||||
np.testing.assert_allclose(npy, res, rtol=1e-4)
|
||||
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
|
||||
check_device("metal")
|
||||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
test_prim(tvm.sum, np.sum)
|
||||
|
@ -64,7 +62,7 @@ def test_rfactor():
|
|||
s[BF].parallel(BF.op.axis[0])
|
||||
# one line to build the function.
|
||||
def check_target(target="llvm"):
|
||||
if not tvm.codegen.enabled(target):
|
||||
if not tvm.module.enabled(target):
|
||||
return
|
||||
ctx = tvm.cpu(0)
|
||||
fapi = tvm.lower(s, args=[A, B])
|
||||
|
@ -105,15 +103,14 @@ def test_rfactor_threads():
|
|||
|
||||
# one line to build the function.
|
||||
def check_target(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
fapi = tvm.lower(s, args=[A, B])
|
||||
fapi2 = tvm.ir_pass.LowerThreadAllreduce(fapi, 32)
|
||||
fsum = tvm.build(fapi,
|
||||
target=device,
|
||||
name="mysum")
|
||||
print(fsum.imported_modules[0].get_source())
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
m = mm
|
||||
|
@ -125,9 +122,8 @@ def test_rfactor_threads():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), res, rtol=1e-4)
|
||||
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
check_target("cuda")
|
||||
check_target("metal")
|
||||
check_target("opencl")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -23,15 +23,14 @@ def test_scan():
|
|||
s[s_update].bind(xi, thread_x)
|
||||
|
||||
# one line to build the function.
|
||||
def check_device(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
fscan = tvm.build(s, [X, res],
|
||||
device, host,
|
||||
device,
|
||||
name="myscan")
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
# launch the kernel.
|
||||
n = 1024
|
||||
m = 10
|
||||
|
@ -42,10 +41,8 @@ def test_scan():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), np.cumsum(a_np, axis=0))
|
||||
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("opencl")
|
||||
|
||||
|
||||
|
|
|
@ -99,9 +99,9 @@ def test_gemm():
|
|||
|
||||
# correctness
|
||||
def check_device(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
f = tvm.build(s, [A, B, C], device, host,
|
||||
max_auto_unroll_step=max_auto_unroll_step)
|
||||
|
|
|
@ -29,17 +29,16 @@ def test_add_pipeline():
|
|||
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
|
||||
|
||||
def check_target(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
mhost = tvm.codegen.build_module(fsplits[0], host)
|
||||
mdev = tvm.codegen.build_module(fsplits[1:], device)
|
||||
mhost.import_module(mdev)
|
||||
code = mdev.get_source()
|
||||
f = mhost.entry_func
|
||||
|
||||
# launch the kernel.
|
||||
n = 1027
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
|
||||
|
@ -50,11 +49,11 @@ def test_add_pipeline():
|
|||
c.asnumpy(), a.asnumpy() + b.asnumpy())
|
||||
|
||||
def check_module_save(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
|
||||
ctx = tvm.context(device, 0)
|
||||
fmt = "ptx" if device == "cuda" else "cl"
|
||||
mhost = tvm.codegen.build_module(fsplits[0], host)
|
||||
mdev = tvm.codegen.build_module(fsplits[1:], device)
|
||||
|
|
|
@ -19,7 +19,7 @@ def test_add_pipeline():
|
|||
s = tvm.create_schedule(C.op)
|
||||
|
||||
def check_llvm():
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], "llvm")
|
||||
|
@ -51,7 +51,7 @@ def test_pack_buffer_simple():
|
|||
|
||||
|
||||
def check_target(target):
|
||||
if not tvm.codegen.enabled(target):
|
||||
if not tvm.module.enabled(target):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], target)
|
||||
|
@ -81,7 +81,7 @@ def test_pack_buffer_intermediate():
|
|||
s = tvm.create_schedule(C.op)
|
||||
|
||||
def check_target(target):
|
||||
if not tvm.codegen.enabled(target):
|
||||
if not tvm.module.enabled(target):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], target)
|
||||
|
|
|
@ -12,7 +12,7 @@ def test_llvm_add_pipeline():
|
|||
s[C].parallel(xo)
|
||||
s[C].vectorize(xi)
|
||||
def check_llvm():
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, B, C], "llvm")
|
||||
|
@ -30,7 +30,7 @@ def test_llvm_add_pipeline():
|
|||
|
||||
def test_llvm_flip_pipeline():
|
||||
def check_llvm(nn, base):
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
n = tvm.convert(nn)
|
||||
A = tvm.placeholder((n + base), name='A')
|
||||
|
@ -57,7 +57,7 @@ def test_llvm_flip_pipeline():
|
|||
|
||||
def test_llvm_madd_pipeline():
|
||||
def check_llvm(nn, base, stride):
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
n = tvm.convert(nn)
|
||||
A = tvm.placeholder((n + base, stride), name='A')
|
||||
|
@ -89,7 +89,7 @@ def test_llvm_temp_space():
|
|||
s = tvm.create_schedule(C.op)
|
||||
|
||||
def check_llvm():
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], "llvm")
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
|
||||
def run_jit(fapi, check):
|
||||
for target in ["llvm", "stackvm"]:
|
||||
if not tvm.codegen.enabled(target):
|
||||
if not tvm.module.enabled(target):
|
||||
continue
|
||||
f = tvm.codegen.build_module(fapi, target)
|
||||
s = f.get_source()
|
||||
|
|
|
@ -22,7 +22,7 @@ print("Finish runtime checking...")
|
|||
"""
|
||||
|
||||
def test_dso_module_load():
|
||||
if not tvm.codegen.enabled("llvm"):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
dtype = 'int64'
|
||||
temp = util.tempdir()
|
||||
|
@ -38,6 +38,7 @@ def test_dso_module_load():
|
|||
tvm.make.Load(dtype, Ab.data, i) + 1,
|
||||
i + 1))
|
||||
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
|
||||
fapi = tvm.ir_pass.LowerPackedCall(fapi)
|
||||
m = tvm.codegen.build_module(fapi, "llvm")
|
||||
for name in names:
|
||||
m.save(name)
|
||||
|
|
|
@ -2,14 +2,14 @@ import tvm
|
|||
import numpy as np
|
||||
|
||||
def enabled_ctx_list():
|
||||
if tvm.module.enabled("opencl"):
|
||||
tvm.module.init_opencl()
|
||||
|
||||
ctx_list = [('cpu', tvm.cpu(0)),
|
||||
('gpu', tvm.gpu(0)),
|
||||
('cl', tvm.opencl(0)),
|
||||
('cpu', tvm.vpi(0))]
|
||||
ctx_list = [x[1] for x in ctx_list if tvm.module.enabled(x[0])]
|
||||
('metal', tvm.metal(0)),
|
||||
('vpi', tvm.vpi(0))]
|
||||
for k, v in ctx_list:
|
||||
assert tvm.context(k, 0) == v
|
||||
ctx_list = [x[1] for x in ctx_list if x[1].exist]
|
||||
return ctx_list
|
||||
|
||||
ENABLED_CTX_LIST = enabled_ctx_list()
|
||||
|
@ -29,7 +29,8 @@ def test_nd_create():
|
|||
np.testing.assert_equal(x, y.asnumpy())
|
||||
np.testing.assert_equal(x, z.asnumpy())
|
||||
# no need here, just to test usablity
|
||||
tvm.nd.sync(ctx)
|
||||
ctx.sync()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nd_create()
|
||||
|
|
|
@ -18,16 +18,17 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
|
|||
fi
|
||||
|
||||
cp make/config.mk config.mk
|
||||
echo "USE_CUDA=0" >> config.mk
|
||||
echo "ENABLE_CUDA=0" >> config.mk
|
||||
|
||||
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
|
||||
echo "USE_OPENCL=1" >> config.mk
|
||||
echo "ENABLE_OPENCL=1" >> config.mk
|
||||
echo "ENABLE_METAL=1" >> config.mk
|
||||
else
|
||||
# use g++-4.8 for linux
|
||||
if [ ${CXX} == "g++" ]; then
|
||||
export CXX=g++-4.8
|
||||
fi
|
||||
echo "USE_OPENCL=0" >> config.mk
|
||||
echo "ENABLE_OPENCL=0" >> config.mk
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then
|
||||
|
|
|
@ -40,13 +40,12 @@ def test_add_pipeline():
|
|||
fapi = lower(s, [A, B, C], "myadd")
|
||||
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
|
||||
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
|
||||
print(fsplits[1].body)
|
||||
print("------")
|
||||
|
||||
def check_target(device, host="stackvm"):
|
||||
if not tvm.codegen.enabled(host):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.codegen.enabled(device):
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
ctx = tvm.vpi(0)
|
||||
mhost = tvm.codegen.build_module(fsplits[0], host)
|
||||
|
|
|
@ -228,7 +228,6 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
|
|||
# The following codeblocks generate opencl code, creates array on opencl
|
||||
# device, and verifies the correctness of the code.
|
||||
#
|
||||
tvm.module.init_opencl()
|
||||
fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd")
|
||||
print("------opencl code------")
|
||||
print(fadd_cl.imported_modules[0].get_source())
|
||||
|
|
|
@ -65,7 +65,6 @@ print(fcuda.imported_modules[0].get_source())
|
|||
# We can find that the code works for both CUDA and opencl.
|
||||
# The same tvm.exp can also be used for float64 data types.
|
||||
#
|
||||
tvm.module.init_opencl()
|
||||
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
|
||||
print(fopencl.imported_modules[0].get_source())
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче