[CODEGEN/RUNTIME] Metal support, runtime improvement. (#111)

* [CODEGEN/RUNTIME] Metal support, runtime improvement.

* Fix case when no device is available
This commit is contained in:
Tianqi Chen 2017-05-02 10:14:45 -07:00 коммит произвёл GitHub
Родитель 9ba40dc0fe
Коммит 706f9b6f7e
69 изменённых файлов: 1939 добавлений и 623 удалений

Просмотреть файл

@ -18,7 +18,9 @@ all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a
SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
METAL_SRC = $(wildcard src/runtime/metal/*.mm)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
METAL_OBJ = $(patsubst src/%.mm, build/%.o, $(METAL_SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
@ -29,6 +31,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
export OBJCFLAGS= -fobjc-arc
ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc
@ -36,7 +39,7 @@ ifdef CUDA_PATH
LDFLAGS += -L$(CUDA_PATH)/lib64
endif
ifeq ($(USE_CUDA), 1)
ifeq ($(ENABLE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc
else
@ -45,9 +48,10 @@ endif
FRAMEWORKS=
ifeq ($(USE_OPENCL), 1)
UNAME_S := $(shell uname -s)
ifeq ($(ENABLE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
FRAMEWORKS += -framework OpenCL
else
@ -57,10 +61,20 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif
ifeq ($(ENABLE_METAL), 1)
CFLAGS += -DTVM_METAL_RUNTIME=1
LDFLAGS += -lObjc
ALL_DEP += $(METAL_OBJ)
RUNTIME_DEP += $(METAL_OBJ)
FRAMEWORKS += -framework Metal -framework Foundation
else
CFLAGS += -DTVM_METAL_RUNTIME=0
endif
# llvm configuration
LLVM_CONFIG=llvm-config
ifeq ($(USE_LLVM), 1)
ifeq ($(ENABLE_LLVM), 1)
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
@ -87,6 +101,11 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
build/%.o: src/%.mm
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) $(OBJCFLAGS) -c $(CFLAGS) -c $< -o $@
lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
@ -105,7 +124,7 @@ LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)
cpplint:
python2 dmlc-core/scripts/lint.py tvm cpp include src verilog
python dmlc-core/scripts/lint.py tvm cpp include src verilog
pylint:
pylint python/tvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc

Просмотреть файл

@ -12,4 +12,5 @@ tvm.ndarray
.. autofunction:: tvm.cpu
.. autofunction:: tvm.gpu
.. autofunction:: tvm.opencl
.. autofunction:: tvm.metal
.. autofunction:: tvm.ndarray.array

Просмотреть файл

@ -31,13 +31,6 @@ using runtime::TVMRetValue;
*/
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target);
/*!
* \param target The target to be queried.
* \return Whether target is enabled.
*/
bool TargetEnabled(const std::string& target);
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -41,6 +41,8 @@ typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */
typedef enum {
/*! \brief Metal buffer. */
kMetal = 8,
/*! \brief Simulated on board RAM */
kVPI = 9
} TVMDeviceExtType;
@ -360,7 +362,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
const char*** out_array);
// Array related apis for quick proptying
// Array related apis for quick proptyping
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.

Просмотреть файл

@ -20,4 +20,11 @@
#define TVM_OPENCL_RUNTIME 0
#endif
/*!
*\brief whether to use metal runtime
*/
#ifndef TVM_METAL_RUNTIME
#define TVM_METAL_RUNTIME 0
#endif
#endif // TVM_RUNTIME_CONFIG_H_

Просмотреть файл

@ -145,8 +145,8 @@ class ModuleNode {
namespace symbol {
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Local function to set the device during API entry. */
constexpr const char* tvm_entry_setdevice = "__tvm_entry_setdevice";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */

Просмотреть файл

@ -34,16 +34,19 @@ ADD_CFLAGS =
# matrix computation libraries for CPU/GPU
#---------------------------------------------
# whether use CUDA during compile
USE_CUDA = 1
# whether enable CUDA during compile
ENABLE_CUDA = 1
# whether use OpenCL during compile
USE_OPENCL = 0
# whether enable OpenCL during compile
ENABLE_OPENCL = 0
# whether enable Metal during compile
ENABLE_METAL = 0
# whether build with LLVM support
# This requires llvm-config to be in your PATH
# Requires LLVM version >= 4.0
USE_LLVM = 0
ENABLE_LLVM = 0
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable.

Просмотреть файл

@ -16,7 +16,7 @@ from . import node
from . import ir_builder
from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, cl, vpi
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__

Просмотреть файл

@ -4,7 +4,8 @@
from __future__ import absolute_import
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array
from .base import _LIB, check_call, c_array, string_types
from .. import _api_internal
tvm_shape_index_t = ctypes.c_int64
@ -63,22 +64,62 @@ class TVMType(ctypes.Structure):
def __ne__(self, other):
return not self.__eq__(other)
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("device_id", ctypes.c_int),
("device_type", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl',
8 : 'metal',
9 : 'vpi'
}
def __init__(self, device_id, device_type):
STR2MASK = {
'cpu': 1,
'gpu': 2,
'cuda': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
'vpi': 9
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
self.device_id = device_id
self.device_type = device_type
@property
def exist(self):
"""Whether this device exist."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 0) != 0
@property
def max_threads_per_block(self):
"""Maximum number of threads on each block."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 1)
@property
def warp_size(self):
"""Number of threads that executes in concurrent."""
return _api_internal._GetDeviceAttr(
self.device_type, self.device_id, 2)
def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self, None))
def __eq__(self, other):
return (isinstance(other, TVMContext) and
self.device_id == other.device_id and
self.device_type == other.device_type)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.device_type], self.device_id)
@ -97,48 +138,38 @@ class TVMArray(ctypes.Structure):
TVMArrayHandle = ctypes.POINTER(TVMArray)
def cpu(dev_id=0):
"""Construct a CPU device
def context(dev_type, dev_id=0):
"""Construct a TVM context with given device type and id.
Parameters
----------
dev_type: int or str
The device type mask or name of the device.
dev_id : int, optional
The integer device id
Returns
-------
ctx: TVMContext
The corresponding context.
Examples
--------
Context can be used to create reflection of context by
string representation of the device type.
.. code-block:: python
assert tvm.context("cpu", 1) == tvm.cpu(1)
assert tvm.context("gpu", 0) == tvm.gpu(0)
assert tvm.context("cuda", 0) == tvm.gpu(0)
"""
return TVMContext(dev_id, 1)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(dev_id, 2)
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(dev_id, 4)
def vpi(dev_id=0):
"""Construct a VPI simulated device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(dev_id, 9)
if isinstance(dev_type, string_types):
if not dev_type in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type]
return TVMContext(dev_type, dev_id)
def numpyasarray(np_data):
@ -154,10 +185,11 @@ def numpyasarray(np_data):
arr.dtype = TVMType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = cpu(0)
arr.ctx = context(1, 0)
return arr, shape
def empty(shape, dtype="float32", ctx=cpu(0)):
def empty(shape, dtype="float32", ctx=context(1, 0)):
"""Create an empty array given shape and device
Parameters
@ -185,17 +217,6 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
return _CLASS_NDARRAY(handle)
def sync(ctx):
"""Synchronize all the context
Parameters
----------
ctx : TVMContext
The context to be synced
"""
check_call(_LIB.TVMSynchronize(ctx, None))
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]

Просмотреть файл

@ -10,6 +10,7 @@ from . import schedule
from . import expr
from . import ir_pass
from . import collections
from . import module
from . import codegen
@ -149,7 +150,7 @@ def build(sch,
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0])
if len(fsplits) > 1:
if not target_host:
target_host = "llvm" if codegen.enabled("llvm") else "stackvm"
target_host = "llvm" if module.enabled("llvm") else "stackvm"
mhost = codegen.build_module(fsplits[0], target_host)
if target:
mdev = codegen.build_module(fsplits[1:], target)

Просмотреть файл

@ -19,21 +19,4 @@ def build_module(lowered_func, target):
"""
return _Build(lowered_func, target)
def enabled(target):
"""Whether target is enabled for codegen.
Parameters
----------
target : str
The target module type.
Returns
-------
enabled : boolean
Whether the target module is enabled.
"""
return _Enabled(target)
_init_api("tvm.codegen")

Просмотреть файл

@ -24,6 +24,8 @@ def create_shared(path_target, objects,
"""
cmd = [cc]
cmd += ["-shared"]
if sys.platform == "darwin":
cmd += ["-undefined", "dynamic_lookup"]
cmd += ["-o", path_target]
cmd += objects
if options:

Просмотреть файл

@ -0,0 +1,50 @@
# pylint: disable=invalid-name
"""Utility to invoke metal compiler in the CLI system"""
from __future__ import absolute_import as _abs
import sys
import subprocess
from . import util
def compile_source(code, path_target=None):
"""Compile metal with CLI tool from env.
Parameters
----------
code : str
The cuda code.
path_target : str, optional
Output file.
Return
------
metallib : bytearray
The bytearray of the metallib
"""
temp = util.tempdir()
temp_code = temp.relpath("my_lib.metal")
temp_ir = temp.relpath("my_lib.air")
temp_target = temp.relpath("my_lib.metallib")
with open(temp_code, "w") as out_file:
out_file.write(code)
file_target = path_target if path_target else temp_target
cmd1 = ["xcrun", "-sdk", "macosx", "metal", "-O3"]
cmd1 += [temp_code, "-o", temp_ir]
cmd2 = ["xcrun", "-sdk", "macosx", "metallib"]
cmd2 += [temp_ir, "-o", file_target]
proc = subprocess.Popen(
' '.join(cmd1) + ";" + ' '.join(cmd2),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(out)
sys.stderr.flush()
libbin = None
else:
libbin = bytearray(open(file_target, "rb").read())
return libbin

Просмотреть файл

@ -8,11 +8,9 @@ from __future__ import absolute_import as _abs
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import cpu, gpu, opencl, vpi, empty, sync
from ._ffi.ndarray import context, empty
from ._ffi.ndarray import _set_class_ndarray
cl = opencl
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
@ -27,6 +25,64 @@ class NDArray(NDArrayBase):
pass
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(2, dev_id)
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(4, dev_id)
def metal(dev_id=0):
"""Construct a metal device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(8, dev_id)
def vpi(dev_id=0):
"""Construct a VPI simulated device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(9, dev_id)
cl = opencl
mtl = metal
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.

Просмотреть файл

@ -20,10 +20,5 @@ TVM_REGISTER_API("codegen._Build")
*ret = Build(args[0], args[1]);
}
});
TVM_REGISTER_API("codegen._Enabled")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TargetEnabled(args[0]);
});
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -0,0 +1,34 @@
/*!
* Copyright (c) 2017 by Contributors
* Common build utilities
* \file build_common.h
*/
#ifndef TVM_CODEGEN_BUILD_COMMON_H_
#define TVM_CODEGEN_BUILD_COMMON_H_
#include <tvm/codegen.h>
#include <unordered_map>
#include <string>
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
// Extract function information from device function.
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const Array<LoweredFunc>& funcs) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return fmap;
}
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_BUILD_COMMON_H_

Просмотреть файл

@ -6,11 +6,10 @@
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include "./codegen_cuda.h"
#include "./build_common.h"
#if TVM_CUDA_RUNTIME
#include <nvrtc.h>
#include "../runtime/meta_data.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"
@ -71,19 +70,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
} else {
ptx = NVRTCCompile(code);
}
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return CUDAModuleCreate(ptx, fmt, fmap, code);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code);
}
TVM_REGISTER_API("codegen.build_cuda")

Просмотреть файл

@ -0,0 +1,47 @@
/*!
* Copyright (c) 2017 by Contributors
* Build metal modules from source.
* \file build_metal.cc
*/
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include "./codegen_metal.h"
#include "./build_common.h"
#if TVM_METAL_RUNTIME
#include "../runtime/metal/metal_module.h"
#endif // TVM_METAL_RUNTIME
namespace tvm {
namespace codegen {
runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenMetal cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
#if TVM_METAL_RUNTIME
std::string fmt = "metal";
std::string source = "";
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
source = code;
code = (*f)(code).operator std::string();
fmt = "metallib";
}
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
#else
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return SourceModuleCreate(code, "metal");
#endif // TVM_METAL_RUNTIME
}
TVM_REGISTER_API("codegen.build_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildMetal(args[0]);
});
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -6,39 +6,29 @@
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include "./codegen_opencl.h"
#include "./build_common.h"
#if TVM_OPENCL_RUNTIME
#include "../runtime/meta_data.h"
#include "../runtime/opencl/opencl_common.h"
#include "../runtime/opencl/opencl_module.h"
#endif // TVM_OPENCL_RUNTIME
namespace tvm {
namespace codegen {
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return OpenCLModuleCreate(code, "cl", fmap);
#if TVM_OPENCL_RUNTIME
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
#else
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return SourceModuleCreate(code, "cl");
#endif // TVM_OPENCL_RUNTIME
}
TVM_REGISTER_API("codegen.build_opencl")
@ -47,4 +37,3 @@ TVM_REGISTER_API("codegen.build_opencl")
});
} // namespace codegen
} // namespace tvm
#endif // TVM_OPENCL_RUNTIME

Просмотреть файл

@ -32,10 +32,5 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
return m;
}
bool TargetEnabled(const std::string& target) {
std::string build_f_name = "codegen.build_" + target;
return runtime::Registry::Get(build_f_name) != nullptr;
}
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -24,7 +24,6 @@ void CodeGenC::InitFuncState(LoweredFunc f) {
void CodeGenC::AddFunction(LoweredFunc f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
GetUniqueName("_");
// add to alloc buffer type.
@ -41,8 +40,8 @@ void CodeGenC::AddFunction(LoweredFunc f) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}
stream << ' ';
}
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
@ -246,9 +245,9 @@ void CodeGenC::PrintVecStore(const Variable* buffer,
stream << ref << " = " << value << ";\n";
}
void CodeGenC::PrintThreadIndexExpr(
std::string thread_tag, std::ostream& os) { // NOLINT(*)
os << thread_tag;
void CodeGenC::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = iv->thread_tag;
}
void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
@ -674,6 +673,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
const Variable* buffer = op->buffer_var.as<Variable>();
std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->type, stream);
stream << ' '<< vid << '['
<< constant_size << "];\n";
@ -687,13 +687,7 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
this->PrintIndent();
PrintType(iv->var.type(), stream);
stream << ' '
<< AllocVarID(iv->var.get())
<< " = ";
PrintThreadIndexExpr(iv->thread_tag, stream);
stream << ";\n";
BindThreadIndex(iv);
}
}
} else if (op->attr_key == ir::attr::storage_scope) {
@ -740,7 +734,11 @@ void CodeGenC::VisitStmt_(const For* op) {
void CodeGenC::VisitStmt_(const IfThenElse* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (" << cond << ") {\n";
if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
stream << "if " << cond << " {\n";
} else {
stream << "if (" << cond << ") {\n";
}
int then_scope = BeginScope();
PrintStmt(op->then_case);
this->EndScope(then_scope);

Просмотреть файл

@ -121,12 +121,10 @@ class CodeGenC :
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param tag The tag in the thread.
* \param os The strean to output to
* \param IterVar iv The thread index to be binded;
*/
virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
virtual void PrintStorageSync(const Call* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
@ -169,12 +167,12 @@ class CodeGenC :
const std::string& target, const std::string& src, Type t) final;
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief set of volatile buf access */
std::unordered_set<const Variable*> volatile_buf_;
};

Просмотреть файл

@ -141,7 +141,9 @@ void CodeGenCUDA::PrintVecElemStore(
void CodeGenCUDA::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "shared") {
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
} else if (sync == "global") {
@ -182,7 +184,7 @@ void CodeGenCUDA::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global");
if (scope == "shared") {
os << "__shared__ ";
os << "__shared__";
}
}
@ -203,6 +205,17 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
}
}
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->type, os);
os << "(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -14,7 +14,7 @@
namespace tvm {
namespace codegen {
class CodeGenCUDA : public CodeGenC {
class CodeGenCUDA final : public CodeGenC {
public:
void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
@ -31,6 +31,8 @@ class CodeGenCUDA : public CodeGenC {
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const Evaluate *op) final;
private:

Просмотреть файл

@ -0,0 +1,203 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_metal.cc
*/
#include <tvm/runtime/config.h>
#include <tvm/packed_func_ext.h>
#include <vector>
#include <string>
#include "./codegen_metal.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
void CodeGenMetal::InitFuncState(LoweredFunc f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->args) {
if (arg.type().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
}
CodeGenMetal::CodeGenMetal() {
decl_stream << "#include <metal_stdlib>\n";
decl_stream << "using namespace metal;\n\n";
decl_stream << "union __TVMArgUnion {\n"
<< " int v_int;\n"
<< "};\n\n";
}
void CodeGenMetal::AddFunction(LoweredFunc f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
GetUniqueName("_");
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.type());
}
// Function header.
this->stream << "kernel void " << f->name << "(\n";
// Buffer arguments
size_t num_buffer = 0;
for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) {
Var v = f->args[i];
if (!v.type().is_handle()) break;
stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
CHECK(it != alloc_storage_scope_.end());
PrintStorageScope(it->second, stream);
stream << ' ';
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
stream << "*";
} else {
PrintType(v.type(), stream);
}
stream << ' ' << vid
<< " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
size_t nargs = f->args.size() - num_buffer;
std::string varg = GetUniqueName("arg");
if (nargs != 0) {
std::string arg_buf_type = f->name + "_args_t";
stream << " constant " << arg_buf_type << "& " << varg
<< " [[ buffer(" << num_buffer << ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < f->args.size(); ++i) {
Var v = f->args[i];
CHECK(!v.type().is_handle());
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
if (v.type().bits() == 32) {
decl_stream << " ";
PrintType(v.type(), decl_stream);
decl_stream << " " << vid << ";\n";
vref << varg << "." << vid;
} else {
// For non 32bit type, ref through arg union.
decl_stream << " __TVMArgUnion " << vid << ";\n";
vref << varg << "." << vid << ".v_";
PrintType(v.type(), vref);
}
var_idmap_[v.get()] = vref.str();
}
decl_stream << "};\n\n";
}
// Setup the thread group info.
CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
int work_dim = 0;
for (IterVar iv : f->thread_axis) {
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
work_dim = std::max(work_dim, scope.dim_index + 1);
}
if (work_dim != 0) {
// use ushort by default for now
stream << " ";
PrintType(UInt(16, work_dim), stream);
stream << " blockIdx [[threadgroup_position_in_grid]],\n";
stream << " ";
PrintType(UInt(16, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
// bind thread axis
for (IterVar iv : f->thread_axis) {
CHECK(!var_idmap_.count(iv->var.get()));
if (work_dim <= 1) {
var_idmap_[iv->var.get()] =
iv->thread_tag.substr(0, iv->thread_tag.length() - 2);
} else {
var_idmap_[iv->var.get()] = iv->thread_tag;
}
}
// the function scope.
stream << ") {\n";
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half"; break;
case 32: os << "float"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
os << "int"; return;
}
switch (t.bits()) {
case 8: os << "char"; break;
case 16: os << "short"; break;
case 32: os << "int"; break;
case 1: os << "bool"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
}
void CodeGenMetal::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
} else if (sync == "shared") {
this->PrintIndent();
this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n";
} else if (sync == "global") {
LOG(FATAL) << "global barrier not supported";
}
}
void CodeGenMetal::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "device";
} else if (scope == "shared") {
os << "threadgroup";
}
}
void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
PrintType(op->type, os);
os << "(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -0,0 +1,34 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_metal.h
* \brief Generate Metal device code.
*/
#ifndef TVM_CODEGEN_CODEGEN_METAL_H_
#define TVM_CODEGEN_CODEGEN_METAL_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
#include "./codegen_c.h"
namespace tvm {
namespace codegen {
class CodeGenMetal final : public CodeGenC {
public:
CodeGenMetal();
void AddFunction(LoweredFunc f);
// override print thread tag.
void PrintArgUnionDecl();
void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_METAL_H_

Просмотреть файл

@ -1,6 +1,6 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_cuda.cc
* \file codegen_opencl.cc
*/
#include <tvm/runtime/config.h>
#include <tvm/packed_func_ext.h>
@ -22,19 +22,20 @@ void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
}
void CodeGenOpenCL::AddFunction(LoweredFunc f) {
this->stream << " __kernel ";
this->stream << "__kernel ";
CodeGenC::AddFunction(f);
}
void CodeGenOpenCL::PrintThreadIndexExpr(
std::string tag, std::ostream& os) { // NOLINT(*)
runtime::ThreadScope ts = runtime::ThreadScope::make(tag);
void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
std::ostringstream os;
if (ts.rank == 1) {
os << "get_local_id(" << ts.dim_index << ")";
} else {
os << "get_global_id(" << ts.dim_index << ")"
<< " / get_local_size(" << ts.dim_index << ")";
os << "get_group_id(" << ts.dim_index << ")";
}
var_idmap_[iv->var.get()] = os.str();
}
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
@ -115,7 +116,9 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
void CodeGenOpenCL::PrintStorageSync(const Call* op) {
const std::string& sync = op->args[0].as<StringImm>()->value;
if (sync == "shared") {
if (sync == "warp") {
LOG(FATAL) << "warp sync not supported in opencl";
} else if (sync == "shared") {
this->PrintIndent();
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
} else if (sync == "global") {
@ -128,8 +131,20 @@ void CodeGenOpenCL::PrintStorageScope(
if (scope == "global") {
os << "__global";
} else if (scope == "shared") {
os << "__local ";
os << "__local";
}
}
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << '(';
PrintType(op->type, os);
os << ")(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -1,7 +1,7 @@
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_opencl.h
* \brief Utility to generate opencl code
* \brief Generate OpenCL device code.
*/
#ifndef TVM_CODEGEN_CODEGEN_OPENCL_H_
#define TVM_CODEGEN_CODEGEN_OPENCL_H_
@ -14,13 +14,12 @@
namespace tvm {
namespace codegen {
class CodeGenOpenCL : public CodeGenC {
class CodeGenOpenCL final : public CodeGenC {
public:
void AddFunction(LoweredFunc f);
// override print thread tag.
void InitFuncState(LoweredFunc f) final;
void PrintThreadIndexExpr(
std::string tag, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
@ -32,6 +31,8 @@ class CodeGenOpenCL : public CodeGenC {
// the address of load/store
void PrintVecAddr(const Variable* buffer, Type t,
Expr base, std::ostream& os); // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
};
} // namespace codegen

Просмотреть файл

@ -102,6 +102,12 @@ class CodeGenSourceBase {
int indent_{0};
};
/*!
* \brief Create a source module for viewing.
* \param code The code to be viewed.
* \param fmt The code. format.
*/
runtime::Module SourceModuleCreate(std::string code, std::string fmt);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_

Просмотреть файл

@ -0,0 +1,52 @@
/*!
* Copyright (c) 2017 by Contributors
* \file source_module.cc
* \brief Source code module, only for viewing
*/
#include <tvm/runtime/packed_func.h>
#include "./codegen_source_base.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
// Simulator function
class SourceModuleNode : public runtime::ModuleNode {
public:
SourceModuleNode(std::string code,
std::string fmt)
: code_(code), fmt_(fmt) {}
const char* type_key() const {
return "source";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "not implemented";
}
std::string GetSource(const std::string& format) final {
return code_;
}
private:
std::string code_;
std::string fmt_;
};
runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<SourceModuleNode> n =
std::make_shared<SourceModuleNode>(code, fmt);
return runtime::Module(n);
}
} // namespace codegen
} // namespace tvm

Просмотреть файл

@ -31,11 +31,7 @@ class VerilogModuleNode : public runtime::ModuleNode {
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
CHECK(sptr_to_self.get() == this);
if (name == runtime::symbol::tvm_entry_setdevice) {
return PackedFunc([](const TVMArgs& args, TVMRetValue* rv){});
}
CHECK(m_.fmap.count(name)) << "Cannot find function " << name << " in the module";
if (!m_.fmap.count(name)) return PackedFunc();
auto f = [sptr_to_self, name, this](const runtime::TVMArgs& args, TVMRetValue* rv) {
auto* fsim = runtime::Registry::Get("tvm_callback_verilog_simulator");
CHECK(fsim != nullptr)

Просмотреть файл

@ -16,7 +16,7 @@ namespace tvm {
namespace codegen {
/*! \brief Simulated device ram */
class VPIDeviceAPI : public runtime::DeviceAPI {
class VPIDeviceAPI final : public runtime::DeviceAPI {
public:
VPIDeviceAPI() {
static const size_t kAllocAlign = 32U;
@ -44,6 +44,12 @@ class VPIDeviceAPI : public runtime::DeviceAPI {
if (ptr + size >= ram_max_) return nullptr;
return (char*)(&ram_[0]) + ptr; // NOLINT(*)
}
void SetDevice(int dev_id) final {}
void GetAttr(int dev_id, runtime::DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == runtime::kExist) {
*rv = 1;
}
}
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
static const size_t kAllocAlign = 32U;
// always align to 32 bytes at least.
@ -80,16 +86,18 @@ class VPIDeviceAPI : public runtime::DeviceAPI {
free_blocks_.insert({b.size, head});
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
if (static_cast<int>(ctx_from.device_type) == kVPI) {
from = RealAddr(from, size);
from = RealAddr(static_cast<const char*>(from) + from_offset, size);
}
if (static_cast<int>(ctx_to.device_type) == kVPI) {
to = RealAddr(to, size);
to = RealAddr(static_cast<char*>(to) + to_offset, size);
}
memcpy(to, from, size);
}

Просмотреть файл

@ -156,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
seq.emplace_back(Store::make(
shared_buf, value,
BufIndex(reduce_index, group_index, reduce_extent)));
seq.emplace_back(SyncThread());
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(
combiner, value.type(), shared_buf,
reduce_index, group_index, reduce_extent, threadx_extent));
@ -202,7 +202,7 @@ class ThreadAllreduceBuilder : public IRMutator {
reduce_align = reduce_align >> 1;
Expr cond = reduce_index < (reduce_extent - reduce_align);
seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread());
seq.emplace_back(SyncThread("shared"));
}
CHECK(threadx_extent >= 1 && warp_size_ >= 1);
// normal synchronization
@ -211,7 +211,7 @@ class ThreadAllreduceBuilder : public IRMutator {
reduce_align = reduce_align >> 1;
Expr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread());
seq.emplace_back(SyncThread("shared"));
}
// in warp synchronization.
std::vector<Stmt> in_warp_seq;
@ -219,6 +219,7 @@ class ThreadAllreduceBuilder : public IRMutator {
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
in_warp_seq.emplace_back(freduce(reduce_align));
seq.emplace_back(SyncThread("warp"));
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = MergeSeq(in_warp_seq);
@ -249,10 +250,10 @@ class ThreadAllreduceBuilder : public IRMutator {
return ret;
}
// sync thread op.
static Stmt SyncThread() {
static Stmt SyncThread(const std::string& sync) {
return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_storage_sync,
{StringImm::make("shared")},
{StringImm::make(sync)},
Call::Intrinsic));
}
// The local buffer index.

Просмотреть файл

@ -244,6 +244,12 @@ LoweredFunc MakeAPI(Stmt body,
node, attr::device_context_id, device_id, nop));
seq_init.push_back(AttrStmt::make(
node, attr::device_context_type, device_type, nop));
Stmt set_device = IfThenElse::make(
device_type != kCPU, Evaluate::make(Call::make(
Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic)));
body = Block::make(set_device, body);
}
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);

Просмотреть файл

@ -162,20 +162,6 @@ class HostDeviceSplitter : public IRMutator {
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
if (f->is_packed_func && device_funcs_.size() != 0) {
// insert auto set device from device function.
Array<Expr> args = {StringImm::make(runtime::symbol::tvm_entry_setdevice)};
for (Var arg : f->args) {
args.push_back(arg);
}
n->body = Block::make(
Evaluate::make(Call::make(
Int(32), intrinsic::tvm_call_packed,
args, Call::Intrinsic)),
n->body);
}
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
@ -193,14 +179,21 @@ class HostDeviceSplitter : public IRMutator {
m.visit_thread_extent_ = false;
n->body = m.Mutate(body);
n->name = os.str();
n->args = m.undefined_;
n->thread_axis = m.thread_axis_;
// improve the handle data type
for (Var arg : n->args) {
auto it = handle_data_type_.find(arg.get());
if (it != handle_data_type_.end()) {
n->handle_data_type.Set(arg, it->second);
// Strictly order the arguments: Var pointers, positional arguments.
for (Var v : m.undefined_) {
if (v.type().is_handle()) {
n->args.push_back(v);
// mark handle data type.
auto it = handle_data_type_.find(v.get());
if (it != handle_data_type_.end()) {
n->handle_data_type.Set(v, it->second);
}
}
}
for (Var v : m.undefined_) {
if (!v.type().is_handle()) {
n->args.push_back(v);
}
}
LoweredFunc f_device(n);
@ -209,7 +202,6 @@ class HostDeviceSplitter : public IRMutator {
for (Var arg : n->args) {
call_args.push_back(arg);
}
for (Expr ext : m.thread_extent_) {
call_args.push_back(ext);
}

Просмотреть файл

@ -54,14 +54,17 @@ class StorageSyncPlanner : public IRVisitor {
if (!in_device_env_) return;
if (const Call* call = op->value.as<Call>()) {
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
StorageScope scope = StorageScope::make(call->args[0].as<StringImm>()->value);
if (scope.rank <= sync_scope_.rank) {
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.access.emplace_back(
AccessEntry(nullptr, Expr(), kSync, scope));
// push to the scope
scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear();
const std::string& s = call->args[0].as<StringImm>()->value;
if (s != "warp") {
StorageScope scope = StorageScope::make(s);
if (scope.rank <= sync_scope_.rank) {
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.access.emplace_back(
AccessEntry(nullptr, Expr(), kSync, scope));
// push to the scope
scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear();
}
}
}
}

Просмотреть файл

@ -25,8 +25,11 @@ class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(TVMContext ctx) {
return Global()->GetAPI(ctx.device_type);
static DeviceAPI* Get(const TVMContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
private:
@ -42,20 +45,25 @@ class DeviceAPIManager {
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(DLDeviceType type) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
std::string factory = "device_api." + DeviceName(type);
auto* f = Registry::Get(factory);
CHECK(f != nullptr)
<< "Device API " << DeviceName(type) << " is not enabled.";
void* ptr = (*f)();
api_[type] = static_cast<DeviceAPI*>(ptr);
return api_[type];
}
DeviceAPI* GetAPI(int type, bool allow_missing);
};
DeviceAPI* DeviceAPIManager::GetAPI(int type, bool allow_missing) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
std::string factory = "device_api." + DeviceName(type);
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << DeviceName(type) << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
api_[type] = static_cast<DeviceAPI*>(ptr);
return api_[type];
}
inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray();
@ -352,8 +360,9 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
<< "Can not copy across different ctx types directly";
}
DeviceAPIManager::Get(ctx)->CopyDataFromTo(
from->data, to->data, from_size,
from->ctx, to->ctx, stream);
from->data, from->byte_offset,
to->data, to->byte_offset,
from_size, from->ctx, to->ctx, stream);
API_END();
}
@ -362,3 +371,29 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
int dev_type = args[0];
int dev_id = args[1];
DeviceAPIManager::Get(dev_type)->SetDevice(dev_id);
});
// set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int dev_type = args[0];
int dev_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(dev_type, true);
if (api != nullptr) {
api->GetAttr(dev_id, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(dev_type)->GetAttr(dev_id, kind, ret);
}
});

Просмотреть файл

@ -11,8 +11,14 @@
namespace tvm {
namespace runtime {
class CPUDeviceAPI : public DeviceAPI {
class CPUDeviceAPI final : public DeviceAPI {
public:
void SetDevice(int dev_id) final {}
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) {
*rv = 1;
}
}
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
void* ptr;
#if _MSC_VER
@ -34,12 +40,16 @@ class CPUDeviceAPI : public DeviceAPI {
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
memcpy(to, from, size);
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {

Просмотреть файл

@ -15,8 +15,33 @@
namespace tvm {
namespace runtime {
class CUDADeviceAPI : public DeviceAPI {
class CUDADeviceAPI final : public DeviceAPI {
public:
void SetDevice(int dev_id) final {
CUDA_CALL(cudaSetDevice(dev_id));
}
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final {
int value;
switch (kind) {
case kExist:
value = (
cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, dev_id)
== cudaSuccess);
break;
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, dev_id));
break;
}
case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute(
&value, cudaDevAttrWarpSize, dev_id));
break;
}
}
*rv = value;
}
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U)
@ -32,12 +57,16 @@ class CUDADeviceAPI : public DeviceAPI {
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {

Просмотреть файл

@ -14,7 +14,7 @@
#include <string>
#include <mutex>
#include "./cuda_common.h"
#include "../void_addr_args.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
@ -214,41 +214,13 @@ class CUDAPrepGlobalBarrier {
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
};
void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 3);
TVMValue* values = static_cast<TVMValue*>(args[0].operator void*());
int* type_codes = static_cast<int*>(args[1].operator void*());
int num_args = args[2].operator int();
int device_id = -1;
for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.device_type, kGPU)
<< "All operands need to be GPU";
if (device_id == -1) {
device_id = ctx.device_id;
} else {
CHECK_EQ(device_id, ctx.device_id)
<< "Operands comes from different devices ";
}
}
}
CHECK_NE(device_id, -1)
<< "Cannot detect device id from list";
CUDA_CALL(cudaSetDevice(device_id));
}
PackedFunc CUDAModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
if (name == symbol::tvm_entry_setdevice) {
return PackedFunc(AutoSetCUDADevice);
} else if (name == symbol::tvm_prepare_global_barrier) {
if (name == symbol::tvm_prepare_global_barrier) {
return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self));
}
auto it = fmap_.find(name);
@ -256,7 +228,7 @@ PackedFunc CUDAModuleNode::GetFunction(
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
return PackFromVoidAddrArgs(f, info.arg_types);
return PackFuncVoidAddr(f, info.arg_types);
}
Module CUDAModuleCreate(

Просмотреть файл

@ -13,10 +13,29 @@
namespace tvm {
namespace runtime {
enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2
};
class DeviceAPI {
public:
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Set the environment device id to dev_id
* \param dev_id The device id.
* \return The allocated device pointer
*/
virtual void SetDevice(int dev_id) = 0;
/*!
* \brief Get attribute of specified device.
* \param dev_id The device id
* \param kind The result kind
* \param rv The return value.
*/
virtual void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
@ -36,13 +55,18 @@ class DeviceAPI {
* \brief copy data from one place to another
* \param dev The device to perform operation.
* \param from The source array.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param size The size of the memory
* \param ctx_from The source context
* \param ctx_to The target context
* \param stream Optional stream object.
*/
virtual void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
@ -59,11 +83,12 @@ class DeviceAPI {
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(DLDeviceType type) {
switch (static_cast<int>(type)) {
inline std::string DeviceName(int type) {
switch (type) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
case kMetal: return "metal";
case kVPI: return "vpi";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}

Просмотреть файл

@ -30,7 +30,6 @@ struct FunctionInfo {
void Save(dmlc::JSONWriter *writer) const;
void Load(dmlc::JSONReader *reader);
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_META_DATA_H_

Просмотреть файл

@ -0,0 +1,98 @@
/*!
* Copyright (c) 2017 by Contributors
* \file metal_common.h
* \brief Metal common header
*/
#ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_
#define TVM_RUNTIME_METAL_METAL_COMMON_H_
#import <Metal/MTLBuffer.h>
#import <Metal/MTLCommandQueue.h>
#import <Metal/MTLCommandBuffer.h>
#import <Metal/MTLBlitCommandEncoder.h>
#import <Metal/MTLDevice.h>
#import <Metal/MTLLibrary.h>
#include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/logging.h>
#include <mutex>
#include <string>
#include <vector>
#include "../device_api.h"
namespace tvm {
namespace runtime {
namespace metal {
/*!
* \brief Process global Metal workspace.
*/
class MetalWorkspace final : public DeviceAPI {
public:
// the devices
std::vector<id<MTLDevice> > devices;
// the queues
std::vector<id<MTLCommandQueue> > queues;
// Warp size constant
std::vector<int> warp_size;
// Whether it is initialized.
bool initialized_{false};
// the mutex for initialization
std::mutex mutex;
// Get command queue for given context.
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal);
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid Metal device_id=" << ctx.device_id;
return queues[ctx.device_id];
}
// Get device for given context
id<MTLDevice> GetDevice(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal);
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size())
<< "Invalid Metal device_id=" << ctx.device_id;
return devices[ctx.device_id];
}
// Initialize workspace
// Return false if already initialized, otherwise return true.
void Init();
// override device API
void SetDevice(int dev_id) final;
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
size_t from_size,
void* to,
size_t to_size,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
// get the global workspace
static MetalWorkspace* Global();
};
/*! \brief Thread local workspace */
class MetalThreadEntry {
public:
/*! \brief The current context */
TVMContext context;
/*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer> > temp_buffer_;
MetalThreadEntry() {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kMetal);
}
// Get temp buffer with at least size under ctx.
id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size);
// get the global workspace
static MetalThreadEntry* ThreadLocal();
};
} // namespace metal
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_METAL_METAL_COMMON_H_

Просмотреть файл

@ -0,0 +1,240 @@
/*!
* Copyright (c) 2017 by Contributors
* \file metal_device_api.mm
*/
#include "./metal_common.h"
#if TVM_METAL_RUNTIME
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
namespace tvm {
namespace runtime {
namespace metal {
MetalWorkspace* MetalWorkspace::Global() {
static MetalWorkspace inst;
return &inst;
}
void MetalWorkspace::GetAttr(
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(dev_id);
if (kind == kExist) {
*rv = int(index< devices.size());
return;
}
CHECK_LT(index, devices.size())
<< "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>(
[devices[dev_id] maxThreadsPerThreadgroup].width);
break;
}
case kWarpSize: {
// Set warp size to be 1 for safty reason.
*rv = 1;
break;
}
case kExist: break;
}
}
static const char* kDummyKernel = R"A0B0(
using namespace metal;
// Simple copy kernel
// Just to get threadExecutionWidth from current Metal API.
kernel void CopyKernel(
device float* dst [[buffer(0)]],
device float* src [[buffer(1)]],
ushort2 gid[[thread_position_in_grid]]) {
dst[gid.x] = src[gid.x];
}
)A0B0";
// Hack to get Warp size from device.
// Note that in Metal
// state.threadExecutionWidth can vary per kernel
// maybe due to resource constraint.
// so state.threadExecutionWidth can be smaller than warp size
// For safe issue, turn off warp-aware optimization for now
// But we keep this code.
int GetWarpSize(id<MTLDevice> dev) {
NSError* error_msg = nil;
id<MTLLibrary> lib =
[dev
newLibraryWithSource:
[NSString stringWithUTF8String:kDummyKernel]
options:nil
error:&error_msg];
CHECK(lib != nil) << error_msg;
id<MTLFunction> f =
[lib
newFunctionWithName:
[NSString stringWithUTF8String:"CopyKernel"]];
CHECK(f!= nil);
id<MTLComputePipelineState> state =
[dev
newComputePipelineStateWithFunction:f
error:&error_msg];
CHECK(state != nil) << error_msg;
return state.threadExecutionWidth;
}
void MetalWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex>(this->mutex);
if (initialized_) return;
initialized_ = true;
if (devices.size() != 0) return;
NSArray<id<MTLDevice>>* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
queues.push_back([d newCommandQueue]);
LOG(INFO) << "Intializing Metal device " << i
<< ", name=" << d.name;
warp_size.push_back(GetWarpSize(d));
}
}
void MetalWorkspace::SetDevice(int dev_id) {
MetalThreadEntry::ThreadLocal()->context.device_id = dev_id;
}
void* MetalWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment) {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// allocate buffer in GPU only mode.
id<MTLBuffer> buf = [
dev newBufferWithLength:size
options:MTLResourceStorageModePrivate];
// retain ARC to keep it alive before release.
return (__bridge_retained void*)(buf);
}
void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
// release the ptr.
CFBridgingRelease(ptr);
}
void MetalWorkspace::CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kMetal && to_dev_type == kMetal) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Metal disallow cross device copy.";
[encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
sourceOffset:from_offset
toBuffer:(__bridge id<MTLBuffer>)(to)
destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == kMetal && to_dev_type == kCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
->GetTempBuffer(ctx_from, size);
[encoder copyFromBuffer:from_buf
sourceOffset:from_offset
toBuffer:temp
destinationOffset:0
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([temp contents]),
size);
} else {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([from_buf contents]) + from_offset,
size);
}
} else if (from_dev_type == kCPU && to_dev_type == kMetal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode == MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
->GetTempBuffer(ctx_to, size);
memcpy([temp contents],
static_cast<const char*>(from) + from_offset,
size);
[encoder copyFromBuffer:temp
sourceOffset:0
toBuffer:to_buf
destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
} else {
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
}
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type
<< ", to=" << to_dev_type;
}
}
void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
[cb commit];
[cb waitUntilCompleted];
}
id<MTLBuffer> MetalThreadEntry::GetTempBuffer(TVMContext ctx, size_t size) {
if (temp_buffer_.size() <= static_cast<size_t>(ctx.device_id)) {
temp_buffer_.resize(ctx.device_id + 1, nil);
}
if (temp_buffer_[ctx.device_id] == nil ||
temp_buffer_[ctx.device_id].length < size) {
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
temp_buffer_[ctx.device_id] = [
dev newBufferWithLength:size
options:MTLStorageModeShared];
}
return temp_buffer_[ctx.device_id];
}
typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
MetalThreadEntry* MetalThreadEntry::ThreadLocal() {
return MetalThreadStore::Get();
}
TVM_REGISTER_GLOBAL("device_api.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MetalWorkspace::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace metal
} // namespace runtime
} // namespace tvm
#endif // TVM_METAL_RUNTIME

Просмотреть файл

@ -0,0 +1,36 @@
/*!
* Copyright (c) 2017 by Contributors
* \file metal_module.h
* \brief Execution handling of Metal kernels
*/
#ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_
#define TVM_RUNTIME_METAL_METAL_MODULE_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <vector>
#include <string>
#include "../meta_data.h"
namespace tvm {
namespace runtime {
/*! \brief Maximum number of GPU supported in MetalModule. */
static constexpr const int kMetalMaxNumDevice = 32;
/*!
* \brief create a metal module from data.
*
* \param data The data content.
* \param fmt The format of the data, can be "metal" or "metallib"
* \param fmap The map function information map of each function.
* \param source Optional, source file
*/
Module MetalModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_

Просмотреть файл

@ -0,0 +1,273 @@
/*!
* Copyright (c) 2017 by Contributors
* \file metal_module.cc
*/
#include "./metal_module.h"
#if TVM_METAL_RUNTIME
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <array>
#include <string>
#include <mutex>
#include "./metal_common.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
namespace tvm {
namespace runtime {
// Module to support thread-safe multi-GPU execution.
// cuModule is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class MetalModuleNode final :public runtime::ModuleNode {
public:
explicit MetalModuleNode(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
}
const char* type_key() const final {
return "metal";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
GetPipelineState(ctx.device_id, name);
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_;
if (source_.length() != 0) {
return source_;
} else if (fmt_ == "metal") {
return data_;
} else {
return "";
}
}
// get a CUfunction from primary context in device_id
id<MTLComputePipelineState> GetPipelineState(
size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
CHECK_LT(device_id, w->devices.size());
// start lock scope.
std::lock_guard<std::mutex> lock(mutex_);
if (finfo_.size() <= device_id) {
finfo_.resize(device_id + 1, DeviceEntry());
}
DeviceEntry& e = finfo_[device_id];
auto it = e.smap.find(func_name);
if (it != e.smap.end()) return it->second;
// compile
NSError* err_msg = nil;
if (e.lib == nil) {
if (fmt_ == "metal") {
e.lib = [
w->devices[device_id]
newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
options:nil
error:&err_msg];
if (err_msg != nil || e.lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:"
<< [[err_msg localizedDescription] UTF8String];
}
} else {
// Build from library.
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
auto data = dispatch_data_create(
data_.c_str(), data_.length(), q, ^{});
e.lib = [
w->devices[device_id]
newLibraryWithData:data
error:&err_msg];
if (err_msg != nil || e.lib == nil) {
LOG(FATAL) << "Fail to compile metal lib:"
<< [[err_msg localizedDescription] UTF8String];
}
}
}
id<MTLFunction> f = [
e.lib
newFunctionWithName:
[NSString stringWithUTF8String:func_name.c_str()]];
CHECK(f != nil) << "cannot find function " << func_name;
id<MTLComputePipelineState> state =
[w->devices[device_id]
newComputePipelineStateWithFunction:f
error:&err_msg];
CHECK(state != nil)
<< "cannot get state:" << " for function " << func_name
<< [[err_msg localizedDescription] UTF8String];
// The state.threadExecutionWidth can change dynamically according
// to the resource constraint in kernel, so it is not strictly hold
// Turn of warp aware optimziation for now.
// CHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
e.smap[func_name] = state;
return state;
}
private:
// device specific entry
struct DeviceEntry {
// library
id<MTLLibrary> lib = nil;
// state cache;
std::unordered_map<std::string, id<MTLComputePipelineState> > smap;
};
// the binary data
std::string data_;
// The format
std::string fmt_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The source
std::string source_;
// function information.
std::vector<DeviceEntry> finfo_;
// internal mutex when updating the module
std::mutex mutex_;
};
// a wrapped function class to get packed fucn.
class MetalWrappedFunc {
public:
// initialize the METAL function.
void Init(MetalModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
const std::string& func_name,
size_t num_buffer_args,
size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
w_ = metal::MetalWorkspace::Global();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
std::fill(scache_.begin(), scache_.end(), (id<MTLComputePipelineState>)nil);
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int dev_id = t->context.device_id;
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
}
// invoke the function with void arguments
void operator()(TVMArgs args,
TVMRetValue* rv,
const ArgUnion* pack_args) const {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->context.device_id;
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
id<MTLCommandQueue> queue = w_->GetCommandQueue(t->context);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
void* buf = args[i];
[encoder setBuffer:(__bridge id<MTLBuffer>)(buf) offset:0 atIndex:i];
}
if (num_pack_args_ != 0) {
[encoder setBytes:pack_args
length:num_pack_args_ * sizeof(ArgUnion)
atIndex:num_buffer_args_];
}
// launch
MTLSize dimGrid = MTLSizeMake(
wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
MTLSize dimBlock = MTLSizeMake(
wl.block_dim(0), wl.block_dim(1), wl.work_size[2]);
[encoder dispatchThreadgroups: dimGrid
threadsPerThreadgroup: dimBlock];
[encoder endEncoding];
[cb commit];
}
private:
// Reference to global workspace.
metal::MetalWorkspace* w_;
// internal module
MetalModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// The name of the function.
std::string func_name_;
// Number of buffer arguments
size_t num_buffer_args_;
// number of packed arguments.
size_t num_pack_args_;
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
};
PackedFunc MetalModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name,
num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
return PackFuncNonBufferArg(f, info.arg_types);
}
Module MetalModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
w->Init();
std::shared_ptr<MetalModuleNode> n =
std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
return Module(n);
}
// Load module from module.
Module MetalModuleLoad(const std::string& file_name,
const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return MetalModuleCreate(data, fmt, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = MetalModuleLoad(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_METAL_RUNTIME

Просмотреть файл

@ -84,17 +84,25 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
}
bool RuntimeEnabled(const std::string& target) {
std::string load_f_name;
std::string f_name;
if (target == "cpu") {
return true;
} else if (target == "cuda" || target == "gpu") {
load_f_name = "module.loadfile_ptx";
f_name = "device_api.gpu";
} else if (target == "cl" || target == "opencl") {
load_f_name = "module.loadfile_cl";
f_name = "device_api.opencl";
} else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal";
} else if (target == "stackvm") {
f_name = "codegen.build_stackvm";
} else if (target == "llvm") {
f_name = "codegen.build_llvm";
} else if (target == "vpi" || target == "verilog") {
f_name = "device_api.vpi";
} else {
LOG(FATAL) << "Unknown optional runtime " << target;
}
return runtime::Registry::Get(load_f_name) != nullptr;
return runtime::Registry::Get(f_name) != nullptr;
}
TVM_REGISTER_GLOBAL("module._Enabled")

Просмотреть файл

@ -101,12 +101,14 @@ inline const char* CLGetErrorString(cl_int error) {
/*!
* \brief Process global OpenCL workspace.
*/
class OpenCLWorkspace : public DeviceAPI {
class OpenCLWorkspace final : public DeviceAPI {
public:
// global platform id
cl_platform_id platform_id;
// global context of this process
cl_context context{nullptr};
// whether the workspace it initialized.
bool initialized_{false};
// the devices
std::vector<cl_device_id> devices;
// the queues
@ -126,24 +128,25 @@ class OpenCLWorkspace : public DeviceAPI {
OPENCL_CALL(clReleaseContext(context));
}
}
// whether the workspace is initialized.
inline bool initialized() const {
return context != nullptr;
}
// Initialzie the device.
void Init();
// get the queue of the context
cl_command_queue GetQueue(TVMContext ctx) const {
cl_command_queue GetQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kOpenCL);
CHECK(initialized())
<< "The OpenCL is not initialized";
this->Init();
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << ctx.device_id;
return queues[ctx.device_id];
}
// override device API
void SetDevice(int dev_id) final;
void GetAttr(int dev_id, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,

Просмотреть файл

@ -18,8 +18,41 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
return &inst;
}
void OpenCLWorkspace::SetDevice(int dev_id) {
OpenCLThreadEntry::ThreadLocal()->context.device_id = dev_id;
}
void OpenCLWorkspace::GetAttr(
int dev_id, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(dev_id);
if (kind == kExist) {
*rv = static_cast<int>(index< devices.size());
return;
}
CHECK_LT(index, devices.size())
<< "Invalid device id " << index;
size_t value;
switch (kind) {
case kMaxThreadsPerBlock: {
OPENCL_CALL(clGetDeviceInfo(
devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE,
sizeof(size_t), &value, nullptr));
*rv = static_cast<int64_t>(value);
break;
}
case kWarpSize: {
*rv = 1;
break;
}
case kExist: break;
}
}
void* OpenCLWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment) {
this->Init();
CHECK(context != nullptr) << "No OpenCL device";
cl_int err_code;
cl_mem mptr = clCreateBuffer(
this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
@ -33,30 +66,35 @@ void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
}
void OpenCLWorkspace::CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to),
0, 0, size, 0, nullptr, nullptr));
from_offset, to_offset, size, 0, nullptr, nullptr));
} else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) {
OPENCL_CALL(clEnqueueReadBuffer(
this->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*)
CL_FALSE, 0, size, to,
CL_FALSE, from_offset, size,
static_cast<char*>(to) + to_offset,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>(to),
CL_FALSE, 0, size, from,
CL_FALSE, to_offset, size,
static_cast<const char*>(from) + from_offset,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
} else {
@ -97,8 +135,9 @@ std::string GetDeviceInfo(
std::vector<cl_platform_id> GetPlatformIDs() {
cl_uint ret_size;
OPENCL_CALL(clGetPlatformIDs(0, nullptr, &ret_size));
cl_int code = clGetPlatformIDs(0, nullptr, &ret_size);
std::vector<cl_platform_id> ret;
if (code != CL_SUCCESS) return ret;
ret.resize(ret_size);
OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
return ret;
@ -108,11 +147,12 @@ std::vector<cl_device_id> GetDeviceIDs(
cl_platform_id pid, std::string device_type) {
cl_device_type dtype = CL_DEVICE_TYPE_ALL;
if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU;
if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
cl_uint ret_size;
OPENCL_CALL(clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size));
cl_int code = clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size);
std::vector<cl_device_id> ret;
if (code != CL_SUCCESS) return ret;
ret.resize(ret_size);
OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
return ret;
@ -127,70 +167,53 @@ bool MatchPlatformInfo(
return param_value.find(value) != std::string::npos;
}
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
std::lock_guard<std::mutex>(w->mu);
if (w->initialized()) return false;
// matching conditions
std::string platform_name, device_type;
for (int i = 0; i < args.num_args; ++i) {
std::string arg = args[i];
size_t pos = arg.find_first_of('=');
CHECK_EQ(pos, std::string::npos)
<< "Argumentes need to be key=value";
std::string key = arg.substr(0, pos);
std::string val = arg.substr(pos + 1, arg.length() - pos - 1);
if (key == "platform_name") {
platform_name = val;
} else if (key == "device_type") {
device_type = val;
} else {
LOG(FATAL) << "unknown DeviceInit option " << key;
}
}
void OpenCLWorkspace::Init() {
if (initialized_) return;
std::lock_guard<std::mutex>(this->mu);
if (initialized_) return;
initialized_ = true;
if (context != nullptr) return;
// matched platforms
std::vector<cl_platform_id> platform_matched;
for (cl_platform_id pid : cl::GetPlatformIDs()) {
bool matched = true;
if (!cl::MatchPlatformInfo(pid, CL_PLATFORM_NAME, platform_name)) matched = false;
if (matched) platform_matched.push_back(pid);
}
std::vector<cl_platform_id> platform_matched = cl::GetPlatformIDs();
if (platform_matched.size() == 0) {
LOG(FATAL) << "No OpenCL platform matched given existing options ...";
LOG(WARNING) << "No OpenCL platform matched given existing options ...";
return;
}
if (platform_matched.size() > 1) {
LOG(WARNING) << "Multiple OpenCL platforms matched, use the first one ... ";
}
w->platform_id = platform_matched[0];
this->platform_id = platform_matched[0];
LOG(INFO) << "Initialize OpenCL platform \'"
<< cl::GetPlatformInfo(w->platform_id, CL_PLATFORM_NAME) << '\'';
<< cl::GetPlatformInfo(this->platform_id, CL_PLATFORM_NAME) << '\'';
std::vector<cl_device_id> devices_matched =
cl::GetDeviceIDs(w->platform_id, device_type);
CHECK_GT(devices_matched.size(), 0U)
<< "No OpenCL device any device matched given the options";
w->devices = devices_matched;
cl::GetDeviceIDs(this->platform_id, "gpu");
if (devices_matched.size() == 0) {
LOG(WARNING) << "No OpenCL device any device matched given the options";
return;
}
this->devices = devices_matched;
cl_int err_code;
w->context = clCreateContext(
nullptr, w->devices.size(), &(w->devices[0]),
this->context = clCreateContext(
nullptr, this->devices.size(), &(this->devices[0]),
nullptr, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
CHECK_EQ(w->queues.size(), 0U);
for (size_t i = 0; i < w->devices.size(); ++i) {
cl_device_id did = w->devices[i];
w->queues.push_back(
clCreateCommandQueue(w->context, did, 0, &err_code));
CHECK_EQ(this->queues.size(), 0U);
for (size_t i = 0; i < this->devices.size(); ++i) {
cl_device_id did = this->devices[i];
this->queues.push_back(
clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
LOG(INFO) << "opencl(" << i
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
<< "\' cl_device_id=" << did;
}
return true;
}
TVM_REGISTER_GLOBAL("module.init_opencl")
.set_body(InitOpenCL);
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
w->Init();
return true;
}
TVM_REGISTER_GLOBAL("device_api.opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {

Просмотреть файл

@ -11,7 +11,7 @@
#include <vector>
#include <string>
#include <unordered_map>
#include "../void_addr_args.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
@ -90,7 +90,8 @@ class OpenCLModuleNode : public ModuleNode {
// Initialize the programs
void InitProgram() {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
CHECK(w->initialized());
w->Init();
CHECK(w->context != nullptr) << "No OpenCL device";
if (fmt_ == "cl") {
const char* s = data_.c_str();
size_t len = data_.length();
@ -179,6 +180,7 @@ class OpenCLWrappedFunc {
std::string func_name,
std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
w_ = cl::OpenCLWorkspace::Global();
m_ = m;
sptr_ = sptr;
entry_ = entry;
@ -190,9 +192,7 @@ class OpenCLWrappedFunc {
void operator()(TVMArgs args,
TVMRetValue* rv,
void** void_args) const {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
cl::OpenCLThreadEntry* t = cl::OpenCLThreadEntry::ThreadLocal();
CHECK(w->initialized());
// get the kernel from thread local kernel table.
if (entry_.kernel_id >= t->kernel_table.size()) {
t->kernel_table.resize(entry_.kernel_id + 1);
@ -200,13 +200,13 @@ class OpenCLWrappedFunc {
const auto& e = t->kernel_table[entry_.kernel_id];
cl_kernel kernel = e.kernel;
if (kernel == nullptr || e.version != entry_.version) {
kernel = m_->InstallKernel(w, t, func_name_, entry_);
kernel = m_->InstallKernel(w_, t, func_name_, entry_);
}
// setup arguments.
for (cl_uint i = 0; i < arg_size_.size(); ++i) {
OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i]));
}
cl_command_queue queue = w->GetQueue(t->context);
cl_command_queue queue = w_->GetQueue(t->context);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
for (cl_uint i = 0; i < work_dim; ++i) {
@ -221,6 +221,8 @@ class OpenCLWrappedFunc {
}
private:
// global workspace.
cl::OpenCLWorkspace* w_;
// The module
OpenCLModuleNode* m_;
// resource handle
@ -235,45 +237,12 @@ class OpenCLWrappedFunc {
ThreadAxisConfig thread_axis_cfg_;
};
/*!
* \brief Automatically detect and set cuda device.
* \param args The arguments.
*/
void AutoSetOpenCLDevice(const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 3);
TVMValue* values = static_cast<TVMValue*>(args[0].operator void*());
int* type_codes = static_cast<int*>(args[1].operator void*());
int num_args = args[2].operator int();
// TODO(tqchen): merge this with CUDA logic.
int device_id = -1;
for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.device_type, kOpenCL)
<< "All operands need to be OpenCL";
if (device_id == -1) {
device_id = ctx.device_id;
} else {
CHECK_EQ(device_id, ctx.device_id)
<< "Operands comes from different devices ";
}
}
}
CHECK_NE(device_id, -1)
<< "Cannot detect device id from list";
cl::OpenCLThreadEntry::ThreadLocal()->context.device_id = device_id;
}
PackedFunc OpenCLModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
if (name == symbol::tvm_entry_setdevice) {
return PackedFunc(AutoSetOpenCLDevice);
}
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
@ -289,7 +258,7 @@ PackedFunc OpenCLModuleNode::GetFunction(
// initialize the wrapped func.
f.Init(this, sptr_to_self, kid_map_.at(name),
name, arg_size, info.thread_axis_tags);
return PackFromVoidAddrArgs(f, info.arg_types);
return PackFuncVoidAddr(f, info.arg_types);
}
Module OpenCLModuleCreate(

Просмотреть файл

@ -18,7 +18,7 @@ namespace runtime {
/*!
* \brief create a cuda module from data.
*
* \param data The module data, can be ptx, cubin
* \param data The module data.
* \param fmt The format of the data, can be "clbin", "cl"
* \param fmap The map function information map of each function.
*/

233
src/runtime/pack_args.h Normal file
Просмотреть файл

@ -0,0 +1,233 @@
/*!
* Copyright (c) 2017 by Contributors
* \file pack_args.h
* \brief Utility to pack TVMArgs to other type-erased fution calling convention.
*
* Two type erased function signatures are supported.
* - cuda_style(void** args, int num_args);
* - Pack everything by address
* - metal_style(void** buffers, int num_buffers,
* union_32bit args[N], int num_args);
* - Pack buffer by address, pack rest parameter into 32bit union buffer.
*/
#ifndef TVM_RUNTIME_PACK_ARGS_H_
#define TVM_RUNTIME_PACK_ARGS_H_
#include <tvm/runtime/c_runtime_api.h>
#include <vector>
namespace tvm {
namespace runtime {
/*!
* \brief argument union type of 32bit.
* Choose 32 bit because most GPU API do not work well with 64 bit.
*/
union ArgUnion {
int32_t v_int32;
uint32_t v_uint32;
float v_float32;
};
/*!
* \brief Create a packed function from void addr types.
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments that wish to get from
* \tparam T the function type
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
/*!
* \brief Create a packed function that from function only packs buffer arguments.
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments that wish to get from
* \tparam T the function type
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
/*!
* \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types.
* \return number of buffer arguments
*/
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types);
// implementations details
namespace detail {
template<typename T, int kSize>
class TempArray {
public:
explicit TempArray(int size) {}
T* data() {
return data_;
}
private:
T data_[kSize];
};
template<typename T>
class TempArray<T, 0> {
public:
explicit TempArray(int size) : data_(size) {}
T* data() {
return data_.data();
}
private:
std::vector<T> data_;
};
/*! \brief conversion code used in void arg. */
enum ArgConvertCode {
INT64_TO_INT64,
INT64_TO_INT32,
INT64_TO_UINT32,
FLOAT64_TO_FLOAT32,
FLOAT64_TO_FLOAT64,
HANDLE_TO_HANDLE
};
inline ArgConvertCode GetArgConvertCode(TVMType t) {
CHECK_EQ(t.lanes, 1U)
<< "Cannot pass vector type argument to devic function for now";
if (t.code == kInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
} else if (t.code == kUInt) {
if (t.bits == 32U) return INT64_TO_UINT32;
} else if (t.code == kFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kHandle) {
return HANDLE_TO_HANDLE;
}
LOG(FATAL) << "Cannot handle " << t << " as device function argument";
return HANDLE_TO_HANDLE;
}
template<int N, typename F>
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<void*, N> addr_(num_args);
TempArray<ArgUnion, N> holder_(num_args);
void** addr = addr_.data();
ArgUnion* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64:
case HANDLE_TO_HANDLE: {
addr[i] = (void*)&(args.values[i]); // NOLINT(*)
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
}
case INT64_TO_UINT32 : {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
addr[i] = &(holder[i]);
break;
}
}
}
f(args, ret, addr);
};
return PackedFunc(ret);
}
template<int N, typename F>
inline PackedFunc PackFuncNonBufferArg_(
F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<ArgUnion, N> holder_(num_args);
ArgUnion* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: {
LOG(FATAL) << "Donot support 64bit argument to device function"; break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
break;
}
case INT64_TO_UINT32 : {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
break;
}
case HANDLE_TO_HANDLE: {
LOG(FATAL) << "not reached"; break;
}
}
}
f(args, ret, holder);
};
return PackedFunc(ret);
}
} // namespace detail
template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetArgConvertCode(arg_types[i]);
}
size_t num_void_args = arg_types.size();
// specialization
if (num_void_args <= 4) {
return detail::PackFuncVoidAddr_<4>(f, codes);
} else if (num_void_args <= 8) {
return detail::PackFuncVoidAddr_<8>(f, codes);
} else {
return detail::PackFuncVoidAddr_<0>(f, codes);
}
}
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types) {
size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kHandle) {
base = i; break;
}
}
for (size_t i = base; i < arg_types.size(); ++i) {
CHECK(arg_types[i].code != kHandle)
<< "Device function need to be organized";
}
return base;
}
template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes;
for (size_t i = num_buffer; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
}
int base = static_cast<int>(num_buffer);
size_t nargs = codes.size();
// specialization
if (nargs <= 4) {
return detail::PackFuncNonBufferArg_<4>(f, base, codes);
} else {
return detail::PackFuncNonBufferArg_<0>(f, base, codes);
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_

Просмотреть файл

@ -1,164 +0,0 @@
/*!
* Copyright (c) 2017 by Contributors
* \file void_addr_args.h
* \brief Utility to convert TVMArgs to void* array type-erasure function call.
*
* Array of argument address is a typical way of type-erasure for functions.
* The function signiture looks like function(void** args, int num_args);
* Where args takes the address of each input.
*/
#ifndef TVM_RUNTIME_VOID_ADDR_ARGS_H_
#define TVM_RUNTIME_VOID_ADDR_ARGS_H_
#include <tvm/runtime/c_runtime_api.h>
#include <vector>
namespace tvm {
namespace runtime {
/*!
* \brief Create a packed function from void addr types
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments that wish to get from
* \tparam T the function type
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFromVoidAddrArgs(
F f, const std::vector<TVMType>& arg_types);
// implementations details
namespace detail {
/*!
* \brief void addr argument data content
* holder in case conversion is needed.
*/
union VoidArgHolder {
int32_t v_int32;
uint32_t v_uint32;
float v_float32;
};
template<int MAX_NARG>
class VoidAddrArray {
public:
explicit VoidAddrArray(int num_args) {
}
void** addr() {
return addr_;
}
VoidArgHolder* holder() {
return holder_;
}
private:
void* addr_[MAX_NARG];
VoidArgHolder holder_[MAX_NARG];
};
template<>
class VoidAddrArray<0> {
public:
explicit VoidAddrArray(int num_args)
: addr_(num_args), holder_(num_args) {
}
void** addr() {
return addr_.data();
}
VoidArgHolder* holder() {
return holder_.data();
}
private:
std::vector<void*> addr_;
std::vector<VoidArgHolder> holder_;
};
/*! \brief conversion code used in void arg. */
enum VoidArgConvertCode {
INT64_TO_INT64,
INT64_TO_INT32,
INT64_TO_UINT32,
FLOAT64_TO_FLOAT32,
FLOAT64_TO_FLOAT64,
HANDLE_TO_HANDLE
};
template<int N, typename F>
inline PackedFunc PackFromVoidAddrArgs_(
F f, const std::vector<VoidArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
VoidAddrArray<N> temp(num_args);
void** addr = temp.addr();
VoidArgHolder* holder = temp.holder();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64:
case HANDLE_TO_HANDLE: {
addr[i] = (void*)&(args.values[i]); // NOLINT(*)
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
}
case INT64_TO_UINT32 : {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
addr[i] = &(holder[i]);
break;
}
}
}
f(args, ret, addr);
};
return PackedFunc(ret);
}
inline VoidArgConvertCode GetVoidArgConvertCode(TVMType t) {
CHECK_EQ(t.lanes, 1U);
if (t.code == kInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
} else if (t.code == kUInt) {
if (t.bits == 32U) return INT64_TO_UINT32;
} else if (t.code == kFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kHandle) {
return HANDLE_TO_HANDLE;
}
LOG(FATAL) << "Cannot handle " << t;
return HANDLE_TO_HANDLE;
}
} // namespace detail
template<typename F>
inline PackedFunc PackFromVoidAddrArgs(
F f, const std::vector<TVMType>& arg_types) {
std::vector<detail::VoidArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetVoidArgConvertCode(arg_types[i]);
}
size_t num_void_args = arg_types.size();
// specialization
if (num_void_args <= 4) {
return detail::PackFromVoidAddrArgs_<4>(f, codes);
} else if (num_void_args <= 8) {
return detail::PackFromVoidAddrArgs_<8>(f, codes);
} else {
return detail::PackFromVoidAddrArgs_<0>(f, codes);
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_VOID_ADDR_ARGS_H_

Просмотреть файл

@ -17,6 +17,7 @@ def lower(s, args, name="mydot"):
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
return fapi
@ -35,7 +36,7 @@ def test_dot():
fapi = lower(s, [A, B, C])
def verify(target):
if not tvm.codegen.enabled(target):
if not tvm.module.enabled(target):
print("Target %s is not enabled" % target)
return
f = tvm.codegen.build_module(fapi, target)

Просмотреть файл

@ -1,5 +1,6 @@
import tvm
import numpy as np
import time
def test_exp():
# graph
@ -8,21 +9,21 @@ def test_exp():
B = tvm.compute(A.shape, lambda *i: tvm.exp(A(*i)), name='B')
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
num_thread = 64
num_thread = 8
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
if not tvm.module.enabled(host):
return
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
return
fexp = tvm.build(s, [A, B],
device, host,
name="myexp")
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
@ -31,8 +32,6 @@ def test_exp():
np.testing.assert_allclose(
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda", "llvm")
check_device("opencl")
@ -46,7 +45,7 @@ def test_log_llvm():
# create iter var and assign them tags.
bx, tx = s[B].split(B.op.axis[0], factor=32)
# one line to build the function.
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
flog = tvm.build(s, [A, B],
@ -60,17 +59,26 @@ def test_log_llvm():
np.testing.assert_allclose(
b.asnumpy(), np.log(a.asnumpy()), rtol=1e-5)
from tvm.contrib import nvcc_compiler
@tvm.register_func
def tvm_callback_cuda_compile(code):
print(code)
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
return ptx
def test_add():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
bias = tvm.var("bias", dtype="float32")
scale = tvm.var("scale", dtype="float32")
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name='C')
# schedule
s = tvm.create_schedule(C.op)
# create iter var and assign them tags.
num_thread = 256
num_thread = 32
bx, x = s[C].split(C.op.axis[0], factor=num_thread*4)
tx, x = s[C].split(x, nparts=num_thread)
_, x = s[C].split(x, factor=4)
@ -80,26 +88,28 @@ def test_add():
# one line to build the function.
def check_device(device):
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
fadd = tvm.build(s, [A, B, C],
fadd = tvm.build(s, [A, B, C, bias, scale],
device,
name="myadd")
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
print(fadd.imported_modules[0].get_source())
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd(a, b, c)
vbias = np.random.uniform()
vscale = np.random.uniform()
fadd(a, b, c, vbias, vscale)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
c.asnumpy(), a.asnumpy() + b.asnumpy() * vscale + vbias, rtol=1e-6)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda")
check_device("opencl")
check_device("metal")
check_device("cuda")
if __name__ == "__main__":

Просмотреть файл

@ -1,6 +1,13 @@
import tvm
from tvm.contrib import nvcc_compiler
from tvm.contrib import metal_compiler
import numpy as np
import time
#@tvm.register_func
def tvm_callback_metal_compile(code):
lib = metal_compiler.compile_source(code)
return lib
def test_gemm():
# graph
@ -63,15 +70,14 @@ def test_gemm():
s = s.normalize()
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
def check_device(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B, C], device, host,
f = tvm.build(s, [A, B, C], device,
max_auto_unroll_step=max_auto_unroll_step)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
# launch the kernel.
n = nn
m = n
@ -81,15 +87,20 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
for i in range(4):
f(a, b, c)
f(a, b, c)
ctx.sync()
tbegin = time.time()
f(a, b, c)
tpush = time.time()
ctx.sync()
tend = time.time()
print("launch=%g sec, exec=%g sec" % (tpush - tbegin, tend - tbegin))
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
check_device("cuda")
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("metal")
check_device("opencl")
check_device("cuda")
if __name__ == "__main__":
test_gemm()

Просмотреть файл

@ -19,16 +19,16 @@ def test_reduce_prims():
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
if not tvm.module.enabled(host):
return
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
freduce = tvm.build(s,
args=[A, B],
target=device, target_host=host,
name="myreduce")
print(freduce.imported_modules[0].get_source())
# launch the kernel.
n = 1028
m = 129
@ -41,9 +41,7 @@ def test_reduce_prims():
res[:2] = 0
np.testing.assert_allclose(npy, res, rtol=1e-4)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("metal")
check_device("cuda")
check_device("opencl")
test_prim(tvm.sum, np.sum)
@ -64,7 +62,7 @@ def test_rfactor():
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
if not tvm.codegen.enabled(target):
if not tvm.module.enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, B])
@ -105,15 +103,14 @@ def test_rfactor_threads():
# one line to build the function.
def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[A, B])
fapi2 = tvm.ir_pass.LowerThreadAllreduce(fapi, 32)
fsum = tvm.build(fapi,
target=device,
name="mysum")
print(fsum.imported_modules[0].get_source())
# launch the kernel.
n = nn
m = mm
@ -125,9 +122,8 @@ def test_rfactor_threads():
np.testing.assert_allclose(
b.asnumpy(), res, rtol=1e-4)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_target("cuda")
check_target("metal")
check_target("opencl")
if __name__ == "__main__":

Просмотреть файл

@ -23,15 +23,14 @@ def test_scan():
s[s_update].bind(xi, thread_x)
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
def check_device(device):
if not tvm.module.enabled(device):
print("skip because %s is not enabled.." % device)
return
fscan = tvm.build(s, [X, res],
device, host,
device,
name="myscan")
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
m = 10
@ -42,10 +41,8 @@ def test_scan():
np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0))
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda")
check_device("metal")
check_device("opencl")

Просмотреть файл

@ -99,9 +99,9 @@ def test_gemm():
# correctness
def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host):
if not tvm.module.enabled(host):
return
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
return
f = tvm.build(s, [A, B, C], device, host,
max_auto_unroll_step=max_auto_unroll_step)

Просмотреть файл

@ -29,17 +29,16 @@ def test_add_pipeline():
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(host):
if not tvm.module.enabled(host):
return
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device)
mhost.import_module(mdev)
code = mdev.get_source()
f = mhost.entry_func
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
@ -50,11 +49,11 @@ def test_add_pipeline():
c.asnumpy(), a.asnumpy() + b.asnumpy())
def check_module_save(device, host="stackvm"):
if not tvm.codegen.enabled(host):
if not tvm.module.enabled(host):
return
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
ctx = tvm.context(device, 0)
fmt = "ptx" if device == "cuda" else "cl"
mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device)

Просмотреть файл

@ -19,7 +19,7 @@ def test_add_pipeline():
s = tvm.create_schedule(C.op)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
@ -51,7 +51,7 @@ def test_pack_buffer_simple():
def check_target(target):
if not tvm.codegen.enabled(target):
if not tvm.module.enabled(target):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], target)
@ -81,7 +81,7 @@ def test_pack_buffer_intermediate():
s = tvm.create_schedule(C.op)
def check_target(target):
if not tvm.codegen.enabled(target):
if not tvm.module.enabled(target):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], target)

Просмотреть файл

@ -12,7 +12,7 @@ def test_llvm_add_pipeline():
s[C].parallel(xo)
s[C].vectorize(xi)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.build(s, [A, B, C], "llvm")
@ -30,7 +30,7 @@ def test_llvm_add_pipeline():
def test_llvm_flip_pipeline():
def check_llvm(nn, base):
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
n = tvm.convert(nn)
A = tvm.placeholder((n + base), name='A')
@ -57,7 +57,7 @@ def test_llvm_flip_pipeline():
def test_llvm_madd_pipeline():
def check_llvm(nn, base, stride):
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
n = tvm.convert(nn)
A = tvm.placeholder((n + base, stride), name='A')
@ -89,7 +89,7 @@ def test_llvm_temp_space():
s = tvm.create_schedule(C.op)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")

Просмотреть файл

@ -3,7 +3,7 @@ import numpy as np
def run_jit(fapi, check):
for target in ["llvm", "stackvm"]:
if not tvm.codegen.enabled(target):
if not tvm.module.enabled(target):
continue
f = tvm.codegen.build_module(fapi, target)
s = f.get_source()

Просмотреть файл

@ -22,7 +22,7 @@ print("Finish runtime checking...")
"""
def test_dso_module_load():
if not tvm.codegen.enabled("llvm"):
if not tvm.module.enabled("llvm"):
return
dtype = 'int64'
temp = util.tempdir()
@ -38,6 +38,7 @@ def test_dso_module_load():
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
m = tvm.codegen.build_module(fapi, "llvm")
for name in names:
m.save(name)

Просмотреть файл

@ -2,14 +2,14 @@ import tvm
import numpy as np
def enabled_ctx_list():
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
ctx_list = [('cpu', tvm.cpu(0)),
('gpu', tvm.gpu(0)),
('cl', tvm.opencl(0)),
('cpu', tvm.vpi(0))]
ctx_list = [x[1] for x in ctx_list if tvm.module.enabled(x[0])]
('metal', tvm.metal(0)),
('vpi', tvm.vpi(0))]
for k, v in ctx_list:
assert tvm.context(k, 0) == v
ctx_list = [x[1] for x in ctx_list if x[1].exist]
return ctx_list
ENABLED_CTX_LIST = enabled_ctx_list()
@ -29,7 +29,8 @@ def test_nd_create():
np.testing.assert_equal(x, y.asnumpy())
np.testing.assert_equal(x, z.asnumpy())
# no need here, just to test usablity
tvm.nd.sync(ctx)
ctx.sync()
if __name__ == "__main__":
test_nd_create()

Просмотреть файл

@ -18,16 +18,17 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
fi
cp make/config.mk config.mk
echo "USE_CUDA=0" >> config.mk
echo "ENABLE_CUDA=0" >> config.mk
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
echo "USE_OPENCL=1" >> config.mk
echo "ENABLE_OPENCL=1" >> config.mk
echo "ENABLE_METAL=1" >> config.mk
else
# use g++-4.8 for linux
if [ ${CXX} == "g++" ]; then
export CXX=g++-4.8
fi
echo "USE_OPENCL=0" >> config.mk
echo "ENABLE_OPENCL=0" >> config.mk
fi
if [ ${TASK} == "verilog_test" ] || [ ${TASK} == "all_test" ]; then

Просмотреть файл

@ -40,13 +40,12 @@ def test_add_pipeline():
fapi = lower(s, [A, B, C], "myadd")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
print(fsplits[1].body)
print("------")
def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(host):
if not tvm.module.enabled(host):
return
if not tvm.codegen.enabled(device):
if not tvm.module.enabled(device):
return
ctx = tvm.vpi(0)
mhost = tvm.codegen.build_module(fsplits[0], host)

Просмотреть файл

@ -228,7 +228,6 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# The following codeblocks generate opencl code, creates array on opencl
# device, and verifies the correctness of the code.
#
tvm.module.init_opencl()
fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd")
print("------opencl code------")
print(fadd_cl.imported_modules[0].get_source())

Просмотреть файл

@ -65,7 +65,6 @@ print(fcuda.imported_modules[0].get_source())
# We can find that the code works for both CUDA and opencl.
# The same tvm.exp can also be used for float64 data types.
#
tvm.module.init_opencl()
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
print(fopencl.imported_modules[0].get_source())