[BACKEND] Vulkan Runtime and SPIRV Codegen (#861)
* [BACKEND] Vulkan Runtime and SPIRV Codegen * fix doc
This commit is contained in:
Родитель
108e9f3f78
Коммит
79d503fd3b
|
@ -1,4 +1,4 @@
|
|||
cmake_minimum_required(VERSION 3.5)
|
||||
cmake_minimum_required(VERSION 3.7)
|
||||
project(tvm C CXX)
|
||||
|
||||
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
|
||||
|
@ -22,6 +22,7 @@ endif()
|
|||
|
||||
tvm_option(USE_CUDA "Build with CUDA" OFF)
|
||||
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
|
||||
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
|
||||
tvm_option(USE_OPENGL "Build with OpenGL" OFF)
|
||||
tvm_option(USE_METAL "Build with Metal" OFF)
|
||||
tvm_option(USE_RPC "Build with RPC" ON)
|
||||
|
@ -88,9 +89,11 @@ file(GLOB_RECURSE HALIDEIR_SRCS HalideIR/src/*.cpp)
|
|||
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
|
||||
file(GLOB RUNTIME_SRCS src/runtime/*.cc)
|
||||
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
|
||||
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
|
||||
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
|
||||
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
|
||||
file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc)
|
||||
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
|
||||
file(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm)
|
||||
file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc)
|
||||
file(GLOB RUNTIME_GRAPH_SRCS src/runtime/graph/*.cc)
|
||||
|
@ -151,6 +154,22 @@ else(USE_OPENGL)
|
|||
add_definitions(-DTVM_OPENGL_RUNTIME=0)
|
||||
endif(USE_OPENGL)
|
||||
|
||||
if(USE_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
message(STATUS "Build with VULKAN support")
|
||||
include_directories(${Vulkan_INCLUDE_DIRS})
|
||||
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARIES})
|
||||
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
|
||||
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
|
||||
get_filename_component(VULKAN_LIB_PATH ${Vulkan_LIBRARY} DIRECTORY)
|
||||
find_library(SPIRV_TOOLS_LIB SPIRV-Tools
|
||||
${VULKAN_LIB_PATH}/spirv-tools)
|
||||
list(APPEND TVM_LINKER_LIBS ${SPIRV_TOOLS_LIB})
|
||||
add_definitions(-DTVM_VULKAN_RUNTIME=1)
|
||||
else(USE_VULKAN)
|
||||
add_definitions(-DTVM_VULKAN_RUNTIME=0)
|
||||
endif(USE_VULKAN)
|
||||
|
||||
if(USE_METAL)
|
||||
find_package(OpenCL QUIET REQUIRED)
|
||||
message(STATUS "Build with Metal support")
|
||||
|
@ -174,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
|
|||
endif(USE_GRAPH_RUNTIME)
|
||||
|
||||
if(USE_LLVM)
|
||||
find_package(LLVM CONFIG REQUIRED)
|
||||
find_spackage(LLVM CONFIG REQUIRED)
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
|
||||
|
@ -252,4 +271,4 @@ if(MSVC)
|
|||
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
|
||||
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
|
||||
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
|
||||
endif()
|
||||
endif()
|
||||
|
|
2
HalideIR
2
HalideIR
|
@ -1 +1 @@
|
|||
Subproject commit 87b089a0ba20f2e8257038ee9211d6816088ce95
|
||||
Subproject commit aadbf02d6bd7a545edbf6652494a7b07a97a06c1
|
16
Makefile
16
Makefile
|
@ -56,6 +56,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc)
|
|||
ROCM_SRC = $(wildcard src/runtime/rocm/*.cc)
|
||||
OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc)
|
||||
OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc)
|
||||
VULKAN_SRC = $(wildcard src/runtime/vulkan/*.cc)
|
||||
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
|
||||
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
|
||||
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
|
||||
|
@ -69,6 +70,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC))
|
|||
ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC))
|
||||
OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC))
|
||||
OPENGL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENGL_SRC))
|
||||
VULKAN_OBJ = $(patsubst src/%.cc, build/%.o, $(VULKAN_SRC))
|
||||
RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
|
||||
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
|
||||
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
|
||||
|
@ -129,6 +131,20 @@ else
|
|||
CFLAGS += -DTVM_OPENCL_RUNTIME=0
|
||||
endif
|
||||
|
||||
ifdef VULKAN_SDK
|
||||
CFLAGS += -I$(VULKAN_SDK)/include
|
||||
LDFLAGS += -L$(VULKAN_SDK)/lib
|
||||
LDFLAGS += -L$(VULKAN_SDK)/lib/spirv-tools
|
||||
endif
|
||||
|
||||
ifeq ($(USE_VULKAN), 1)
|
||||
CFLAGS += -DTVM_VULKAN_RUNTIME=1
|
||||
LDFLAGS += -lvulkan -lSPIRV-Tools
|
||||
RUNTIME_DEP += $(VULKAN_OBJ)
|
||||
else
|
||||
CFLAGS += -DTVM_VULKAN_RUNTIME=0
|
||||
endif
|
||||
|
||||
ifeq ($(USE_OPENGL), 1)
|
||||
CFLAGS += -DTVM_OPENGL_RUNTIME=1
|
||||
EMCC_FLAGS += -DTVM_OPENGL_RUNTIME=1
|
||||
|
|
|
@ -421,6 +421,18 @@ LoweredFunc LowerTVMBuiltin(LoweredFunc f);
|
|||
*/
|
||||
LoweredFunc CombineContextCall(LoweredFunc f);
|
||||
|
||||
/*!
|
||||
* \brief Rewrite the pointer content type of arguments,
|
||||
* as well as Alloc internal to the function to use
|
||||
* the most frequently accessed type for load/store
|
||||
* to avoid pointer casting in backend when possible.
|
||||
*
|
||||
* \note implemeneted in storage_rewrite.cc
|
||||
* \param f The function to be trasnformed
|
||||
* \return Transformed function.
|
||||
*/
|
||||
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
|
||||
|
||||
/*!
|
||||
* \brief Lower intrinsic function calls.
|
||||
* \param f The device function to be lowered.
|
||||
|
|
|
@ -55,8 +55,8 @@ typedef int64_t tvm_index_t;
|
|||
|
||||
/*! \brief Extension device types in TVM */
|
||||
typedef enum {
|
||||
kDLVulkan = 7,
|
||||
kOpenGL = 11,
|
||||
|
||||
// Extension DRAM type, used for quickly test extension device
|
||||
// The device api can differ depending on the xpu driver registered.
|
||||
kExtDev = 12,
|
||||
|
|
|
@ -17,7 +17,8 @@ from . import ir_builder
|
|||
from . import target
|
||||
|
||||
from . import ndarray as nd
|
||||
from .ndarray import context, cpu, gpu, opencl, cl, metal, mtl, vpi, rocm, opengl, ext_dev
|
||||
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
|
||||
from .ndarray import vpi, rocm, opengl, ext_dev
|
||||
|
||||
from ._ffi.runtime_ctypes import TypeCode
|
||||
from ._ffi.function import Function
|
||||
|
|
|
@ -94,6 +94,7 @@ class TVMContext(ctypes.Structure):
|
|||
1 : 'cpu',
|
||||
2 : 'gpu',
|
||||
4 : 'opencl',
|
||||
7 : 'vulkan',
|
||||
8 : 'metal',
|
||||
9 : 'vpi',
|
||||
10: 'rocm',
|
||||
|
@ -109,6 +110,7 @@ class TVMContext(ctypes.Structure):
|
|||
'nvptx': 2,
|
||||
'cl': 4,
|
||||
'opencl': 4,
|
||||
'vulkan': 7,
|
||||
'metal': 8,
|
||||
'vpi': 9,
|
||||
'rocm': 10,
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
"""Utility for Interacting with SPIRV Tools"""
|
||||
import subprocess
|
||||
import os
|
||||
from . import util
|
||||
|
||||
|
||||
def optimize(spv_bin):
|
||||
"""Optimize SPIRV using spirv-opt via CLI
|
||||
|
||||
Note that the spirv-opt is still experimental.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spv_bin : bytearray
|
||||
The spirv file
|
||||
|
||||
Return
|
||||
------
|
||||
cobj_bin : bytearray
|
||||
The HSA Code Object
|
||||
"""
|
||||
|
||||
tmp_dir = util.tempdir()
|
||||
tmp_in = tmp_dir.relpath("input.spv")
|
||||
tmp_out = tmp_dir.relpath("output.spv")
|
||||
with open(tmp_in, "wb") as out_file:
|
||||
out_file.write(bytes(spv_bin))
|
||||
|
||||
sdk = os.environ.get("VULKAN_SDK", None)
|
||||
cmd = os.path.join(sdk, "bin/spirv-opt") if sdk else "spirv-opt"
|
||||
args = [cmd, "-O", tmp_in, "-o", tmp_out]
|
||||
proc = subprocess.Popen(
|
||||
args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
(out, _) = proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
msg = "Opitmizationerror using spirv-opt:\n"
|
||||
msg += str(out)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return bytearray(open(tmp_out, "rb").read())
|
|
@ -120,6 +120,23 @@ def vpi(dev_id=0):
|
|||
"""
|
||||
return TVMContext(9, dev_id)
|
||||
|
||||
|
||||
def vulkan(dev_id=0):
|
||||
"""Construct a Vulkan device
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dev_id : int, optional
|
||||
The integer device id
|
||||
|
||||
Returns
|
||||
-------
|
||||
ctx : TVMContext
|
||||
The created context
|
||||
"""
|
||||
return TVMContext(7, dev_id)
|
||||
|
||||
|
||||
def opengl(dev_id=0):
|
||||
"""Construct a OpenGL device
|
||||
|
||||
|
@ -135,6 +152,7 @@ def opengl(dev_id=0):
|
|||
"""
|
||||
return TVMContext(11, dev_id)
|
||||
|
||||
|
||||
def ext_dev(dev_id=0):
|
||||
"""Construct a extension device
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class Target(object):
|
|||
# For now assume rocm schedule for opencl
|
||||
self.keys += ("rocm", "gpu")
|
||||
self.max_num_threads = 256
|
||||
elif target_name in ("metal",):
|
||||
elif target_name in ("metal", "vulkan"):
|
||||
self.keys += ("gpu",)
|
||||
self.max_num_threads = 256
|
||||
elif target_name in ("opengl",):
|
||||
|
|
|
@ -666,6 +666,8 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
|
|||
}
|
||||
|
||||
void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
|
||||
// constraint of current logic
|
||||
CHECK_EQ(op->base.type(), Int(32));
|
||||
os << "((int" << op->lanes << ")(";
|
||||
for (int i = 0; i < op->lanes; i++) {
|
||||
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file codegen_common.h
|
||||
* \brief Common utility for codegen.
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
|
||||
#define TVM_CODEGEN_CODEGEN_COMMON_H_
|
||||
|
||||
#include <tvm/arithmetic.h>
|
||||
#include "../arithmetic/compute_expr.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
/*!
|
||||
* \brief Visit AssertStmt recursively, update align_map from condition.
|
||||
* \param op The AssertStmt
|
||||
* \param align_map The alignmap
|
||||
* \param fvisit The recursive visitor
|
||||
* \tparam FVisit the recursive visitor
|
||||
*/
|
||||
template<typename FVisit>
|
||||
inline void VisitAssert(
|
||||
const ir::AssertStmt* op,
|
||||
std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
|
||||
FVisit fvisit) {
|
||||
using namespace ir;
|
||||
auto& align_map_ = *align_map;
|
||||
// Detect useful invariant pattern and use them to visit child.
|
||||
// Pattern: Var % const == 0
|
||||
// TODO(tqchen) merge these pattern to a generic scope info visitor.
|
||||
if (const EQ* eq = op->condition.as<EQ>()) {
|
||||
const Mod* mod = eq->a.as<Mod>();
|
||||
int64_t factor = 0, offset = 0;
|
||||
if (mod && arith::GetConst(eq->b, &offset)) {
|
||||
const Variable *var = mod->a.as<Variable>();
|
||||
if (var && arith::GetConst(mod->b, &factor)) {
|
||||
arith::ModularEntry old = align_map_[var];
|
||||
if (factor > old.coeff) {
|
||||
arith::ModularEntry e;
|
||||
e.coeff = static_cast<int>(factor);
|
||||
e.base = static_cast<int>(offset);
|
||||
// new alignment info,
|
||||
align_map_[var] = e;
|
||||
fvisit(op->body);
|
||||
// restore old info
|
||||
align_map_[var] = old;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fvisit(op->body);
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_CODEGEN_CODEGEN_COMMON_H_
|
|
@ -9,6 +9,7 @@
|
|||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include "./codegen_llvm.h"
|
||||
#include "./codegen_cpu.h"
|
||||
#include "../codegen_common.h"
|
||||
#include "../../pass/ir_util.h"
|
||||
#include "../../arithmetic/compute_expr.h"
|
||||
|
||||
|
@ -341,7 +342,7 @@ void CodeGenLLVM::GetAlignment(Type t,
|
|||
int align_bits = t.bits();
|
||||
while (align_bits < max_align_bits &&
|
||||
me.base % 2 == 0 &&
|
||||
me.coeff %2 == 0) {
|
||||
me.coeff % 2 == 0) {
|
||||
me.base = me.base / 2;
|
||||
me.coeff = me.coeff / 2;
|
||||
align_bits *= 2;
|
||||
|
@ -1026,31 +1027,9 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
|
|||
}
|
||||
|
||||
void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
|
||||
// Detect useful invariant pattern and use them to visit child.
|
||||
// Pattern: Var % const == 0
|
||||
// TODO(tqchen) move these pattern to a generic scope info visitor.
|
||||
if (const EQ* eq = op->condition.as<EQ>()) {
|
||||
const Mod* mod = eq->a.as<Mod>();
|
||||
int64_t factor = 0, offset = 0;
|
||||
if (mod && arith::GetConst(eq->b, &offset)) {
|
||||
const Variable *var = mod->a.as<Variable>();
|
||||
if (var && arith::GetConst(mod->b, &factor)) {
|
||||
arith::ModularEntry old = align_map_[var];
|
||||
if (factor > old.coeff) {
|
||||
arith::ModularEntry e;
|
||||
e.coeff = static_cast<int>(factor);
|
||||
e.base = static_cast<int>(offset);
|
||||
// new alignment info,
|
||||
align_map_[var] = e;
|
||||
this->VisitStmt(op->body);
|
||||
// restore old info
|
||||
align_map_[var] = old;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
this->VisitStmt(op->body);
|
||||
VisitAssert(op, &align_map_, [this](const Stmt& body) {
|
||||
this->VisitStmt(body);
|
||||
});
|
||||
}
|
||||
|
||||
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file build_vulkan.cc
|
||||
* \brief Build SPIRV block
|
||||
*/
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
// Use libspirv for parsing and validating code.
|
||||
#include <vulkan/libspirv.h>
|
||||
#include <dmlc/memory_io.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
|
||||
#include "./codegen_spirv.h"
|
||||
#include "../build_common.h"
|
||||
#include "../../runtime/vulkan/vulkan_module.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
class SPIRVTools {
|
||||
public:
|
||||
SPIRVTools() {
|
||||
ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0);
|
||||
}
|
||||
~SPIRVTools() {
|
||||
spvContextDestroy(ctx_);
|
||||
}
|
||||
std::string BinaryToText(const std::vector<uint32_t>& bin) {
|
||||
spv_text text = nullptr;
|
||||
spv_diagnostic diagnostic;
|
||||
spv_const_binary_t spv_bin{bin.data(), bin.size()};
|
||||
spv_result_t res;
|
||||
|
||||
res = spvBinaryToText(
|
||||
ctx_, spv_bin.code, spv_bin.wordCount,
|
||||
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
|
||||
SPV_BINARY_TO_TEXT_OPTION_INDENT,
|
||||
&text, &diagnostic);
|
||||
|
||||
CHECK_EQ(res, SPV_SUCCESS)
|
||||
<< " line=" << diagnostic->position.line
|
||||
<< " column=" << diagnostic->position.column
|
||||
<< " index=" << diagnostic->position.index
|
||||
<< " error:" << diagnostic->error;
|
||||
|
||||
std::string ret(text->str);
|
||||
spvTextDestroy(text);
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
spv_context ctx_;
|
||||
};
|
||||
|
||||
runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
|
||||
using tvm::runtime::Registry;
|
||||
using tvm::runtime::VulkanShader;
|
||||
|
||||
std::ostringstream code_data;
|
||||
static SPIRVTools spirv_tools;
|
||||
std::unordered_map<std::string, VulkanShader> smap;
|
||||
|
||||
const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
|
||||
|
||||
CodeGenSPIRV cg;
|
||||
for (LoweredFunc f : funcs) {
|
||||
f = PointerValueTypeRewrite(f);
|
||||
VulkanShader shader;
|
||||
shader.data = cg.BuildFunction(f);
|
||||
|
||||
if (postproc != nullptr) {
|
||||
TVMByteArray arr;
|
||||
arr.data = reinterpret_cast<const char*>(dmlc::BeginPtr(shader.data));
|
||||
arr.size = shader.data.size() * sizeof(uint32_t);
|
||||
std::string transformed = (*postproc)(arr);
|
||||
CHECK_EQ(transformed.length() % 4U, 0U);
|
||||
shader.data.resize(transformed.size() / 4U);
|
||||
std::copy(transformed.begin(), transformed.end(),
|
||||
reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
|
||||
}
|
||||
code_data << spirv_tools.BinaryToText(shader.data);
|
||||
smap[f->name] = std::move(shader);
|
||||
}
|
||||
return runtime::VulkanModuleCreate(
|
||||
smap, ExtractFuncInfo(funcs), code_data.str());
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_vulkan")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = BuildSPIRV(args[0]);
|
||||
});
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_VULKAN_RUNTIME
|
|
@ -0,0 +1,638 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file codegen_spirv.cc
|
||||
* \brief Generate SPIRV block
|
||||
*/
|
||||
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include "../codegen_common.h"
|
||||
#include "./codegen_spirv.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
|
||||
this->InitFuncState();
|
||||
CHECK(f->is_restricted)
|
||||
<< "SPIRV only takes restricted memory model";
|
||||
std::vector<Var> pod_args;
|
||||
uint32_t num_buffer = 0;
|
||||
for (Var arg : f->args) {
|
||||
Type t = arg.type();
|
||||
if (t.is_handle()) {
|
||||
auto it = f->handle_data_type.find(arg);
|
||||
if (it != f->handle_data_type.end()) {
|
||||
Type value_type = (*it).second.type();
|
||||
spirv::Value arg_value = builder_->BufferArgument(
|
||||
builder_->GetSType(value_type), 0, num_buffer);
|
||||
storage_info_[arg.get()].UpdateContentType(value_type);
|
||||
var_map_[arg.get()] = arg_value;
|
||||
} else {
|
||||
LOG(FATAL) << "require all handles to be typed";
|
||||
}
|
||||
++num_buffer;
|
||||
} else {
|
||||
pod_args.push_back(arg);
|
||||
}
|
||||
}
|
||||
spirv::Value func_ptr = builder_->DeclareKenrelFunction(f->name);
|
||||
builder_->StartFunction(func_ptr);
|
||||
|
||||
// All the POD arguments are passed in through PushConstant
|
||||
if (pod_args.size() != 0) {
|
||||
std::vector<spirv::SType> value_types;
|
||||
for (size_t i = 0; i < pod_args.size(); ++i) {
|
||||
value_types.push_back(builder_->GetSType(pod_args[i].type()));
|
||||
}
|
||||
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
|
||||
for (size_t i = 0; i < pod_args.size(); ++i) {
|
||||
spirv::Value value = builder_->GetPushConstant(
|
||||
ptr, value_types[i], static_cast<uint32_t>(i));
|
||||
var_map_[pod_args[i].get()] = value;
|
||||
}
|
||||
}
|
||||
this->VisitStmt(f->body);
|
||||
builder_->SetLocalSize(func_ptr, workgroup_size_);
|
||||
builder_->MakeInst(spv::OpReturn);
|
||||
builder_->MakeInst(spv::OpFunctionEnd);
|
||||
|
||||
return builder_->Finalize();
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::InitFuncState() {
|
||||
std::fill(workgroup_size_, workgroup_size_ + 3, 1);
|
||||
var_map_.clear();
|
||||
storage_info_.clear();
|
||||
align_map_.clear();
|
||||
builder_.reset(new spirv::IRBuilder());
|
||||
builder_->InitHeader();
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::GetThreadIndex(
|
||||
const IterVar& iv, const Expr& extent) {
|
||||
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
|
||||
spirv::Value v;
|
||||
if (ts.rank == 1) {
|
||||
v = builder_->GetLocalID(ts.dim_index);
|
||||
int size;
|
||||
CHECK(arith::GetConstInt(extent, &size))
|
||||
<< "SPIRV only allows constant thread group size " << " get " << extent;
|
||||
CHECK_LT(ts.dim_index, 3);
|
||||
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
|
||||
} else {
|
||||
v = builder_->GetWorkgroupID(ts.dim_index);
|
||||
}
|
||||
return builder_->Cast(builder_->GetSType(iv->var.type()), v);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
|
||||
const std::string& sync = op->args[0].as<StringImm>()->value;
|
||||
spirv::Value value;
|
||||
if (sync == "warp") {
|
||||
return value;
|
||||
} else if (sync == "shared") {
|
||||
builder_->MakeInst(
|
||||
spv::OpControlBarrier,
|
||||
spv::ScopeWorkgroup,
|
||||
spv::ScopeWorkgroup,
|
||||
spv::MemorySemanticsSequentiallyConsistentMask |
|
||||
spv::MemorySemanticsWorkgroupMemoryMask);
|
||||
} else {
|
||||
LOG(FATAL) << "Do not support sync " << sync;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) {
|
||||
auto it = var_map_.find(op);
|
||||
CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const IntImm* op) {
|
||||
return builder_->IntImm(builder_->GetSType(op->type), op->value);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImm* op) {
|
||||
return builder_->UIntImm(builder_->GetSType(op->type), op->value);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImm* op) {
|
||||
return builder_->FloatImm(builder_->GetSType(op->type), op->value);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) {
|
||||
LOG(FATAL) << "StringImm is not supported in Device code";
|
||||
return spirv::Value();
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) {
|
||||
return builder_->Cast(builder_->GetSType(op->type), MakeValue(op->value));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) {
|
||||
return builder_->Add(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Sub* op) {
|
||||
return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Mul* op) {
|
||||
return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Div* op) {
|
||||
return builder_->Div(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Mod* op) {
|
||||
return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Min* op) {
|
||||
spirv::Value a = MakeValue(op->a);
|
||||
spirv::Value b = MakeValue(op->b);
|
||||
return builder_->Select(builder_->LT(a, b), a, b);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Max* op) {
|
||||
spirv::Value a = MakeValue(op->a);
|
||||
spirv::Value b = MakeValue(op->b);
|
||||
return builder_->Select(builder_->GT(a, b), a, b);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const LT* op) {
|
||||
return builder_->LT(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const LE* op) {
|
||||
return builder_->LE(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const GT* op) {
|
||||
return builder_->GT(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const GE* op) {
|
||||
return builder_->GE(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const EQ* op) {
|
||||
return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const NE* op) {
|
||||
return builder_->NE(MakeValue(op->a), MakeValue(op->b));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const And* op) {
|
||||
spirv::Value a = MakeValue(op->a);
|
||||
spirv::Value b = MakeValue(op->b);
|
||||
return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Or* op) {
|
||||
spirv::Value a = MakeValue(op->a);
|
||||
spirv::Value b = MakeValue(op->b);
|
||||
return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Not* op) {
|
||||
spirv::Value a = MakeValue(op->a);
|
||||
return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) {
|
||||
return builder_->Select(MakeValue(op->condition),
|
||||
MakeValue(op->true_value),
|
||||
MakeValue(op->false_value));
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) {
|
||||
CHECK(!var_map_.count(op->var.get()));
|
||||
var_map_[op->var.get()] = MakeValue(op->value);
|
||||
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
|
||||
return MakeValue(op->body);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
|
||||
if (op->is_intrinsic("spirv_glsl450")) {
|
||||
CHECK_GE(op->args.size(), 2U);
|
||||
uint32_t inst_id = op->args[0].as<UIntImm>()->value;
|
||||
std::vector<spirv::Value> values;
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
values.push_back(MakeValue(op->args[i]));
|
||||
}
|
||||
return builder_->CallGLSL450(
|
||||
builder_->GetSType(op->type), inst_id, values);
|
||||
} else if (op->is_intrinsic(Call::bitwise_and)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
spirv::Value a = MakeValue(op->args[0]);
|
||||
spirv::Value b = MakeValue(op->args[1]);
|
||||
return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
|
||||
} else if (op->is_intrinsic(Call::bitwise_xor)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
spirv::Value a = MakeValue(op->args[0]);
|
||||
spirv::Value b = MakeValue(op->args[1]);
|
||||
return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
|
||||
} else if (op->is_intrinsic(Call::bitwise_or)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
spirv::Value a = MakeValue(op->args[0]);
|
||||
spirv::Value b = MakeValue(op->args[1]);
|
||||
return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
|
||||
} else if (op->is_intrinsic(Call::bitwise_not)) {
|
||||
CHECK_EQ(op->args.size(), 1U);
|
||||
spirv::Value a = MakeValue(op->args[0]);
|
||||
return builder_->MakeValue(spv::OpNot, a.stype, a);
|
||||
} else if (op->is_intrinsic(Call::shift_left)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
spirv::Value a = MakeValue(op->args[0]);
|
||||
spirv::Value b = MakeValue(op->args[1]);
|
||||
return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
|
||||
} else if (op->is_intrinsic(Call::shift_right)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
spirv::Value a = MakeValue(op->args[0]);
|
||||
spirv::Value b = MakeValue(op->args[1]);
|
||||
if (op->args[0].type().is_int()) {
|
||||
return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
|
||||
} else {
|
||||
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
|
||||
}
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
|
||||
return this->CreateStorageSync(op);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
|
||||
CHECK_EQ(op->args.size(), 3U);
|
||||
spirv::Value cond = MakeValue(op->args[0]);
|
||||
spirv::Label then_label = builder_->NewLabel();
|
||||
spirv::Label else_label = builder_->NewLabel();
|
||||
spirv::Label merge_label = builder_->NewLabel();
|
||||
builder_->MakeInst(
|
||||
spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
|
||||
builder_->MakeInst(
|
||||
spv::OpBranchConditional, cond, then_label, else_label);
|
||||
// then block, must get label after we see the value
|
||||
builder_->StartLabel(then_label);
|
||||
spirv::Value then_value = MakeValue(op->args[1]);
|
||||
spirv::Label then_value_label = builder_->CurrentLabel();
|
||||
builder_->MakeInst(spv::OpBranch, merge_label);
|
||||
// else block
|
||||
builder_->StartLabel(else_label);
|
||||
spirv::Value else_value = MakeValue(op->args[2]);
|
||||
spirv::Label else_value_label = builder_->CurrentLabel();
|
||||
builder_->MakeInst(spv::OpBranch, merge_label);
|
||||
// merge block
|
||||
builder_->StartLabel(merge_label);
|
||||
spirv::PhiValue phi = builder_->MakePhi(then_value.stype, 2);
|
||||
phi.SetIncoming(0, then_value, then_value_label);
|
||||
phi.SetIncoming(1, else_value, else_value_label);
|
||||
return phi;
|
||||
} else if (op->is_intrinsic("popcount")) {
|
||||
return builder_->MakeValue(
|
||||
spv::OpBitCount,
|
||||
builder_->GetSType(op->type),
|
||||
MakeValue(op->args[0]));
|
||||
} else {
|
||||
if (op->call_type == Call::Intrinsic ||
|
||||
op->call_type == Call::PureIntrinsic) {
|
||||
LOG(FATAL) << "Unresolved intrinsic " << op->name
|
||||
<< " with return type " << op->type;
|
||||
} else if (op->call_type == Call::Extern ||
|
||||
op->call_type == Call::PureExtern) {
|
||||
LOG(FATAL) << "Unresolved extern " << op->name
|
||||
<< " with return type " << op->type;
|
||||
} else {
|
||||
LOG(FATAL) << "Unresolved call type " << op->call_type;
|
||||
}
|
||||
return spirv::Value();
|
||||
}
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
|
||||
std::vector<spirv::Value> values;
|
||||
spirv::Value base = MakeValue(op->base);
|
||||
for (int i = 0; i < op->lanes; ++i) {
|
||||
spirv::Value v = base;
|
||||
if (i != 0) {
|
||||
spirv::Value offset = MakeValue(
|
||||
arith::ComputeExpr<Mul>(make_const(op->stride.type(), i), op->stride));
|
||||
v = builder_->Add(v, offset);
|
||||
}
|
||||
values.push_back(v);
|
||||
}
|
||||
return builder_->Concat(values);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Broadcast* op) {
|
||||
std::vector<spirv::Value> values;
|
||||
spirv::Value v = MakeValue(op->value);
|
||||
for (int i = 0; i < op->lanes; i++) {
|
||||
values.push_back(v);
|
||||
}
|
||||
return builder_->Concat(values);
|
||||
}
|
||||
|
||||
spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
|
||||
CHECK(is_one(op->predicate));
|
||||
auto it = storage_info_.find(op->buffer_var.get());
|
||||
CHECK(it != storage_info_.end());
|
||||
StorageInfo& info = it->second;
|
||||
if (!info.content_fixed) {
|
||||
info.UpdateContentType(op->type);
|
||||
}
|
||||
|
||||
spirv::SType content_type = builder_->GetSType(info.content_type);
|
||||
spirv::Value buffer = MakeValue(op->buffer_var);
|
||||
spirv::SType ptr_type = builder_->GetPointerType(
|
||||
content_type, buffer.stype.storage_class);
|
||||
|
||||
uint32_t mask = spv::MemoryAccessMaskNone;
|
||||
if (info.is_volatile) {
|
||||
mask |= spv::MemoryAccessVolatileMask;
|
||||
}
|
||||
if (op->type.lanes() == 1) {
|
||||
CHECK_EQ(info.content_type, op->type)
|
||||
<< "Vulkan only allow one type access to the same buffer";
|
||||
spirv::Value index = MakeValue(op->index);
|
||||
spirv::Value ptr = builder_->StructArrayAccess(
|
||||
ptr_type, buffer, index);
|
||||
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
|
||||
} else {
|
||||
if (op->type.element_of() == info.content_type) {
|
||||
// because content type is element type, we can only do scalarize load.
|
||||
std::vector<spirv::Value> values;
|
||||
auto f = [&](int i, spirv::Value index) {
|
||||
spirv::Value ptr = builder_->StructArrayAccess(
|
||||
ptr_type, buffer, index);
|
||||
values.emplace_back(
|
||||
builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
|
||||
};
|
||||
this->Scalarize(op->index, f);
|
||||
return builder_->Concat(values);
|
||||
} else {
|
||||
if (const Ramp* ramp = op->index.as<Ramp>()) {
|
||||
if (is_one(ramp->stride)) {
|
||||
CHECK_EQ(ramp->lanes, op->type.lanes());
|
||||
arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_);
|
||||
CHECK((me.coeff % ramp->lanes) == 0 &&
|
||||
(me.base % ramp->lanes) == 0)
|
||||
<< "Only aligned vector access is allowed in SPIRV";
|
||||
Expr vec_index = ir::Simplify(
|
||||
ramp->base / make_const(ramp->base.type(), ramp->lanes));
|
||||
spirv::Value ptr = builder_->StructArrayAccess(
|
||||
ptr_type, buffer, MakeValue(vec_index));
|
||||
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
|
||||
}
|
||||
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
|
||||
return spirv::Value();
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::Scalarize(const Expr& e,
|
||||
std::function<void(int i, spirv::Value v)> f) {
|
||||
if (const Ramp* ramp = e.as<Ramp>()) {
|
||||
for (int i = 0; i < ramp->type.lanes(); ++i) {
|
||||
Expr offset = arith::ComputeExpr<Add>(
|
||||
ramp->base,
|
||||
arith::ComputeExpr<Mul>(ramp->stride, i));
|
||||
f(i, MakeValue(offset));
|
||||
}
|
||||
} else {
|
||||
spirv::SType etype = builder_->GetSType(e.type().element_of());
|
||||
spirv::Value value = MakeValue(e);
|
||||
for (int i = 0; i < e.type().lanes(); ++i) {
|
||||
f(i, builder_->MakeValue(
|
||||
spv::OpCompositeExtract, etype, value, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const Store* op) {
|
||||
CHECK(is_one(op->predicate));
|
||||
auto it = storage_info_.find(op->buffer_var.get());
|
||||
CHECK(it != storage_info_.end());
|
||||
StorageInfo& info = it->second;
|
||||
|
||||
if (!info.content_fixed) {
|
||||
info.UpdateContentType(op->value.type());
|
||||
}
|
||||
|
||||
spirv::SType content_type = builder_->GetSType(info.content_type);
|
||||
spirv::Value buffer = MakeValue(op->buffer_var);
|
||||
spirv::Value value = MakeValue(op->value);
|
||||
spirv::SType ptr_type = builder_->GetPointerType(
|
||||
content_type, buffer.stype.storage_class);
|
||||
|
||||
uint32_t mask = spv::MemoryAccessMaskNone;
|
||||
if (info.is_volatile) {
|
||||
mask |= spv::MemoryAccessVolatileMask;
|
||||
}
|
||||
|
||||
if (op->value.type().lanes() == 1) {
|
||||
CHECK_EQ(info.content_type, op->value.type())
|
||||
<< "Vulkan only allow one type access to the same buffer";
|
||||
spirv::Value index = MakeValue(op->index);
|
||||
spirv::Value ptr = builder_->StructArrayAccess(
|
||||
ptr_type, buffer, index);
|
||||
builder_->MakeInst(spv::OpStore, ptr, value, mask);
|
||||
} else {
|
||||
if (op->value.type().element_of() == info.content_type) {
|
||||
// because content type is element type, we can only do scalarize load.
|
||||
auto f = [&](int i, spirv::Value index) {
|
||||
spirv::Value elem = builder_->MakeValue(
|
||||
spv::OpCompositeExtract, content_type, value, i);
|
||||
spirv::Value ptr = builder_->StructArrayAccess(
|
||||
ptr_type, buffer, index);
|
||||
builder_->MakeInst(spv::OpStore, ptr, elem, mask);
|
||||
};
|
||||
this->Scalarize(op->index, f);
|
||||
} else {
|
||||
if (const Ramp* ramp = op->index.as<Ramp>()) {
|
||||
if (is_one(ramp->stride)) {
|
||||
CHECK_EQ(ramp->lanes, op->value.type().lanes());
|
||||
arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_);
|
||||
CHECK((me.coeff % ramp->lanes) == 0 &&
|
||||
(me.base % ramp->lanes) == 0)
|
||||
<< "Only aligned vector access is allowed in SPIRV";
|
||||
Expr vec_index = ir::Simplify(
|
||||
ramp->base / make_const(ramp->base.type(), ramp->lanes));
|
||||
spirv::Value ptr = builder_->StructArrayAccess(
|
||||
ptr_type, buffer, MakeValue(vec_index));
|
||||
builder_->MakeInst(spv::OpStore, ptr, value, mask);
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const For* op) {
|
||||
CHECK(is_zero(op->min));
|
||||
spirv::Value init_value = MakeValue(op->min);
|
||||
spirv::Value extent_value = MakeValue(op->extent);
|
||||
// Must get init label after making value(to make sure they are correct)
|
||||
spirv::Label init_label = builder_->CurrentLabel();
|
||||
spirv::Label head_label = builder_->NewLabel();
|
||||
spirv::Label body_label = builder_->NewLabel();
|
||||
spirv::Label continue_label = builder_->NewLabel();
|
||||
spirv::Label merge_label = builder_->NewLabel();
|
||||
builder_->MakeInst(spv::OpBranch, head_label);
|
||||
|
||||
// Loop head
|
||||
builder_->StartLabel(head_label);
|
||||
spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
|
||||
loop_var.SetIncoming(0, init_value, init_label);
|
||||
spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
|
||||
uint32_t control = (
|
||||
op->for_type == ForType::Unrolled ?
|
||||
spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
|
||||
builder_->MakeInst(
|
||||
spv::OpLoopMerge, merge_label, continue_label, control);
|
||||
builder_->MakeInst(
|
||||
spv::OpBranchConditional, loop_cond, body_label, merge_label,
|
||||
weight_likely_branch_, 1);
|
||||
|
||||
// loop body
|
||||
builder_->StartLabel(body_label);
|
||||
var_map_[op->loop_var.get()] = spirv::Value(loop_var);
|
||||
this->VisitStmt(op->body);
|
||||
builder_->MakeInst(spv::OpBranch, continue_label);
|
||||
|
||||
// loop continue
|
||||
builder_->StartLabel(continue_label);
|
||||
spirv::Value one =
|
||||
op->loop_var.type().is_int() ?
|
||||
builder_->IntImm(loop_var.stype, 1) :
|
||||
builder_->UIntImm(loop_var.stype, 1);
|
||||
spirv::Value next_value = builder_->Add(loop_var, one);
|
||||
loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
|
||||
builder_->MakeInst(spv::OpBranch, head_label);
|
||||
// loop merge
|
||||
builder_->StartLabel(merge_label);
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) {
|
||||
spirv::Value cond = MakeValue(op->condition);
|
||||
spirv::Label then_label = builder_->NewLabel();
|
||||
spirv::Label merge_label = builder_->NewLabel();
|
||||
if (op->else_case.defined()) {
|
||||
spirv::Label else_label = builder_->NewLabel();
|
||||
builder_->MakeInst(
|
||||
spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
|
||||
builder_->MakeInst(
|
||||
spv::OpBranchConditional, cond, then_label, else_label);
|
||||
// then block
|
||||
builder_->StartLabel(then_label);
|
||||
this->VisitStmt(op->then_case);
|
||||
builder_->MakeInst(spv::OpBranch, merge_label);
|
||||
// else block
|
||||
builder_->StartLabel(else_label);
|
||||
this->VisitStmt(op->else_case);
|
||||
builder_->MakeInst(spv::OpBranch, merge_label);
|
||||
} else {
|
||||
builder_->MakeInst(
|
||||
spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
|
||||
builder_->MakeInst(
|
||||
spv::OpBranchConditional, cond, then_label, merge_label,
|
||||
weight_likely_branch_, 1);
|
||||
// then block
|
||||
builder_->StartLabel(then_label);
|
||||
this->VisitStmt(op->then_case);
|
||||
builder_->MakeInst(spv::OpBranch, merge_label);
|
||||
}
|
||||
// start merge label;
|
||||
builder_->StartLabel(merge_label);
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
|
||||
CHECK(!is_zero(op->condition));
|
||||
CHECK(!op->new_expr.defined());
|
||||
CHECK(!op->type.is_handle());
|
||||
int32_t constant_size = op->constant_allocation_size();
|
||||
CHECK_GT(constant_size, 0)
|
||||
<< "Can only handle constant size stack allocation in GPU";
|
||||
spirv::Value buf;
|
||||
StorageInfo& info = storage_info_[op->buffer_var.get()];
|
||||
spirv::SType etype = builder_->GetSType(op->type);
|
||||
if (info.scope.rank == 2) {
|
||||
buf = builder_->Allocate(
|
||||
etype, static_cast<uint32_t>(constant_size),
|
||||
spv::StorageClassFunction);
|
||||
} else {
|
||||
// shared memory
|
||||
CHECK_EQ(info.scope.rank, 1)
|
||||
<< "Can only allocate shared or local memory inside kernel";
|
||||
// Shared memory
|
||||
buf = builder_->Allocate(
|
||||
etype, static_cast<uint32_t>(constant_size),
|
||||
spv::StorageClassWorkgroup);
|
||||
}
|
||||
CHECK(!info.content_fixed);
|
||||
info.UpdateContentType(op->type);
|
||||
CHECK(!var_map_.count(op->buffer_var.get()));
|
||||
var_map_[op->buffer_var.get()] = buf;
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
|
||||
if (op->attr_key == attr::thread_extent) {
|
||||
IterVar iv(op->node.node_);
|
||||
if (iv->thread_tag.length() != 0) {
|
||||
if (!var_map_.count(iv->var.get())) {
|
||||
var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
|
||||
}
|
||||
}
|
||||
} else if (op->attr_key == ir::attr::storage_scope) {
|
||||
const Variable* v = op->node.as<Variable>();
|
||||
CHECK(v);
|
||||
storage_info_[v].scope =
|
||||
runtime::StorageScope::make(op->value.as<StringImm>()->value);
|
||||
} else if (op->attr_key == ir::attr::volatile_scope) {
|
||||
const Variable* v = op->node.as<Variable>();
|
||||
CHECK(v);
|
||||
storage_info_[v].is_volatile = true;
|
||||
}
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
|
||||
VisitAssert(op, &align_map_, [this](const Stmt& body) {
|
||||
this->VisitStmt(body);
|
||||
});
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
|
||||
CHECK(!var_map_.count(op->var.get()));
|
||||
CHECK(!align_map_.count(op->var.get()));
|
||||
CHECK(!op->var.type().is_handle());
|
||||
var_map_[op->var.get()] = MakeValue(op->value);
|
||||
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const Block* op) {
|
||||
VisitStmt(op->first);
|
||||
if (op->rest.defined()) {
|
||||
this->VisitStmt(op->rest);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const Evaluate* op) {
|
||||
MakeValue(op->value);
|
||||
}
|
||||
|
||||
void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_VULKAN_RUNTIME
|
|
@ -0,0 +1,133 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file ir_builder.h
|
||||
* \brief Utility for building SPIRV code block
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
|
||||
#define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_functor_ext.h>
|
||||
#include <tvm/lowered_func.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "./ir_builder.h"
|
||||
#include "../../runtime/thread_storage_scope.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
using namespace ir;
|
||||
|
||||
/*!
|
||||
* \brief Code generator into SPIRV
|
||||
*/
|
||||
class CodeGenSPIRV:
|
||||
public ExprFunctor<spirv::Value(const Expr&)>,
|
||||
public StmtFunctor<void(const Stmt&)> {
|
||||
public:
|
||||
/*!
|
||||
* \brief Compile and add function f to the current module.
|
||||
* \param f The function to be added.
|
||||
* \return The final spirv module.
|
||||
*/
|
||||
virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f);
|
||||
/*!
|
||||
* \brief Create Value for expression e
|
||||
* \param e The expression to be created value for.
|
||||
* \return created value.
|
||||
*/
|
||||
spirv::Value MakeValue(const Expr& e) {
|
||||
return VisitExpr(e);
|
||||
}
|
||||
// override codegen
|
||||
spirv::Value VisitExpr_(const Variable* op) override;
|
||||
spirv::Value VisitExpr_(const Cast* op) override;
|
||||
spirv::Value VisitExpr_(const IntImm* op) override;
|
||||
spirv::Value VisitExpr_(const UIntImm* op) override;
|
||||
spirv::Value VisitExpr_(const FloatImm* op) override;
|
||||
spirv::Value VisitExpr_(const StringImm* op) override;
|
||||
spirv::Value VisitExpr_(const Add* op) override;
|
||||
spirv::Value VisitExpr_(const Sub* op) override;
|
||||
spirv::Value VisitExpr_(const Mul* op) override;
|
||||
spirv::Value VisitExpr_(const Div* op) override;
|
||||
spirv::Value VisitExpr_(const Mod* op) override;
|
||||
spirv::Value VisitExpr_(const Min* op) override;
|
||||
spirv::Value VisitExpr_(const Max* op) override;
|
||||
spirv::Value VisitExpr_(const LT* op) override;
|
||||
spirv::Value VisitExpr_(const LE* op) override;
|
||||
spirv::Value VisitExpr_(const GT* op) override;
|
||||
spirv::Value VisitExpr_(const GE* op) override;
|
||||
spirv::Value VisitExpr_(const EQ* op) override;
|
||||
spirv::Value VisitExpr_(const NE* op) override;
|
||||
spirv::Value VisitExpr_(const And* op) override;
|
||||
spirv::Value VisitExpr_(const Or* op) override;
|
||||
spirv::Value VisitExpr_(const Not* op) override;
|
||||
spirv::Value VisitExpr_(const Select* op) override;
|
||||
spirv::Value VisitExpr_(const Let* op) override;
|
||||
spirv::Value VisitExpr_(const Call* op) override;
|
||||
spirv::Value VisitExpr_(const Ramp* op) override;
|
||||
spirv::Value VisitExpr_(const Broadcast* op) override;
|
||||
spirv::Value VisitExpr_(const Load* op) override;
|
||||
// stmt
|
||||
void VisitStmt_(const Store* op) override;
|
||||
void VisitStmt_(const For* op) override;
|
||||
void VisitStmt_(const IfThenElse* op) override;
|
||||
void VisitStmt_(const Allocate* op) override;
|
||||
void VisitStmt_(const AttrStmt* op) override;
|
||||
void VisitStmt_(const AssertStmt* op) override;
|
||||
void VisitStmt_(const LetStmt* op) override;
|
||||
void VisitStmt_(const Block* op) override;
|
||||
void VisitStmt_(const Evaluate* op) override;
|
||||
void VisitStmt_(const ProducerConsumer* op) override;
|
||||
|
||||
protected:
|
||||
/*! \brief The storage information */
|
||||
struct StorageInfo {
|
||||
/*! \brief The storage scope */
|
||||
runtime::StorageScope scope;
|
||||
/*! \brief Whether it is volatile */
|
||||
bool is_volatile{false};
|
||||
/*! \brief Whether it is volatile */
|
||||
bool content_fixed{false};
|
||||
/*! \brief Current content type */
|
||||
Type content_type{Handle()};
|
||||
|
||||
// Update content type if it hasn't beenupdated.
|
||||
void UpdateContentType(Type type) {
|
||||
if (content_fixed) {
|
||||
CHECK_EQ(type, content_type)
|
||||
<< "Cannot use two different content type in GLSL model";
|
||||
} else {
|
||||
this->content_type = type;
|
||||
content_fixed = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
// Reset the state so it works for a new function.
|
||||
void InitFuncState();
|
||||
// Get the thread index
|
||||
spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
|
||||
spirv::Value CreateStorageSync(const Call* op);
|
||||
void Scalarize(const Expr& e,
|
||||
std::function<void(int i, spirv::Value v)> f);
|
||||
// The builder
|
||||
std::unique_ptr<spirv::IRBuilder> builder_;
|
||||
// Work group size of three
|
||||
uint32_t workgroup_size_[3];
|
||||
// Likely branch
|
||||
uint32_t weight_likely_branch_{128};
|
||||
// the storage scope of allocation
|
||||
std::unordered_map<const Variable*, StorageInfo> storage_info_;
|
||||
// The definition of local variable.
|
||||
std::unordered_map<const Variable*, spirv::Value> var_map_;
|
||||
// The alignment information
|
||||
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
|
||||
#endif // TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
|
|
@ -0,0 +1,50 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file intrin_rule_spirv.cc
|
||||
*/
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
#include <tvm/packed_func_ext.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <vulkan/GLSL.std.450.h>
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
namespace spirv {
|
||||
|
||||
using namespace runtime;
|
||||
|
||||
// num_signature means number of arguments used to query signature
|
||||
template<unsigned id>
|
||||
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
|
||||
Expr e = targs[0];
|
||||
const ir::Call* call = e.as<ir::Call>();
|
||||
CHECK(call != nullptr);
|
||||
Array<Expr> cargs;
|
||||
// intrin id.
|
||||
cargs.push_back(ir::UIntImm::make(UInt(32), id));
|
||||
|
||||
for (Expr arg : call->args) {
|
||||
cargs.push_back(arg);
|
||||
}
|
||||
*rv = ir::Call::make(
|
||||
call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
|
||||
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log")
|
||||
.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt")
|
||||
.set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
|
||||
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_VULKAN_RUNTIME
|
|
@ -0,0 +1,548 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file ir_builder.cc
|
||||
* \brief IRBuilder for SPIRV block
|
||||
*/
|
||||
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
#include "./ir_builder.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
namespace spirv {
|
||||
|
||||
// implementations
|
||||
|
||||
void IRBuilder::InitHeader() {
|
||||
CHECK_EQ(header_.size(), 0U);
|
||||
header_.push_back(spv::MagicNumber);
|
||||
header_.push_back(spv::Version);
|
||||
// generator: set to 0, unknown
|
||||
header_.push_back(0U);
|
||||
// Bound: set during Finalize
|
||||
header_.push_back(0U);
|
||||
// Schema: reserved
|
||||
header_.push_back(0U);
|
||||
// shader
|
||||
ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_);
|
||||
// memory model
|
||||
ib_.Begin(spv::OpMemoryModel).AddSeq(
|
||||
spv::AddressingModelLogical,
|
||||
spv::MemoryModelGLSL450).Commit(&entry_);
|
||||
this->InitPreDefs();
|
||||
}
|
||||
|
||||
void IRBuilder::InitPreDefs() {
|
||||
ext_glsl450_ = ExtInstImport("GLSL.std.450");
|
||||
t_int32_ = DeclareType(Int(32));
|
||||
t_uint32_ = DeclareType(UInt(32));
|
||||
t_bool_ = DeclareType(UInt(1));
|
||||
t_fp32_ = DeclareType(Float(32));
|
||||
const_i32_zero_ = IntImm(t_int32_, 0);
|
||||
// declare void, and void functions
|
||||
t_void_.id = id_counter_++;
|
||||
ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_);
|
||||
t_void_func_.id = id_counter_++;
|
||||
ib_.Begin(spv::OpTypeFunction)
|
||||
.AddSeq(t_void_func_, t_void_).Commit(&global_);
|
||||
}
|
||||
|
||||
SType IRBuilder::GetSType(const Type& dtype) {
|
||||
if (dtype == Int(32)) {
|
||||
return t_int32_;
|
||||
} else if (dtype == UInt(1)) {
|
||||
return t_bool_;
|
||||
} else if (dtype == Float(32)) {
|
||||
return t_fp32_;
|
||||
} else if (dtype == UInt(32)) {
|
||||
return t_uint32_;
|
||||
}
|
||||
uint32_t type_key;
|
||||
type_key = static_cast<uint32_t>(dtype.code());
|
||||
type_key |= static_cast<uint32_t>(dtype.bits()) << 8U;
|
||||
type_key |= static_cast<uint32_t>(dtype.lanes()) << 16U;
|
||||
auto it = pod_type_tbl_.find(type_key);
|
||||
if (it != pod_type_tbl_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
SType t = DeclareType(dtype);
|
||||
pod_type_tbl_[type_key] = t;
|
||||
return t;
|
||||
}
|
||||
|
||||
SType IRBuilder::GetPointerType(const SType& value_type,
|
||||
spv::StorageClass storage_class) {
|
||||
CHECK_NE(storage_class, spv::StorageClassMax);
|
||||
auto key = std::make_pair(value_type.id, storage_class);
|
||||
auto it = pointer_type_tbl_.find(key);
|
||||
if (it != pointer_type_tbl_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
SType t;
|
||||
t.id = id_counter_++;
|
||||
t.type = Handle();
|
||||
t.element_type_id = value_type.id;
|
||||
t.storage_class = storage_class;
|
||||
ib_.Begin(spv::OpTypePointer)
|
||||
.AddSeq(t, storage_class, value_type).Commit(&global_);
|
||||
pointer_type_tbl_[key] = t;
|
||||
return t;
|
||||
}
|
||||
|
||||
SType IRBuilder::GetStructArrayType(const SType& value_type,
|
||||
uint32_t num_elems) {
|
||||
auto key = std::make_pair(value_type.id, num_elems);
|
||||
auto it = struct_array_type_tbl_.find(key);
|
||||
if (it != struct_array_type_tbl_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
SType arr_type;
|
||||
arr_type.id = id_counter_++;
|
||||
arr_type.type = Handle();
|
||||
arr_type.element_type_id = value_type.id;
|
||||
|
||||
if (num_elems != 0) {
|
||||
Value length = UIntImm(GetSType(UInt(32)), num_elems);
|
||||
ib_.Begin(spv::OpTypeArray)
|
||||
.AddSeq(arr_type, value_type, length).Commit(&global_);
|
||||
} else {
|
||||
ib_.Begin(spv::OpTypeRuntimeArray)
|
||||
.AddSeq(arr_type, value_type).Commit(&global_);
|
||||
}
|
||||
int nbits = value_type.type.bits() * value_type.type.lanes();
|
||||
CHECK_EQ(nbits % 8, 0);
|
||||
uint32_t nbytes = static_cast<uint32_t>(nbits) / 8;
|
||||
// decorate the array type.
|
||||
this->Decorate(spv::OpDecorate,
|
||||
arr_type, spv::DecorationArrayStride, nbytes);
|
||||
// declare struct of array
|
||||
SType struct_type;
|
||||
struct_type.id = id_counter_++;
|
||||
struct_type.type = Handle();
|
||||
struct_type.element_type_id = value_type.id;
|
||||
ib_.Begin(spv::OpTypeStruct)
|
||||
.AddSeq(struct_type, arr_type).Commit(&global_);
|
||||
// decorate the array type.
|
||||
ib_.Begin(spv::OpMemberDecorate)
|
||||
.AddSeq(struct_type, 0, spv::DecorationOffset, 0)
|
||||
.Commit(&decorate_);
|
||||
// runtime array are always decorated as BufferBlock(shader storage buffer)
|
||||
if (num_elems == 0) {
|
||||
this->Decorate(spv::OpDecorate,
|
||||
struct_type, spv::DecorationBufferBlock);
|
||||
}
|
||||
struct_array_type_tbl_[key] = struct_type;
|
||||
return struct_type;
|
||||
}
|
||||
|
||||
Value IRBuilder::StructArrayAccess(const SType& res_type,
|
||||
Value buffer,
|
||||
Value index) {
|
||||
CHECK(buffer.flag == kStructArrayPtr);
|
||||
return MakeValue(spv::OpInBoundsAccessChain,
|
||||
res_type, buffer,
|
||||
const_i32_zero_, index);
|
||||
}
|
||||
|
||||
Value IRBuilder::IntImm(const SType& dtype, int64_t value) {
|
||||
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
|
||||
}
|
||||
|
||||
Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) {
|
||||
return GetConst_(dtype, &value);
|
||||
}
|
||||
|
||||
Value IRBuilder::FloatImm(const SType& dtype, double value) {
|
||||
if (dtype.type.bits() == 64) {
|
||||
return GetConst_(dtype, reinterpret_cast<uint64_t*>(&value));
|
||||
} else if (dtype.type.bits() == 32) {
|
||||
float fvalue = static_cast<float>(value);
|
||||
uint64_t data = reinterpret_cast<uint32_t*>(&fvalue)[0];
|
||||
return GetConst_(dtype, &data);
|
||||
} else {
|
||||
CHECK_EQ(dtype.type.bits(), 16);
|
||||
return Cast(dtype,
|
||||
FloatImm(GetSType(Float(32)), value));
|
||||
}
|
||||
}
|
||||
|
||||
Value IRBuilder::BufferArgument(const SType& value_type,
|
||||
uint32_t descriptor_set,
|
||||
uint32_t binding) {
|
||||
SType sarr_type = GetStructArrayType(value_type, 0);
|
||||
SType ptr_type = GetPointerType(sarr_type, spv::StorageClassUniform);
|
||||
Value val = NewValue(ptr_type, kStructArrayPtr);
|
||||
ib_.Begin(spv::OpVariable)
|
||||
.AddSeq(ptr_type, val, spv::StorageClassUniform).Commit(&global_);
|
||||
this->Decorate(spv::OpDecorate,
|
||||
val, spv::DecorationDescriptorSet, descriptor_set);
|
||||
this->Decorate(spv::OpDecorate,
|
||||
val, spv::DecorationBinding, binding);
|
||||
return val;
|
||||
}
|
||||
|
||||
Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
|
||||
CHECK_EQ(push_const_.id, 0);
|
||||
SType struct_type;
|
||||
struct_type.id = id_counter_++;
|
||||
struct_type.type = Handle();
|
||||
ib_.Begin(spv::OpTypeStruct).Add(struct_type);
|
||||
for (const SType& vtype : value_types) {
|
||||
ib_.Add(vtype);
|
||||
}
|
||||
ib_.Commit(&global_);
|
||||
|
||||
uint32_t offset = 0;
|
||||
for (uint32_t i = 0; i < value_types.size(); ++i) {
|
||||
ib_.Begin(spv::OpMemberDecorate)
|
||||
.AddSeq(struct_type, i, spv::DecorationOffset, offset)
|
||||
.Commit(&decorate_);
|
||||
Type t = value_types[i].type;
|
||||
uint32_t nbits = t.bits() * t.lanes();
|
||||
CHECK_EQ(nbits % 8 , 0);
|
||||
offset += nbits / 8;
|
||||
}
|
||||
// Decorate push constants as UBO
|
||||
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
|
||||
|
||||
SType ptr_type = GetPointerType(
|
||||
struct_type, spv::StorageClassPushConstant);
|
||||
Value val = NewValue(ptr_type, kPushConstantPtr);
|
||||
ib_.Begin(spv::OpVariable)
|
||||
.AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
|
||||
return val;
|
||||
}
|
||||
|
||||
Value IRBuilder::GetPushConstant(
|
||||
Value ptr_push_const, const SType& v_type, uint32_t index) {
|
||||
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
|
||||
Value ptr = this->MakeValue(
|
||||
spv::OpAccessChain, ptr_vtype, ptr_push_const,
|
||||
IntImm(t_int32_, static_cast<int64_t>(index)));
|
||||
return this->MakeValue(spv::OpLoad, v_type, ptr);
|
||||
}
|
||||
|
||||
Value IRBuilder::DeclareKenrelFunction(const std::string& name) {
|
||||
Value val = NewValue(t_void_func_, kFunction);
|
||||
ib_.Begin(spv::OpEntryPoint)
|
||||
.AddSeq(spv::ExecutionModelGLCompute, val, name)
|
||||
.Commit(&entry_);
|
||||
return val;
|
||||
}
|
||||
|
||||
void IRBuilder::StartFunction(const Value& func) {
|
||||
CHECK_EQ(func.flag, kFunction);
|
||||
this->MakeInst(
|
||||
spv::OpFunction, t_void_, func, 0, t_void_func_);
|
||||
spirv::Label start_label = this->NewLabel();
|
||||
this->StartLabel(start_label);
|
||||
}
|
||||
|
||||
void IRBuilder::SetLocalSize(const Value& func,
|
||||
uint32_t local_size[3]) {
|
||||
CHECK_EQ(func.flag, kFunction);
|
||||
ib_.Begin(spv::OpExecutionMode)
|
||||
.AddSeq(func, spv::ExecutionModeLocalSize,
|
||||
local_size[0], local_size[1], local_size[2])
|
||||
.Commit(&exec_mode_);
|
||||
}
|
||||
|
||||
Value IRBuilder::Allocate(const SType& value_type,
|
||||
uint32_t num_elems,
|
||||
spv::StorageClass storage_class) {
|
||||
CHECK_NE(num_elems, 0U);
|
||||
SType sarr_type = GetStructArrayType(value_type, num_elems);
|
||||
SType ptr_type = GetPointerType(sarr_type, storage_class);
|
||||
Value val = NewValue(ptr_type, kStructArrayPtr);
|
||||
if (storage_class == spv::StorageClassFunction) {
|
||||
ib_.Begin(spv::OpVariable)
|
||||
.AddSeq(ptr_type, val, storage_class).Commit(&function_);
|
||||
} else {
|
||||
ib_.Begin(spv::OpVariable)
|
||||
.AddSeq(ptr_type, val, storage_class).Commit(&global_);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
Value IRBuilder::GetWorkgroupID(uint32_t dim_index) {
|
||||
if (workgroup_id_.id == 0) {
|
||||
SType vec3_type = this->GetSType(Int(32).with_lanes(3));
|
||||
SType ptr_type = this->GetPointerType(
|
||||
vec3_type, spv::StorageClassInput);
|
||||
workgroup_id_ = NewValue(ptr_type, kVectorPtr);
|
||||
ib_.Begin(spv::OpVariable)
|
||||
.AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput)
|
||||
.Commit(&global_);
|
||||
this->Decorate(spv::OpDecorate, workgroup_id_,
|
||||
spv::DecorationBuiltIn, spv::BuiltInWorkgroupId);
|
||||
}
|
||||
SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput);
|
||||
Value ptr = this->MakeValue(
|
||||
spv::OpAccessChain, pint_type, workgroup_id_,
|
||||
IntImm(t_int32_, static_cast<int64_t>(dim_index)));
|
||||
return this->MakeValue(spv::OpLoad, t_int32_, ptr);
|
||||
}
|
||||
|
||||
Value IRBuilder::GetLocalID(uint32_t dim_index) {
|
||||
if (local_id_.id == 0) {
|
||||
SType vec3_type = this->GetSType(Int(32).with_lanes(3));
|
||||
SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput);
|
||||
local_id_ = NewValue(ptr_type, kVectorPtr);
|
||||
ib_.Begin(spv::OpVariable)
|
||||
.AddSeq(ptr_type, local_id_, spv::StorageClassInput)
|
||||
.Commit(&global_);
|
||||
this->Decorate(spv::OpDecorate, local_id_,
|
||||
spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId);
|
||||
}
|
||||
SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput);
|
||||
Value ptr = this->MakeValue(
|
||||
spv::OpAccessChain, pint_type, local_id_,
|
||||
UIntImm(t_int32_, static_cast<int64_t>(dim_index)));
|
||||
return this->MakeValue(spv::OpLoad, t_int32_, ptr);
|
||||
}
|
||||
|
||||
Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
|
||||
auto key = std::make_pair(dtype.id, pvalue[0]);
|
||||
auto it = const_tbl_.find(key);
|
||||
if (it != const_tbl_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
CHECK_LE(dtype.type.bits(), 64);
|
||||
Value ret = NewValue(dtype, kConstant);
|
||||
ib_.Begin(spv::OpConstant).AddSeq(dtype, ret);
|
||||
uint64_t mask = 0xFFFFFFFFUL;
|
||||
ib_.Add(static_cast<uint32_t>(pvalue[0] & mask));
|
||||
if (dtype.type.bits() > 32) {
|
||||
if (dtype.type.is_int()) {
|
||||
int64_t sign_mask = 0xFFFFFFFFL;
|
||||
const int64_t* sign_ptr =
|
||||
reinterpret_cast<const int64_t*>(pvalue);
|
||||
ib_.Add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
|
||||
} else {
|
||||
ib_.Add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
|
||||
}
|
||||
}
|
||||
ib_.Commit(&global_);
|
||||
const_tbl_[key] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
SType IRBuilder::DeclareType(const Type& dtype) {
|
||||
if (dtype.lanes() == 1) {
|
||||
SType t;
|
||||
t.id = id_counter_++;
|
||||
t.type = dtype;
|
||||
if (dtype.bits() == 1) {
|
||||
CHECK(dtype.is_uint());
|
||||
ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
|
||||
} else if (dtype.is_int()) {
|
||||
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
|
||||
} else if (dtype.is_uint()) {
|
||||
ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 0).Commit(&global_);
|
||||
} else if (dtype.is_float()) {
|
||||
ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_);
|
||||
} else {
|
||||
LOG(FATAL) << "declare type do not support handle";
|
||||
}
|
||||
return t;
|
||||
} else {
|
||||
SType t;
|
||||
t.id = id_counter_++;
|
||||
t.type = dtype;
|
||||
SType base_type = GetSType(dtype.element_of());
|
||||
ib_.Begin(spv::OpTypeVector).AddSeq(
|
||||
t, base_type, dtype.lanes()).Commit(&global_);
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) {
|
||||
Value val = NewValue(out_type, kNormal);
|
||||
ib_.Begin(spv::OpPhi).AddSeq(out_type, val);
|
||||
for (uint32_t i = 0; i < 2 * num_incoming; ++i) {
|
||||
ib_.Add(0);
|
||||
}
|
||||
PhiValue phi;
|
||||
phi.id = val.id;
|
||||
phi.stype = out_type;
|
||||
phi.flag = kNormal;
|
||||
phi.instr = ib_.Commit(&function_);
|
||||
CHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3);
|
||||
return phi;
|
||||
}
|
||||
|
||||
Value IRBuilder::CallGLSL450(const SType& ret_type,
|
||||
uint32_t inst_id,
|
||||
const std::vector<Value>& args) {
|
||||
Value val = NewValue(ret_type, kNormal);
|
||||
ib_.Begin(spv::OpExtInst)
|
||||
.AddSeq(ret_type, val, ext_glsl450_, inst_id);
|
||||
for (const Value& v : args) {
|
||||
ib_.Add(v);
|
||||
}
|
||||
ib_.Commit(&function_);
|
||||
return val;
|
||||
}
|
||||
|
||||
Value IRBuilder::Concat(const std::vector<Value>& vec) {
|
||||
bool is_const = vec[0].flag == kConstant;
|
||||
Type etype = vec[0].stype.type;
|
||||
int lanes = etype.lanes();
|
||||
for (size_t i = 1; i < vec.size(); ++i) {
|
||||
CHECK_EQ(etype, vec[i].stype.type.element_of())
|
||||
<< "Cannot concat vector of different element type";
|
||||
lanes += vec[i].stype.type.lanes();
|
||||
is_const = is_const && (vec[i].flag == kConstant);
|
||||
}
|
||||
Value ret = NewValue(GetSType(etype.with_lanes(lanes)), kNormal);
|
||||
if (is_const && vec.size() == static_cast<size_t>(lanes)) {
|
||||
ib_.Begin(spv::OpConstantComposite);
|
||||
ib_.AddSeq(ret.stype, ret);
|
||||
for (const Value& v : vec) {
|
||||
ib_.Add(v);
|
||||
}
|
||||
ib_.Commit(&global_);
|
||||
} else {
|
||||
ib_.Begin(spv::OpCompositeConstruct);
|
||||
ib_.AddSeq(ret.stype, ret);
|
||||
for (const Value& v : vec) {
|
||||
ib_.Add(v);
|
||||
}
|
||||
ib_.Commit(&function_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
|
||||
CHECK_NE(value.stype.id, 0U);
|
||||
if (value.stype.id == dst_type.id) return value;
|
||||
const tvm::Type& from = value.stype.type;
|
||||
const tvm::Type& to = dst_type.type;
|
||||
CHECK_EQ(from.lanes(), to.lanes());
|
||||
|
||||
if (from.is_int() && to.is_int()) {
|
||||
return MakeValue(spv::OpSConvert, dst_type, value);
|
||||
} else if (from.is_uint() && to.is_uint()) {
|
||||
return MakeValue(spv::OpUConvert, dst_type, value);
|
||||
} else if (from.is_uint() && to.is_int()) {
|
||||
if (from.bits() != to.bits()) {
|
||||
value = MakeValue(
|
||||
spv::OpUConvert, GetSType(from.with_bits(to.bits())), value);
|
||||
}
|
||||
return MakeValue(spv::OpBitcast, dst_type, value);
|
||||
} else if (from.is_int() && to.is_uint()) {
|
||||
if (from.bits() != to.bits()) {
|
||||
value = MakeValue(
|
||||
spv::OpSConvert, GetSType(from.with_bits(to.bits())), value);
|
||||
}
|
||||
return MakeValue(spv::OpBitcast, dst_type, value);
|
||||
} else if (from.is_float() && to.is_int()) {
|
||||
return MakeValue(spv::OpConvertFToS, dst_type, value);
|
||||
} else if (from.is_float() && to.is_uint()) {
|
||||
return MakeValue(spv::OpConvertFToU, dst_type, value);
|
||||
} else if (from.is_int() && to.is_float()) {
|
||||
return MakeValue(spv::OpConvertSToF, dst_type, value);
|
||||
} else if (from.is_uint() && to.is_float()) {
|
||||
return MakeValue(spv::OpConvertUToF, dst_type, value);
|
||||
} else if (from.is_float() && to.is_float()) {
|
||||
return MakeValue(spv::OpFConvert, dst_type, value);
|
||||
} else {
|
||||
LOG(FATAL) << "do not support type cast from "
|
||||
<< from << " to " << to;
|
||||
return Value();
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \
|
||||
Value IRBuilder::_OpName(Value a, Value b) { \
|
||||
CHECK_EQ(a.stype.id, b.stype.id); \
|
||||
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
|
||||
return MakeValue(spv::OpI ## _Op, a.stype, a, b); \
|
||||
} else { \
|
||||
CHECK(a.stype.type.is_float()); \
|
||||
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
|
||||
Value IRBuilder::_OpName(Value a, Value b) { \
|
||||
CHECK_EQ(a.stype.id, b.stype.id); \
|
||||
if (a.stype.type.is_int()) { \
|
||||
return MakeValue(spv::OpS ## _Op, a.stype, a, b); \
|
||||
} else if (a.stype.type.is_uint()) { \
|
||||
return MakeValue(spv::OpU ## _Op, a.stype, a, b); \
|
||||
} else { \
|
||||
CHECK(a.stype.type.is_float()); \
|
||||
return MakeValue(spv::OpF ## _Op, a.stype, a, b); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add);
|
||||
DEFINE_BUILDER_BINARY_USIGN_OP(Sub, Sub);
|
||||
DEFINE_BUILDER_BINARY_USIGN_OP(Mul, Mul);
|
||||
DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div);
|
||||
|
||||
Value IRBuilder::Mod(Value a, Value b) {
|
||||
CHECK_EQ(a.stype.id, b.stype.id);
|
||||
if (a.stype.type.is_int()) {
|
||||
return MakeValue(spv::OpSRem, a.stype, a, b);
|
||||
} else if (a.stype.type.is_uint()) {
|
||||
return MakeValue(spv::OpUMod, a.stype, a, b);
|
||||
} else {
|
||||
CHECK(a.stype.type.is_float());
|
||||
return MakeValue(spv::OpFRem, a.stype, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
|
||||
Value IRBuilder:: _OpName(Value a, Value b) { \
|
||||
CHECK_EQ(a.stype.id, b.stype.id); \
|
||||
if (t_bool_.id == 0) { \
|
||||
t_bool_ = DeclareType(UInt(1)); \
|
||||
} \
|
||||
if (a.stype.type.is_int()) { \
|
||||
return MakeValue(spv::OpS ## _Op, t_bool_, a, b); \
|
||||
} else if (a.stype.type.is_uint()) { \
|
||||
return MakeValue(spv::OpU ## _Op, t_bool_, a, b); \
|
||||
} else { \
|
||||
CHECK(a.stype.type.is_float()); \
|
||||
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_BUILDER_CMP_OP(LT, LessThan);
|
||||
DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
|
||||
DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
|
||||
DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
|
||||
|
||||
#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
|
||||
Value IRBuilder:: _OpName(Value a, Value b) { \
|
||||
CHECK_EQ(a.stype.id, b.stype.id); \
|
||||
if (t_bool_.id == 0) { \
|
||||
t_bool_ = DeclareType(UInt(1)); \
|
||||
} \
|
||||
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
|
||||
return MakeValue(spv::OpI ## _Op, t_bool_, a, b); \
|
||||
} else { \
|
||||
CHECK(a.stype.type.is_float()); \
|
||||
return MakeValue(spv::OpFOrd ## _Op, t_bool_, a, b); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_BUILDER_CMP_UOP(EQ, Equal);
|
||||
DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
|
||||
|
||||
Value IRBuilder::Select(Value cond, Value a, Value b) {
|
||||
CHECK_EQ(a.stype.id, b.stype.id);
|
||||
CHECK_EQ(cond.stype.type, UInt(1));
|
||||
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
|
||||
}
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_VULKAN_RUNTIME
|
|
@ -0,0 +1,597 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file ir_builder.h
|
||||
* \brief Utility for building SPIRV code block
|
||||
*/
|
||||
#ifndef TVM_CODEGEN_SPIRV_IR_BUILDER_H_
|
||||
#define TVM_CODEGEN_SPIRV_IR_BUILDER_H_
|
||||
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <tvm/ir.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include <vulkan/spirv.hpp>
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
namespace spirv {
|
||||
|
||||
/*! \brief Represent the SPIRV Type */
|
||||
struct SType {
|
||||
/*! \brief The Id to represent type */
|
||||
uint32_t id{0};
|
||||
/*! \brief corresponding TVM type */
|
||||
tvm::Type type;
|
||||
/*! \brief content type id if it is a pointer/struct-array class */
|
||||
uint32_t element_type_id{0};
|
||||
/*! \brief The storage class, if it is a pointer */
|
||||
spv::StorageClass storage_class{spv::StorageClassMax};
|
||||
};
|
||||
|
||||
enum ValueKind {
|
||||
kNormal,
|
||||
kConstant,
|
||||
kVectorPtr,
|
||||
kStructArrayPtr,
|
||||
kPushConstantPtr,
|
||||
kFunction,
|
||||
kExtInst
|
||||
};
|
||||
|
||||
/*! \brief Represent the SPIRV Value */
|
||||
struct Value {
|
||||
/*! \brief The Id to represent value */
|
||||
uint32_t id{0};
|
||||
/*! \brief The data type */
|
||||
SType stype;
|
||||
/*! \brief additional flags about the value */
|
||||
ValueKind flag{kNormal};
|
||||
};
|
||||
|
||||
/*! \brief Represent the SPIRV Label */
|
||||
struct Label {
|
||||
/*! \brief The Id to represent label */
|
||||
uint32_t id{0};
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A SPIRV instruction,
|
||||
* can be used as handle to modify its content later
|
||||
*/
|
||||
class Instr {
|
||||
public:
|
||||
/*! \return the word count */
|
||||
uint32_t WordCount() const {
|
||||
return word_count_;
|
||||
}
|
||||
/*!
|
||||
* \brief Access idx-th word of instruction
|
||||
* \param idx The index
|
||||
* \return reference to idx-th word.
|
||||
*/
|
||||
uint32_t& operator[](uint32_t idx) {
|
||||
CHECK_LT(idx, word_count_);
|
||||
return (*data_)[begin_ + idx];
|
||||
}
|
||||
|
||||
private:
|
||||
friend class InstrBuilder;
|
||||
/*!
|
||||
* \brief the data that backs this instruction
|
||||
* Have to use vector reference because
|
||||
* vector can change.
|
||||
*/
|
||||
std::vector<uint32_t>* data_{nullptr};
|
||||
/*! \brief begin location of instruction */
|
||||
uint32_t begin_{0};
|
||||
/*! \brief work count */
|
||||
uint32_t word_count_{0};
|
||||
};
|
||||
|
||||
/*! \brief Representation of phi value */
|
||||
struct PhiValue : public Value {
|
||||
/*! \brief The corresponding instr */
|
||||
Instr instr;
|
||||
/*!
|
||||
* \brief Add incoming information of a PhiValue
|
||||
* \param index The location of Phi
|
||||
* \param value The value to come
|
||||
* \param parent The parent label.
|
||||
*/
|
||||
void SetIncoming(uint32_t index,
|
||||
const Value& value,
|
||||
const Label& parent) {
|
||||
CHECK_EQ(this->stype.id, value.stype.id);
|
||||
instr[3 + index * 2] = value.id;
|
||||
instr[3 + index * 2 + 1] = parent.id;
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Helper class to build SPIRV instruction.
|
||||
*
|
||||
* \code
|
||||
*
|
||||
* std::vector<uint32_t> func_seg_vec_;
|
||||
* InstrBuilder ib;
|
||||
*
|
||||
* // construct and append to the end of func_seg_vec_;
|
||||
* ib.Begin(spv::OpIAdd)
|
||||
* .Add(result).Add(v1).Add(v2)
|
||||
* .Commit(&func_seg_vec_);
|
||||
*
|
||||
* \endcode
|
||||
*/
|
||||
class InstrBuilder {
|
||||
public:
|
||||
/*!
|
||||
* \brief Begin construction of instruction.
|
||||
* \param op The op code
|
||||
* \return reference to self.
|
||||
*/
|
||||
InstrBuilder& Begin(spv::Op op) { // NOLINT(*);
|
||||
// finish previous build
|
||||
CHECK_EQ(data_.size(), 0U);
|
||||
op_ = op;
|
||||
data_.push_back(0);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Add v to end of instruction.
|
||||
* \param v The value to be appended to the instruction.
|
||||
* \return reference to self.
|
||||
*/
|
||||
InstrBuilder& Add(const Value& v) {
|
||||
data_.push_back(v.id);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Add v to end of instruction.
|
||||
* \param v The type to be appended to the instruction.
|
||||
* \return reference to self.
|
||||
*/
|
||||
InstrBuilder& Add(const SType& v) {
|
||||
data_.push_back(v.id);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Add v to end of instruction.
|
||||
* \param v The label to be appended to the instruction.
|
||||
* \return reference to self.
|
||||
*/
|
||||
InstrBuilder& Add(const Label& v) {
|
||||
data_.push_back(v.id);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Add a word to end of instruction.
|
||||
* \param v The value to be added.
|
||||
* \return reference to self.
|
||||
*/
|
||||
InstrBuilder& Add(const uint32_t& v) {
|
||||
data_.push_back(v);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Add string literal of end of instruction.
|
||||
* \param v The string literal to be appended.
|
||||
* \return reference to self.
|
||||
*/
|
||||
InstrBuilder& Add(const std::string& v) {
|
||||
const uint32_t kWordSize = sizeof(uint32_t);
|
||||
uint32_t nwords =
|
||||
(static_cast<uint32_t>(v.length()) + kWordSize) / kWordSize;
|
||||
size_t begin = data_.size();
|
||||
data_.resize(begin + nwords, 0U);
|
||||
std::copy(v.begin(), v.end(),
|
||||
reinterpret_cast<char*>(&data_[begin]));
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief add sequence of values to instruction
|
||||
* \param args The instruction sequence
|
||||
* \return reference to self.
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
InstrBuilder& AddSeq(Args&& ...args) {
|
||||
AddSeqHelper helper;
|
||||
helper.builder = this;
|
||||
runtime::detail::for_each(helper, std::forward<Args>(args)...);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief Finish build, commit the current
|
||||
* instruction to the end of seg.
|
||||
*
|
||||
* \param seg The code segment to commit to
|
||||
* \return The result instruction.
|
||||
*/
|
||||
Instr Commit(std::vector<uint32_t>* seg) {
|
||||
Instr ret;
|
||||
ret.data_ = seg;
|
||||
ret.begin_ = seg->size();
|
||||
ret.word_count_ = static_cast<uint32_t>(data_.size());
|
||||
data_[0] = op_ | (ret.word_count_ << spv::WordCountShift);
|
||||
seg->insert(seg->end(), data_.begin(), data_.end());
|
||||
data_.clear();
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
// current op code.
|
||||
spv::Op op_;
|
||||
// The internal data to store code
|
||||
std::vector<uint32_t> data_;
|
||||
// helper class to support variadic arguments
|
||||
struct AddSeqHelper {
|
||||
// The reference to builder
|
||||
InstrBuilder* builder;
|
||||
// invoke function
|
||||
template<typename T>
|
||||
void operator()(size_t, const T& v) const {
|
||||
builder->Add(v);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Builder to build up a single SPIR-V module
|
||||
*
|
||||
* This is a thin wrapper to build SPIRV binary.
|
||||
* SPIRV adopts structure control-flow.
|
||||
* We can build the code by always appending to the end of the
|
||||
* binary code block and revisit some
|
||||
*
|
||||
* This IRBuilder did not introduce concept of BasicBlock.
|
||||
* instead instructions are append to end of each segment.
|
||||
*/
|
||||
class IRBuilder {
|
||||
public:
|
||||
/*! \brief Initialize header */
|
||||
void InitHeader();
|
||||
/*! \brief Initialize the predefined contents */
|
||||
void InitPreDefs();
|
||||
/*!
|
||||
* \brief Import additional extension libraries.
|
||||
* \param name The name of the library.
|
||||
* \return The finalized binary instruction.
|
||||
*/
|
||||
Value ExtInstImport(const std::string& name) {
|
||||
Value val = NewValue(SType(), kExtInst);
|
||||
ib_.Begin(spv::OpExtInstImport).AddSeq(val, name).Commit(&header_);
|
||||
return val;
|
||||
}
|
||||
/*!
|
||||
* \brief Get the final binary built from the builder
|
||||
* \return The finalized binary instruction.
|
||||
*/
|
||||
std::vector<uint32_t> Finalize() {
|
||||
std::vector<uint32_t> data;
|
||||
// set bound
|
||||
const int kBoundLoc = 3;
|
||||
header_[kBoundLoc] = id_counter_;
|
||||
data.insert(data.end(), header_.begin(), header_.end());
|
||||
data.insert(data.end(), entry_.begin(), entry_.end());
|
||||
data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
|
||||
data.insert(data.end(), debug_.begin(), debug_.end());
|
||||
data.insert(data.end(), decorate_.begin(), decorate_.end());
|
||||
data.insert(data.end(), global_.begin(), global_.end());
|
||||
data.insert(data.end(), function_.begin(), function_.end());
|
||||
return data;
|
||||
}
|
||||
/*!
|
||||
* \brief Create new label
|
||||
* \return The created new label
|
||||
*/
|
||||
Label NewLabel() {
|
||||
Label label;
|
||||
label.id = id_counter_++;
|
||||
return label;
|
||||
}
|
||||
/*!
|
||||
* \brief Start a new block with given label
|
||||
* \param label The label we use.
|
||||
*/
|
||||
void StartLabel(Label label) {
|
||||
MakeInst(spv::OpLabel, label);
|
||||
curr_label_ = label;
|
||||
}
|
||||
/*! \return The current label */
|
||||
Label CurrentLabel() const {
|
||||
return curr_label_;
|
||||
}
|
||||
/*!
|
||||
* \brief Add code to debug segment.
|
||||
* \param op The operator
|
||||
* \param args The instruction sequence
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
void Debug(spv::Op op, Args&& ...args) {
|
||||
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&debug_);
|
||||
}
|
||||
/*!
|
||||
* \brief Add Execution mode to a function.
|
||||
* \param func The function value
|
||||
* \param args The instruction sequence
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
void ExecutionMode(Value func, Args&& ...args) {
|
||||
ib_.Begin(spv::OpExecutionMode).AddSeq(
|
||||
func, std::forward<Args>(args)...).Commit(&exec_mode_);
|
||||
}
|
||||
/*!
|
||||
* \brief Add code to decorate segment.
|
||||
* \param op The operator
|
||||
* \param args The instruction sequence
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
void Decorate(spv::Op op, Args&& ...args) {
|
||||
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
|
||||
}
|
||||
/*!
|
||||
* \brief Add code to global segment.
|
||||
* \param op The operator
|
||||
* \param args The instruction sequence
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
Value DeclareGlobal(spv::Op op, Args&& ...args) {
|
||||
ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&decorate_);
|
||||
}
|
||||
/*!
|
||||
* \brief Make a new instruction and append it to end of function segment.
|
||||
*
|
||||
* \param op The operator
|
||||
* \param args The instruction sequence
|
||||
* \return The result SSA value.
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
Instr MakeInst(spv::Op op, Args&& ...args) {
|
||||
return ib_.Begin(op).AddSeq(std::forward<Args>(args)...).Commit(&function_);
|
||||
}
|
||||
/*!
|
||||
* \brief Make a new SSA value,
|
||||
*
|
||||
* \param op The operator.
|
||||
* \param out_type The result type.
|
||||
* \param args The instruction sequence
|
||||
* \return The result SSA value.
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
template<typename... Args>
|
||||
Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) {
|
||||
Value val = NewValue(out_type, kNormal);
|
||||
MakeInst(op, out_type, val, std::forward<Args>(args)...);
|
||||
return val;
|
||||
}
|
||||
/*!
|
||||
* \brief Make a phi value.
|
||||
*
|
||||
* \param out_type The output data type.
|
||||
* \param num_incoming number of incoming blocks.
|
||||
* \return The result Phi value.
|
||||
*/
|
||||
PhiValue MakePhi(const SType& out_type, uint32_t num_incoming);
|
||||
/*!
|
||||
* \brief Create a GLSL450 call
|
||||
*
|
||||
* \param ret_type The result type.
|
||||
* \param inst_id The instance id of the function.
|
||||
* \param args The arguments
|
||||
* \return The result value.
|
||||
*/
|
||||
Value CallGLSL450(const SType& ret_type,
|
||||
uint32_t inst_id,
|
||||
const std::vector<Value>& args);
|
||||
/*!
|
||||
* \brief Build vector by concatenating components
|
||||
*
|
||||
* \param vec The vector component
|
||||
* \tparams Args The positional arguments
|
||||
*/
|
||||
Value Concat(const std::vector<Value>& vec);
|
||||
/*!
|
||||
* \brief Get the spirv type for a given tvm data type.
|
||||
* \param dtype The data type.
|
||||
* \return The corresponding spirv type.
|
||||
*/
|
||||
SType GetSType(const tvm::Type& dtype);
|
||||
/*!
|
||||
* \brief Get the pointer type that points to value_type
|
||||
* \param value_type.
|
||||
* \param storage_class The storage class
|
||||
* \return The corresponding spirv type.
|
||||
*/
|
||||
SType GetPointerType(const SType& value_type,
|
||||
spv::StorageClass storage_class);
|
||||
/*!
|
||||
* \brief Get a struct{ value_type[num_elems] } type.
|
||||
* \param value_type the content value type.
|
||||
* \param num_elems number of elements in array
|
||||
* num_elems = 0 means runtime array with BufferBlock Decoration
|
||||
*
|
||||
* \return The corresponding spirv type.
|
||||
*/
|
||||
SType GetStructArrayType(const SType& value_type,
|
||||
uint32_t num_elems);
|
||||
/*!
|
||||
* \brief Get a struct array access with a given index.
|
||||
* \param ptr_type The pointer type.
|
||||
* \param buffer The buffer ptr to struct array
|
||||
* \param index The array index.
|
||||
*/
|
||||
Value StructArrayAccess(const SType& ptr_type,
|
||||
Value buffer,
|
||||
Value index);
|
||||
/*!
|
||||
* \brief Create a cast that cast value to dst_type
|
||||
* \param dst_type The target type.
|
||||
* \param value the source value.
|
||||
* \return The result value
|
||||
*/
|
||||
Value Cast(const SType& dst_type, Value value);
|
||||
/*
|
||||
* \brief Create a const integer.
|
||||
* \param dtype The content data type.
|
||||
* \param value The data value.
|
||||
*/
|
||||
Value IntImm(const SType& dtype, int64_t value);
|
||||
/*
|
||||
* \brief Create a const unsigned integer.
|
||||
* \param dtype The content data type.
|
||||
* \param value The data value.
|
||||
*/
|
||||
Value UIntImm(const SType& dtype, uint64_t value);
|
||||
/*
|
||||
* \brief Create a const float.
|
||||
* \param dtype The content data type.
|
||||
* \param value The data value.
|
||||
*/
|
||||
Value FloatImm(const SType& dtype, double value);
|
||||
/*
|
||||
* \brief Declare buffer argument of function
|
||||
*
|
||||
* \param arg_type The type of argument.
|
||||
* \param descriptor_set The descriptor set we want to use.
|
||||
* \param binding The binding locaiton in descriptor set.
|
||||
* \param The argument type.
|
||||
*/
|
||||
Value BufferArgument(const SType& value_type,
|
||||
uint32_t descriptor_set,
|
||||
uint32_t binding);
|
||||
/*!
|
||||
* \brief Declare POD arguments through push constants.
|
||||
*
|
||||
* \note Only call this function once!
|
||||
* \param value_types The values in the push constant
|
||||
* \return reference to self.
|
||||
*/
|
||||
Value DeclarePushConstant(const std::vector<SType>& value_types);
|
||||
/*!
|
||||
* \brief Get i-th push constant
|
||||
* \param v_type The value type
|
||||
* \param index The push constant index
|
||||
* \return the value of push constant
|
||||
*/
|
||||
Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
|
||||
/*!
|
||||
* \brief Declare a kernel function
|
||||
* \param name Name of the entry point.
|
||||
* \return The created function ID.
|
||||
*/
|
||||
Value DeclareKenrelFunction(const std::string& name);
|
||||
/*!
|
||||
* \brief Start function scope.
|
||||
* \param func function to be started.
|
||||
*/
|
||||
void StartFunction(const Value& func);
|
||||
/*!
|
||||
* \brief Set the local size of the function
|
||||
* \param func function of interest
|
||||
* \param local_size The local workgroup_size
|
||||
*/
|
||||
void SetLocalSize(const Value& func, uint32_t local_size[3]);
|
||||
/*
|
||||
* \brief Allocate space
|
||||
* \param value_type The content value type
|
||||
* \param num_elems Number of elements to allocate.
|
||||
* \param storage_class The storage class we want to store to.
|
||||
*/
|
||||
Value Allocate(const SType& value_type,
|
||||
uint32_t num_elems,
|
||||
spv::StorageClass storage_class);
|
||||
/*
|
||||
* \brief Get the i-th workgroup id.
|
||||
* \return The value representing the workgroup id.
|
||||
*/
|
||||
Value GetWorkgroupID(uint32_t dim_index);
|
||||
/*
|
||||
* \brief Get the i-th local id.
|
||||
* \return The value representing the local id.
|
||||
*/
|
||||
Value GetLocalID(uint32_t dim_index);
|
||||
// Expressions
|
||||
Value Add(Value a, Value b);
|
||||
Value Sub(Value a, Value b);
|
||||
Value Mul(Value a, Value b);
|
||||
Value Div(Value a, Value b);
|
||||
Value Mod(Value a, Value b);
|
||||
Value EQ(Value a, Value b);
|
||||
Value NE(Value a, Value b);
|
||||
Value LT(Value a, Value b);
|
||||
Value LE(Value a, Value b);
|
||||
Value GT(Value a, Value b);
|
||||
Value GE(Value a, Value b);
|
||||
Value Select(Value cond, Value a, Value b);
|
||||
|
||||
private:
|
||||
/*!
|
||||
* \brief Create new value
|
||||
* \return The created new label
|
||||
*/
|
||||
Value NewValue(const SType& stype, ValueKind flag) {
|
||||
Value val;
|
||||
val.id = id_counter_++;
|
||||
val.stype = stype;
|
||||
val.flag = flag;
|
||||
return val;
|
||||
}
|
||||
// get constant given value encoded in uint64_t
|
||||
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
|
||||
// declare type
|
||||
SType DeclareType(const Type& dtype);
|
||||
/*! \brief internal instruction builder */
|
||||
InstrBuilder ib_;
|
||||
/*! \brief Current label */
|
||||
Label curr_label_;
|
||||
/*! \brief The current maximum id */
|
||||
uint32_t id_counter_{1};
|
||||
/*! \brief glsl 450 extension */
|
||||
Value ext_glsl450_;
|
||||
/*! \brief Special cache int32, fp32, void*/
|
||||
SType t_bool_, t_int32_, t_uint32_, t_fp32_, t_void_, t_void_func_;
|
||||
/*! \brief quick cache for const one i32 */
|
||||
Value const_i32_zero_;
|
||||
/*! \brief cache value for workgroup_id, local_id */
|
||||
Value workgroup_id_, local_id_;
|
||||
/*! \brief whether push constant is defined */
|
||||
Value push_const_;
|
||||
/*! \brief map from type code to the type */
|
||||
std::unordered_map<uint32_t, SType> pod_type_tbl_;
|
||||
/*! \brief map from value to array type */
|
||||
std::map<std::pair<uint32_t, uint32_t>, SType> struct_array_type_tbl_;
|
||||
/*! \brief map from value to its pointer type */
|
||||
std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
|
||||
/*! \brief map from constant int to its value */
|
||||
std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_;
|
||||
/*! \brief Header segment, include import */
|
||||
std::vector<uint32_t> header_;
|
||||
/*! \brief engtry point segment */
|
||||
std::vector<uint32_t> entry_;
|
||||
/*! \brief Header segment */
|
||||
std::vector<uint32_t> exec_mode_;
|
||||
/*! \brief Debug segment */
|
||||
std::vector<uint32_t> debug_;
|
||||
/*! \brief Annotation segment */
|
||||
std::vector<uint32_t> decorate_;
|
||||
/*! \brief Global segment: types, variables, types */
|
||||
std::vector<uint32_t> global_;
|
||||
/*! \brief Function segment */
|
||||
std::vector<uint32_t> function_;
|
||||
};
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_CODEGEN_SPIRV_IR_BUILDER_H_
|
|
@ -59,7 +59,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
|
|||
return VisitExpr(op->a);
|
||||
}
|
||||
bool VisitExpr_(const Let* op) final {
|
||||
return VisitExpr(op->body) && VisitExpr(op->value);
|
||||
return VisitExpr(op->body) || VisitExpr(op->value);
|
||||
}
|
||||
bool VisitExpr_(const Cast* op) final {
|
||||
return VisitExpr(op->value);
|
||||
|
@ -84,7 +84,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
|
|||
private:
|
||||
template<typename T>
|
||||
bool BinaryOp(const T* op) {
|
||||
return VisitExpr(op->a) && VisitExpr(op->b);
|
||||
return VisitExpr(op->a) || VisitExpr(op->b);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -903,20 +903,42 @@ class VectorAllocRewriter : public IRMutator {
|
|||
return stmt;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
void UpdateTypeMap(const Variable* buffer, Type t) {
|
||||
auto& tvec = acc_map_[buffer];
|
||||
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
|
||||
tvec.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
// Internal access map
|
||||
std::unordered_map<const Variable*,
|
||||
std::vector<Type> > acc_map_;
|
||||
std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
|
||||
};
|
||||
|
||||
|
||||
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
|
||||
std::shared_ptr<LoweredFuncNode> n =
|
||||
std::make_shared<LoweredFuncNode>(*f.operator->());
|
||||
VectorAllocRewriter rewriter;
|
||||
n->body = rewriter.Mutate(n->body);
|
||||
for (Var arg : f->args) {
|
||||
if (arg.type().is_handle()) {
|
||||
const auto& tvec = rewriter.acc_map_[arg.get()];
|
||||
if (tvec.size() == 1) {
|
||||
Expr dtype = make_const(tvec[0], 0);
|
||||
n->handle_data_type.Set(arg, dtype);
|
||||
} else {
|
||||
// always set data type to be non vectorized so
|
||||
// load/store can still work via scalarization
|
||||
if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
|
||||
Expr dtype = make_const(tvec[0].with_lanes(1), 0);
|
||||
n->handle_data_type.Set(arg, dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return LoweredFunc(n);
|
||||
}
|
||||
|
||||
Stmt StorageRewrite(Stmt stmt) {
|
||||
stmt = StoragePlanRewriter().Rewrite(stmt, true);
|
||||
return VectorAllocRewriter().Mutate(stmt);
|
||||
|
|
|
@ -28,6 +28,7 @@ inline std::string DeviceName(int type) {
|
|||
case kDLCPU: return "cpu";
|
||||
case kDLGPU: return "gpu";
|
||||
case kDLOpenCL: return "opencl";
|
||||
case kDLVulkan: return "vulkan";
|
||||
case kDLMetal: return "metal";
|
||||
case kDLVPI: return "vpi";
|
||||
case kDLROCM: return "rocm";
|
||||
|
|
|
@ -119,6 +119,8 @@ bool RuntimeEnabled(const std::string& target) {
|
|||
f_name = "device_api.opengl";
|
||||
} else if (target == "mtl" || target == "metal") {
|
||||
f_name = "device_api.metal";
|
||||
} else if (target == "vulkan") {
|
||||
f_name = "device_api.vulkan";
|
||||
} else if (target == "stackvm") {
|
||||
f_name = "codegen.build_stackvm";
|
||||
} else if (target == "rpc") {
|
||||
|
|
|
@ -44,11 +44,12 @@ class ROCMDeviceAPI final : public DeviceAPI {
|
|||
value = 64;
|
||||
break;
|
||||
}
|
||||
case kComputeVersion:
|
||||
case kComputeVersion: {
|
||||
hipDeviceProp_t prop;
|
||||
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
|
||||
*rv = prop.gcnArch;
|
||||
return;
|
||||
}
|
||||
}
|
||||
*rv = value;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file vulkan_common.h
|
||||
* \brief Vulkan common header
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
|
||||
#define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
|
||||
|
||||
#include <tvm/runtime/config.h>
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
#include <vulkan/vulkan.h>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "../workspace_pool.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
namespace vulkan {
|
||||
|
||||
inline const char* VKGetErrorString(VkResult error) {
|
||||
switch (error) {
|
||||
case VK_SUCCESS: return "VK_SUCCESS";
|
||||
case VK_NOT_READY: return "VK_NOT_READY";
|
||||
case VK_TIMEOUT: return "VK_TIMEOUT";
|
||||
case VK_EVENT_SET: return "VK_EVENT_SET";
|
||||
case VK_EVENT_RESET: return "VK_EVENT_RESET";
|
||||
case VK_INCOMPLETE: return "VK_INCOMPLETE";
|
||||
case VK_ERROR_OUT_OF_HOST_MEMORY: return "VK_ERROR_OUT_OF_HOST_MEMORY";
|
||||
case VK_ERROR_OUT_OF_DEVICE_MEMORY: return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
|
||||
case VK_ERROR_INITIALIZATION_FAILED: return "VK_ERROR_INITIALIZATION_FAILED";
|
||||
case VK_ERROR_DEVICE_LOST: return "VK_ERROR_DEVICE_LOST";
|
||||
case VK_ERROR_MEMORY_MAP_FAILED: return "VK_ERROR_MEMORY_MAP_FAILED";
|
||||
case VK_ERROR_LAYER_NOT_PRESENT: return "VK_ERROR_LAYER_NOT_PRESENT";
|
||||
case VK_ERROR_EXTENSION_NOT_PRESENT: return "VK_ERROR_EXTENSION_NOT_PRESENT";
|
||||
case VK_ERROR_FEATURE_NOT_PRESENT: return "VK_ERROR_FEATURE_NOT_PRESENT";
|
||||
case VK_ERROR_INCOMPATIBLE_DRIVER: return "VK_ERROR_INCOMPATIBLE_DRIVER";
|
||||
case VK_ERROR_TOO_MANY_OBJECTS: return "VK_ERROR_TOO_MANY_OBJECTS";
|
||||
case VK_ERROR_FORMAT_NOT_SUPPORTED: return "VK_ERROR_FORMAT_NOT_SUPPORTED";
|
||||
case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL";
|
||||
default: return "Unknown Vulkan error code";
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Protected Vulkan call
|
||||
* \param func Expression to call.
|
||||
*/
|
||||
#define VULKAN_CHECK_ERROR(__e) \
|
||||
{ \
|
||||
CHECK(__e == VK_SUCCESS) \
|
||||
<< "Vulan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \
|
||||
}
|
||||
|
||||
#define VULKAN_CALL(func) \
|
||||
{ \
|
||||
VkResult __e = (func); \
|
||||
VULKAN_CHECK_ERROR(__e); \
|
||||
}
|
||||
|
||||
/*! \brief Auxiliary context structure for vulkan */
|
||||
struct VulkanContext {
|
||||
// phyiscal device
|
||||
VkPhysicalDevice phy_device{nullptr};
|
||||
// Phyiscal device property
|
||||
VkPhysicalDeviceProperties phy_device_prop;
|
||||
// Memory type index for staging.
|
||||
uint32_t staging_mtype_index{0};
|
||||
// whether staging is coherent
|
||||
bool coherent_staging{false};
|
||||
// Memory type index for compute
|
||||
uint32_t compute_mtype_index{0};
|
||||
// The logical device
|
||||
VkDevice device{nullptr};
|
||||
// command queue
|
||||
VkQueue queue{nullptr};
|
||||
// queue family_index;
|
||||
uint32_t queue_family_index{0};
|
||||
// Queue family index.
|
||||
VkQueueFamilyProperties queue_prop;
|
||||
};
|
||||
|
||||
/*! \brief The buffer object */
|
||||
struct VulkanBuffer {
|
||||
/*! \brief underlying buffer */
|
||||
VkBuffer buffer{nullptr};
|
||||
/*! \brief underlying buffer */
|
||||
VkDeviceMemory memory{nullptr};
|
||||
};
|
||||
|
||||
/*! \brief Buffer only used for stagging */
|
||||
struct VulkanStagingBuffer {
|
||||
/*! \brief the corresponding device */
|
||||
VkDevice device{nullptr};
|
||||
/*! \brief underlying buffer */
|
||||
VkBuffer buffer{nullptr};
|
||||
/*! \brief underlying buffer */
|
||||
VkDeviceMemory memory{nullptr};
|
||||
/*! \brief host address */
|
||||
void* host_addr{nullptr};
|
||||
/*! \brief size of the memory */
|
||||
size_t size{0};
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Process global Vulkan workspace.
|
||||
*/
|
||||
class VulkanWorkspace final : public DeviceAPI {
|
||||
public:
|
||||
// global mutex
|
||||
std::mutex mu;
|
||||
// whether the workspace it initialized.
|
||||
bool initialized_{false};
|
||||
// vulkan instance
|
||||
VkInstance instance_{nullptr};
|
||||
// The physical devices, have 1 to 1 mapping to devices
|
||||
std::vector<VulkanContext> context_;
|
||||
// Destructor
|
||||
~VulkanWorkspace();
|
||||
// Initialize workspace
|
||||
// Return false if already initialized, otherwise return true.
|
||||
void Init();
|
||||
// override device API
|
||||
void SetDevice(TVMContext ctx) final;
|
||||
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
|
||||
void* AllocDataSpace(TVMContext ctx,
|
||||
size_t nbytes,
|
||||
size_t alignment,
|
||||
TVMType type_hint) final;
|
||||
void FreeDataSpace(TVMContext ctx, void* ptr) final;
|
||||
void CopyDataFromTo(const void* from,
|
||||
size_t from_size,
|
||||
void* to,
|
||||
size_t to_size,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) final;
|
||||
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
|
||||
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
|
||||
void FreeWorkspace(TVMContext ctx, void* data) final;
|
||||
// get the global workspace
|
||||
static const std::shared_ptr<VulkanWorkspace>& Global();
|
||||
};
|
||||
|
||||
/*! \brief Helper command buffer resource */
|
||||
struct VulkanCommandBuffer {
|
||||
/*! \brief fence to signal the resource is ready to use */
|
||||
VkFence fence{nullptr};
|
||||
/*! \brief The internal command buffer */
|
||||
VkCommandBuffer cmd_buffer{nullptr};
|
||||
/*! \brief Descriptor set used to bind arguments */
|
||||
VkDescriptorSet descriptor_set{nullptr};
|
||||
/*! \brief Internal utilities for write command */
|
||||
VkWriteDescriptorSet write_descriptor_set;
|
||||
|
||||
VulkanCommandBuffer() {
|
||||
write_descriptor_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
|
||||
write_descriptor_set.pNext = nullptr;
|
||||
write_descriptor_set.dstSet = nullptr;
|
||||
write_descriptor_set.dstBinding = 0;
|
||||
write_descriptor_set.dstArrayElement = 0;
|
||||
write_descriptor_set.descriptorCount = 1;
|
||||
write_descriptor_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
|
||||
write_descriptor_set.pImageInfo = nullptr;
|
||||
write_descriptor_set.pBufferInfo = nullptr;
|
||||
write_descriptor_set.pTexelBufferView = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Command pool backed by a fixed size ring buffer.
|
||||
*
|
||||
* Vulkan requires us not to reuse command buffer until
|
||||
* All its corresponding jobs have finished.
|
||||
*
|
||||
* This class to faciliate automatic management
|
||||
* of the command buffers. A fence is created
|
||||
* for each launch of command buffer jobs
|
||||
* and when we try to reuse the same entry
|
||||
* in the ring, we need to make sure that
|
||||
* the previous pending job already finishes.
|
||||
*
|
||||
*/
|
||||
class VulkanCommandPool {
|
||||
public:
|
||||
/*! \brief Maximum number of pending jobs in the pool */
|
||||
static constexpr const int kMaxPending = 4;
|
||||
/*! \brief Maximum number of pending jobs in the pool */
|
||||
static constexpr const int kMaxNumArgs = 16;
|
||||
/*!
|
||||
* \brief constructor
|
||||
* \param vctx The corresponding vulkan context.
|
||||
*/
|
||||
explicit VulkanCommandPool(const VulkanContext& vctx);
|
||||
/*! \brief destructor */
|
||||
~VulkanCommandPool();
|
||||
/*!
|
||||
* \brief Allocate a new command buffer entry
|
||||
*
|
||||
* The caller must only submit the entry once
|
||||
* with the given fence in the entry,
|
||||
* before calling next Alloc.
|
||||
*
|
||||
* This function may block to wait for a
|
||||
* previously unfinished command when
|
||||
* there is more than kMaxPending jobs.
|
||||
*
|
||||
* \returns The allocated entry.
|
||||
*/
|
||||
VulkanCommandBuffer* Alloc();
|
||||
|
||||
/*!
|
||||
* \brief Allocate a new command buffer entry
|
||||
* \param dlayout the descriptor layout.
|
||||
*
|
||||
* \returns The allocated entry.
|
||||
*/
|
||||
VulkanCommandBuffer* Alloc(const VkDescriptorSetLayout* dlayout);
|
||||
|
||||
private:
|
||||
/*! \brief Local ring buffer */
|
||||
std::vector<VulkanCommandBuffer> ring_;
|
||||
/*! \brief clock pointer */
|
||||
size_t clock_ptr_{0};
|
||||
/*! \brief the corresponding device*/
|
||||
VkDevice device_{nullptr};
|
||||
/*! \brief internal command buffer pool */
|
||||
VkCommandPool cmd_pool_{nullptr};
|
||||
/*! \brief Descriptor pool */
|
||||
VkDescriptorPool descriptor_pool_{nullptr};
|
||||
};
|
||||
|
||||
/*! \brief Thread local workspace */
|
||||
class VulkanThreadEntry {
|
||||
public:
|
||||
/*! \brief The current context */
|
||||
TVMContext context;
|
||||
/*! \brief workspace pool */
|
||||
WorkspacePool pool;
|
||||
/*! \brief The staging buffers */
|
||||
std::vector<VulkanStagingBuffer> staging_buffer_;
|
||||
/*!
|
||||
* \brief Get the command pool of corresponding device;
|
||||
* \param device_id The device id
|
||||
* \return The corresponding command buffer.
|
||||
*/
|
||||
VulkanCommandPool* CommandPool(int device_id);
|
||||
/*!
|
||||
* \brief Get the stagging buffer.
|
||||
* \param device_id The device id
|
||||
* \return The corresponding stagging buffer.
|
||||
*/
|
||||
VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
|
||||
|
||||
// constructor
|
||||
VulkanThreadEntry()
|
||||
: pool(static_cast<DLDeviceType>(kDLVulkan), VulkanWorkspace::Global()) {
|
||||
context.device_id = 0;
|
||||
context.device_type = static_cast<DLDeviceType>(kDLVulkan);
|
||||
}
|
||||
~VulkanThreadEntry();
|
||||
// get the global workspace
|
||||
static VulkanThreadEntry* ThreadLocal();
|
||||
|
||||
private:
|
||||
/*! \brief the command pools */
|
||||
std::vector<std::unique_ptr<VulkanCommandPool> > pool_;
|
||||
};
|
||||
|
||||
// inline implementation
|
||||
|
||||
|
||||
} // namespace vulkan
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_VULKAN_RUNTIME
|
||||
#endif // TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
|
|
@ -0,0 +1,681 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file vulkan_device_api.cc
|
||||
*/
|
||||
#include "./vulkan_common.h"
|
||||
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <cstring>
|
||||
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
namespace vulkan {
|
||||
|
||||
VulkanWorkspace::~VulkanWorkspace() {
|
||||
for (VulkanContext& ctx : context_) {
|
||||
vkDestroyDevice(ctx.device, nullptr);
|
||||
}
|
||||
if (instance_ != nullptr) {
|
||||
vkDestroyInstance(instance_, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
const std::shared_ptr<VulkanWorkspace>& VulkanWorkspace::Global() {
|
||||
static std::shared_ptr<VulkanWorkspace> inst = std::make_shared<VulkanWorkspace>();
|
||||
return inst;
|
||||
}
|
||||
|
||||
void VulkanWorkspace::SetDevice(TVMContext ctx) {
|
||||
VulkanThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
|
||||
}
|
||||
|
||||
void VulkanWorkspace::GetAttr(
|
||||
TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
|
||||
this->Init();
|
||||
size_t index = static_cast<size_t>(ctx.device_id);
|
||||
if (kind == kExist) {
|
||||
*rv = static_cast<int>(index< context_.size());
|
||||
return;
|
||||
}
|
||||
CHECK_LT(index, context_.size())
|
||||
<< "Invalid device id " << index;
|
||||
switch (kind) {
|
||||
case kMaxThreadsPerBlock: {
|
||||
VkPhysicalDeviceProperties phy_prop;
|
||||
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
|
||||
int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
|
||||
*rv = value;
|
||||
break;
|
||||
}
|
||||
case kWarpSize: {
|
||||
*rv = 1;
|
||||
break;
|
||||
}
|
||||
case kComputeVersion: {
|
||||
VkPhysicalDeviceProperties phy_prop;
|
||||
vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
|
||||
int64_t value = phy_prop.apiVersion;
|
||||
std::ostringstream os;
|
||||
os << VK_VERSION_MAJOR(value)
|
||||
<< "." << VK_VERSION_MINOR(value)
|
||||
<< "." << VK_VERSION_PATCH(value);
|
||||
*rv = os.str();
|
||||
break;
|
||||
}
|
||||
case kExist: break;
|
||||
}
|
||||
}
|
||||
|
||||
void* VulkanWorkspace::AllocDataSpace(
|
||||
TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
|
||||
this->Init();
|
||||
|
||||
VulkanContext& vctx = context_[ctx.device_id];
|
||||
|
||||
VkBufferCreateInfo info;
|
||||
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
info.pNext = nullptr;
|
||||
info.flags = 0;
|
||||
info.size = size;
|
||||
info.queueFamilyIndexCount = 1;
|
||||
info.pQueueFamilyIndices = &(vctx.queue_family_index);
|
||||
info.usage =
|
||||
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
|
||||
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
|
||||
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
|
||||
// create buffer
|
||||
VkBuffer buffer;
|
||||
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
|
||||
// bind to memory
|
||||
VkMemoryAllocateInfo minfo;
|
||||
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
minfo.pNext = nullptr;
|
||||
minfo.allocationSize = size;
|
||||
minfo.memoryTypeIndex = vctx.compute_mtype_index;
|
||||
VkDeviceMemory memory;
|
||||
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
|
||||
VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
|
||||
|
||||
VulkanBuffer* pbuf = new VulkanBuffer();
|
||||
pbuf->memory = memory;
|
||||
pbuf->buffer = buffer;
|
||||
return pbuf;
|
||||
}
|
||||
|
||||
void VulkanWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
|
||||
VulkanContext& vctx = context_[ctx.device_id];
|
||||
VulkanBuffer* pbuf = static_cast<VulkanBuffer*>(ptr);
|
||||
vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
|
||||
vkFreeMemory(vctx.device, pbuf->memory, nullptr);
|
||||
delete pbuf;
|
||||
}
|
||||
|
||||
void VulkanWorkspace::CopyDataFromTo(const void* from,
|
||||
size_t from_offset,
|
||||
void* to,
|
||||
size_t to_offset,
|
||||
size_t size,
|
||||
TVMContext ctx_from,
|
||||
TVMContext ctx_to,
|
||||
TVMStreamHandle stream) {
|
||||
this->Init();
|
||||
CHECK(stream == nullptr);
|
||||
TVMContext ctx = ctx_from;
|
||||
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
|
||||
VulkanThreadEntry* tls = VulkanThreadEntry::ThreadLocal();
|
||||
VulkanCommandBuffer* cmd = tls->CommandPool(ctx.device_id)->Alloc();
|
||||
|
||||
VkCommandBufferBeginInfo cb_begin;
|
||||
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
cb_begin.pNext = nullptr;
|
||||
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
cb_begin.pInheritanceInfo = 0;
|
||||
|
||||
VkSubmitInfo cb_submit;
|
||||
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
cb_submit.pNext = nullptr;
|
||||
cb_submit.waitSemaphoreCount = 0;
|
||||
cb_submit.pWaitSemaphores = nullptr;
|
||||
cb_submit.pWaitDstStageMask = 0;
|
||||
cb_submit.commandBufferCount = 1;
|
||||
cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
|
||||
cb_submit.signalSemaphoreCount = 0;
|
||||
cb_submit.pSignalSemaphores = nullptr;
|
||||
|
||||
|
||||
int from_dev_type = static_cast<int>(ctx_from.device_type);
|
||||
int to_dev_type = static_cast<int>(ctx_to.device_type);
|
||||
|
||||
if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
|
||||
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
|
||||
<< "Vulkan disallow cross device copy.";
|
||||
const VulkanContext& vctx = context_[ctx_from.device_id];
|
||||
const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
|
||||
VulkanBuffer* to_buf = static_cast<VulkanBuffer*>(to);
|
||||
// The assumption is that subsequence ops only perform compute/transfer
|
||||
// 0: begin
|
||||
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
|
||||
// 1: copy
|
||||
VkBufferCopy copy_info;
|
||||
copy_info.srcOffset = from_offset;
|
||||
copy_info.dstOffset = to_offset;
|
||||
copy_info.size = size;
|
||||
vkCmdCopyBuffer(cmd->cmd_buffer, from_buf->buffer, to_buf->buffer, 1, ©_info);
|
||||
// 2: barrier(transfer-> compute|transfer)
|
||||
VkMemoryBarrier barrier_info;
|
||||
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
|
||||
barrier_info.pNext = nullptr;
|
||||
barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
barrier_info.dstAccessMask =
|
||||
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
|
||||
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
|
||||
vkCmdPipelineBarrier(
|
||||
cmd->cmd_buffer,
|
||||
VK_PIPELINE_STAGE_TRANSFER_BIT,
|
||||
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
|
||||
0, 1, &barrier_info, 0, nullptr, 0, nullptr);
|
||||
// 3: end
|
||||
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
|
||||
// 4: submit with cmd->fence
|
||||
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
|
||||
} else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
|
||||
const VulkanContext& vctx = context_[ctx_from.device_id];
|
||||
const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
|
||||
VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_from.device_id, size);
|
||||
// 0: begin
|
||||
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
|
||||
// 1: copy
|
||||
VkBufferCopy copy_info;
|
||||
copy_info.srcOffset = from_offset;
|
||||
copy_info.dstOffset = 0;
|
||||
copy_info.size = size;
|
||||
vkCmdCopyBuffer(cmd->cmd_buffer,
|
||||
from_buf->buffer,
|
||||
temp->buffer,
|
||||
1, ©_info);
|
||||
// 2: end
|
||||
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
|
||||
// 4: submit with cmd->fence
|
||||
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
|
||||
// Block until done, to make sure temp can be reused later.
|
||||
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
|
||||
// host side invalidation if access is not coherent.
|
||||
// so writes from GPU is visible to CPU
|
||||
if (!vctx.coherent_staging) {
|
||||
VkMappedMemoryRange mrange;
|
||||
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
|
||||
mrange.pNext = nullptr;
|
||||
mrange.memory = temp->memory;
|
||||
mrange.offset = 0;
|
||||
mrange.size = size;
|
||||
VULKAN_CALL(vkInvalidateMappedMemoryRanges(
|
||||
vctx.device, 1, &mrange));
|
||||
}
|
||||
memcpy(static_cast<char*>(to) + to_offset,
|
||||
static_cast<char*>(temp->host_addr),
|
||||
size);
|
||||
} else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
|
||||
const VulkanContext& vctx = context_[ctx_to.device_id];
|
||||
const VulkanBuffer* to_buf = static_cast<const VulkanBuffer*>(to);
|
||||
VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_to.device_id, size);
|
||||
memcpy(temp->host_addr,
|
||||
static_cast<const char*>(from) + from_offset,
|
||||
size);
|
||||
// host side flush if access is not coherent.
|
||||
// so writes from CPU is visible to GPU
|
||||
if (!vctx.coherent_staging) {
|
||||
VkMappedMemoryRange mrange;
|
||||
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
|
||||
mrange.pNext = nullptr;
|
||||
mrange.memory = temp->memory;
|
||||
mrange.offset = 0;
|
||||
mrange.size = size;
|
||||
VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
|
||||
}
|
||||
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
|
||||
// 0: barrier(host->transfer)
|
||||
VkMemoryBarrier barrier_info;
|
||||
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
|
||||
barrier_info.pNext = nullptr;
|
||||
barrier_info.srcAccessMask = 0;
|
||||
barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
|
||||
vkCmdPipelineBarrier(cmd->cmd_buffer,
|
||||
VK_PIPELINE_STAGE_HOST_BIT,
|
||||
VK_PIPELINE_STAGE_TRANSFER_BIT,
|
||||
0, 1, &barrier_info,
|
||||
0, nullptr, 0, nullptr);
|
||||
// 1: copy
|
||||
VkBufferCopy copy_info;
|
||||
copy_info.srcOffset = 0;
|
||||
copy_info.dstOffset = to_offset;
|
||||
copy_info.size = size;
|
||||
vkCmdCopyBuffer(cmd->cmd_buffer,
|
||||
temp->buffer,
|
||||
to_buf->buffer,
|
||||
1, ©_info);
|
||||
// 2: end
|
||||
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
|
||||
// 4: submit with cmd->fence
|
||||
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
|
||||
// wait until copy finishes, so we can reuse temp next time.
|
||||
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
|
||||
} else {
|
||||
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
|
||||
<< ", from=" << from_dev_type
|
||||
<< ", to=" << to_dev_type;
|
||||
}
|
||||
}
|
||||
|
||||
void VulkanWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
|
||||
CHECK(stream == nullptr);
|
||||
VulkanContext& vctx = context_[ctx.device_id];
|
||||
VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
|
||||
}
|
||||
|
||||
void* VulkanWorkspace::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) {
|
||||
return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
|
||||
}
|
||||
|
||||
void VulkanWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
|
||||
VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
|
||||
}
|
||||
|
||||
// VulkanCommandPool
|
||||
VulkanCommandPool::VulkanCommandPool(const VulkanContext& vctx) {
|
||||
ring_.resize(kMaxPending, VulkanCommandBuffer());
|
||||
device_ = vctx.device;
|
||||
{
|
||||
// create command pool
|
||||
VkCommandPoolCreateInfo cmd_pool_cinfo;
|
||||
cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
|
||||
cmd_pool_cinfo.pNext = nullptr;
|
||||
cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
|
||||
cmd_pool_cinfo.queueFamilyIndex = vctx.queue_family_index;
|
||||
VULKAN_CALL(vkCreateCommandPool(device_, &cmd_pool_cinfo, nullptr, &cmd_pool_));
|
||||
}
|
||||
{
|
||||
// create descriptor pool
|
||||
VkDescriptorPoolSize pool_size;
|
||||
pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
|
||||
pool_size.descriptorCount = kMaxPending * kMaxNumArgs;
|
||||
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
|
||||
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
|
||||
descrip_pool_cinfo.pNext = nullptr;
|
||||
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
|
||||
descrip_pool_cinfo.maxSets = kMaxPending + 2;
|
||||
descrip_pool_cinfo.poolSizeCount = 1;
|
||||
descrip_pool_cinfo.pPoolSizes = &pool_size;
|
||||
VULKAN_CALL(vkCreateDescriptorPool(
|
||||
device_, &descrip_pool_cinfo, nullptr, &descriptor_pool_));
|
||||
}
|
||||
VkCommandBufferAllocateInfo buffer_alloc_info;
|
||||
buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
|
||||
buffer_alloc_info.pNext = nullptr;
|
||||
buffer_alloc_info.commandPool = cmd_pool_;
|
||||
buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
|
||||
buffer_alloc_info.commandBufferCount = 1;
|
||||
|
||||
VkFenceCreateInfo fence_cinfo;
|
||||
fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
|
||||
fence_cinfo.pNext = nullptr;
|
||||
fence_cinfo.flags = VK_FENCE_CREATE_SIGNALED_BIT;
|
||||
|
||||
for (size_t i = 0; i < ring_.size(); ++i) {
|
||||
VULKAN_CALL(vkAllocateCommandBuffers(
|
||||
device_, &buffer_alloc_info, &(ring_[i].cmd_buffer)));
|
||||
VULKAN_CALL(vkCreateFence(
|
||||
device_, &fence_cinfo, nullptr, &(ring_[i].fence)));
|
||||
}
|
||||
}
|
||||
|
||||
VulkanCommandPool::~VulkanCommandPool() {
|
||||
// wait device to be idle so we know we can recycle buffers
|
||||
VULKAN_CALL(vkDeviceWaitIdle(device_));
|
||||
// start recycling.
|
||||
for (size_t i = 0; i < ring_.size(); ++i) {
|
||||
if (ring_[i].cmd_buffer != nullptr) {
|
||||
vkFreeCommandBuffers(device_, cmd_pool_, 1, &(ring_[i].cmd_buffer));
|
||||
ring_[i].cmd_buffer = nullptr;
|
||||
}
|
||||
if (ring_[i].fence != nullptr) {
|
||||
vkDestroyFence(device_, ring_[i].fence, nullptr);
|
||||
}
|
||||
}
|
||||
// delete cmd_pool and descriptor pool
|
||||
vkDestroyCommandPool(device_, cmd_pool_, nullptr);
|
||||
vkDestroyDescriptorPool(device_, descriptor_pool_, nullptr);
|
||||
}
|
||||
|
||||
VulkanCommandBuffer* VulkanCommandPool::Alloc() {
|
||||
return Alloc(nullptr);
|
||||
}
|
||||
|
||||
VulkanCommandBuffer* VulkanCommandPool::Alloc(
|
||||
const VkDescriptorSetLayout* dlayout) {
|
||||
// always allocate resource in round robin manner
|
||||
VulkanCommandBuffer* e = &(ring_[clock_ptr_]);
|
||||
clock_ptr_ = (clock_ptr_ + 1) % ring_.size();
|
||||
// Wait until previous usage of commad buffer is finished.
|
||||
uint64_t timeout = 1UL << 30UL;
|
||||
VkResult res;
|
||||
res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
|
||||
while (res == VK_TIMEOUT) {
|
||||
res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
|
||||
}
|
||||
VULKAN_CHECK_ERROR(res);
|
||||
vkResetFences(device_, 1, (&e->fence));
|
||||
if (e->descriptor_set != nullptr) {
|
||||
VULKAN_CALL(vkFreeDescriptorSets(
|
||||
device_, descriptor_pool_, 1, &(e->descriptor_set)));
|
||||
e->descriptor_set = nullptr;
|
||||
}
|
||||
if (dlayout != nullptr) {
|
||||
VkDescriptorSetAllocateInfo alloc_info;
|
||||
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
|
||||
alloc_info.pNext = nullptr;
|
||||
alloc_info.descriptorPool = descriptor_pool_;
|
||||
alloc_info.descriptorSetCount = 1;
|
||||
alloc_info.pSetLayouts = dlayout;
|
||||
VULKAN_CALL(vkAllocateDescriptorSets(
|
||||
device_, &alloc_info, &(e->descriptor_set)));
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
// VulkanThreadEntry
|
||||
typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
|
||||
|
||||
VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() {
|
||||
return VulkanThreadStore::Get();
|
||||
}
|
||||
|
||||
VulkanCommandPool* VulkanThreadEntry::CommandPool(int device_id) {
|
||||
while (pool_.size() <= static_cast<size_t>(device_id)) {
|
||||
pool_.emplace_back(std::unique_ptr<VulkanCommandPool>());
|
||||
}
|
||||
if (pool_[device_id] == nullptr) {
|
||||
const VulkanContext& vctx =
|
||||
VulkanWorkspace::Global()->context_[device_id];
|
||||
pool_[device_id].reset(new VulkanCommandPool(vctx));
|
||||
}
|
||||
return pool_[device_id].get();
|
||||
}
|
||||
|
||||
VulkanStagingBuffer*
|
||||
VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
|
||||
if (staging_buffer_.size() <= static_cast<size_t>(device_id)) {
|
||||
staging_buffer_.resize(device_id + 1, VulkanStagingBuffer());
|
||||
}
|
||||
VulkanStagingBuffer& buf = staging_buffer_[device_id];
|
||||
|
||||
if (buf.device != nullptr && buf.size < size) {
|
||||
// free previous buffer
|
||||
if (buf.host_addr != nullptr) {
|
||||
vkUnmapMemory(buf.device, buf.memory);
|
||||
}
|
||||
if (buf.memory != nullptr) {
|
||||
vkFreeMemory(buf.device, buf.memory, nullptr);
|
||||
}
|
||||
if (buf.buffer != nullptr) {
|
||||
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
|
||||
}
|
||||
buf.host_addr = nullptr;
|
||||
buf.memory = nullptr;
|
||||
buf.buffer = nullptr;
|
||||
}
|
||||
const VulkanContext& vctx =
|
||||
VulkanWorkspace::Global()->context_[device_id];
|
||||
|
||||
if (buf.device == nullptr) {
|
||||
buf.device = vctx.device;
|
||||
}
|
||||
if (buf.memory == nullptr) {
|
||||
// allocate the stagging buffer memory if necessary
|
||||
VkBufferCreateInfo info;
|
||||
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
info.pNext = nullptr;
|
||||
info.flags = 0;
|
||||
info.size = size;
|
||||
info.queueFamilyIndexCount = 1;
|
||||
info.pQueueFamilyIndices = &(vctx.queue_family_index);
|
||||
info.usage =
|
||||
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
|
||||
VK_BUFFER_USAGE_TRANSFER_DST_BIT;
|
||||
VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
|
||||
VkMemoryAllocateInfo minfo;
|
||||
minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
minfo.pNext = nullptr;
|
||||
minfo.allocationSize = size;
|
||||
minfo.memoryTypeIndex = vctx.staging_mtype_index;
|
||||
VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
|
||||
VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
|
||||
VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
|
||||
buf.size = size;
|
||||
}
|
||||
memset(buf.host_addr, 0, size);
|
||||
return &buf;
|
||||
}
|
||||
|
||||
VulkanThreadEntry::~VulkanThreadEntry() {
|
||||
// Because the thread entry refers to Device API
|
||||
// The command buffer always will be destroyed before
|
||||
// the instance and device get destroyed.
|
||||
// The destruction need to be manually called
|
||||
// to ensure the destruction order.
|
||||
pool_.clear();
|
||||
for (VulkanStagingBuffer buf : staging_buffer_) {
|
||||
if (buf.host_addr != nullptr) {
|
||||
vkUnmapMemory(buf.device, buf.memory);
|
||||
}
|
||||
if (buf.memory != nullptr) {
|
||||
vkFreeMemory(buf.device, buf.memory, nullptr);
|
||||
}
|
||||
if (buf.buffer != nullptr) {
|
||||
vkDestroyBuffer(buf.device, buf.buffer, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VkInstance CreateInstance() {
|
||||
VkApplicationInfo app_info;
|
||||
app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
|
||||
app_info.pNext = nullptr;
|
||||
app_info.pApplicationName = "TVM";
|
||||
app_info.applicationVersion = 0;
|
||||
app_info.pEngineName = "";
|
||||
app_info.engineVersion = 0;
|
||||
app_info.apiVersion = VK_MAKE_VERSION(1, 0, 65);
|
||||
|
||||
VkInstanceCreateInfo inst_info;
|
||||
inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
inst_info.pNext = nullptr;
|
||||
inst_info.flags = 0;
|
||||
inst_info.pApplicationInfo = &app_info;
|
||||
inst_info.enabledLayerCount = 0;
|
||||
inst_info.ppEnabledLayerNames = nullptr;
|
||||
inst_info.enabledExtensionCount = 0;
|
||||
inst_info.ppEnabledExtensionNames = nullptr;
|
||||
|
||||
VkInstance inst;
|
||||
VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &inst));
|
||||
return inst;
|
||||
}
|
||||
|
||||
// find suitable mem_type_index for staging and compute
|
||||
void FindMemoryTypeIndex(VulkanContext* vctx) {
|
||||
// Find suitable compute index.
|
||||
VkBuffer buffer;
|
||||
VkMemoryRequirements req_staging, req_compute;
|
||||
VkBufferCreateInfo info;
|
||||
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
info.pNext = nullptr;
|
||||
info.flags = 0;
|
||||
info.size = 1024;
|
||||
info.queueFamilyIndexCount = 1;
|
||||
info.pQueueFamilyIndices = &(vctx->queue_family_index);
|
||||
|
||||
// get staging requirement
|
||||
info.usage =
|
||||
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
|
||||
VK_BUFFER_USAGE_TRANSFER_DST_BIT;
|
||||
VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
|
||||
vkGetBufferMemoryRequirements(vctx->device, buffer, &req_staging);
|
||||
vkDestroyBuffer(vctx->device, buffer, nullptr);
|
||||
// get compute requirement
|
||||
info.usage =
|
||||
VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
|
||||
VK_BUFFER_USAGE_TRANSFER_DST_BIT |
|
||||
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
|
||||
VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
|
||||
vkGetBufferMemoryRequirements(vctx->device, buffer, &req_compute);
|
||||
vkDestroyBuffer(vctx->device, buffer, nullptr);
|
||||
|
||||
// Query phyiscal device property
|
||||
// find a memory that is host visible, no need to be consistent
|
||||
int win_rank = -1;
|
||||
VkPhysicalDeviceMemoryProperties prop;
|
||||
vkGetPhysicalDeviceMemoryProperties(vctx->phy_device, &prop);
|
||||
|
||||
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
|
||||
VkMemoryType ty = prop.memoryTypes[k];
|
||||
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
|
||||
// host visible
|
||||
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
|
||||
// match copy requirment
|
||||
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
|
||||
if (heap_size < 1024) continue;
|
||||
int rank = 0;
|
||||
rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
|
||||
if (rank > win_rank) {
|
||||
win_rank = rank;
|
||||
vctx->staging_mtype_index = k;
|
||||
vctx->coherent_staging =
|
||||
ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
|
||||
}
|
||||
}
|
||||
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
|
||||
|
||||
win_rank = -1;
|
||||
for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
|
||||
VkMemoryType ty = prop.memoryTypes[k];
|
||||
size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
|
||||
// host visible
|
||||
if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
|
||||
// match copy requirment
|
||||
if (!(req_staging.memoryTypeBits & (1 << k))) continue;
|
||||
if (heap_size < 1024) continue;
|
||||
int rank = 0;
|
||||
// prefer not host visible
|
||||
rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
|
||||
if (rank > win_rank) {
|
||||
win_rank = rank;
|
||||
vctx->compute_mtype_index = k;
|
||||
}
|
||||
}
|
||||
CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
|
||||
}
|
||||
|
||||
// Get all logic devices that support compute
|
||||
std::vector<VulkanContext> GetContext(VkInstance instance) {
|
||||
std::vector<VulkanContext> result;
|
||||
uint32_t phy_dev_count = 0;
|
||||
VULKAN_CALL(vkEnumeratePhysicalDevices(
|
||||
instance, &phy_dev_count, nullptr));
|
||||
std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
|
||||
VULKAN_CALL(vkEnumeratePhysicalDevices(
|
||||
instance, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
|
||||
for (VkPhysicalDevice phy_dev : all_phy_devs) {
|
||||
uint32_t queue_prop_count = 0;
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(
|
||||
phy_dev, &queue_prop_count, nullptr);
|
||||
std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(
|
||||
phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props));
|
||||
uint32_t queue_family_index = 0;
|
||||
std::vector<VkDeviceQueueCreateInfo> queue_create_info;
|
||||
|
||||
for (uint32_t i = 0; i < queue_props.size(); i++) {
|
||||
// find queues that support compute
|
||||
if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
|
||||
float priority = 1.0f;
|
||||
|
||||
VkDeviceQueueCreateInfo info;
|
||||
info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
|
||||
info.pNext = nullptr;
|
||||
info.flags = 0;
|
||||
info.queueFamilyIndex = i;
|
||||
info.queueCount = 1;
|
||||
info.pQueuePriorities = &priority;
|
||||
|
||||
queue_create_info.push_back(info);
|
||||
// only use the first available queue for now
|
||||
if (queue_create_info.size() == 0) {
|
||||
queue_family_index = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (queue_create_info.size() == 0) continue;
|
||||
|
||||
VkDeviceCreateInfo device_create_info;
|
||||
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
|
||||
device_create_info.pNext = nullptr;
|
||||
device_create_info.flags = 0;
|
||||
device_create_info.queueCreateInfoCount
|
||||
= static_cast<uint32_t>(queue_create_info.size());
|
||||
device_create_info.pQueueCreateInfos = queue_create_info.data();
|
||||
device_create_info.enabledLayerCount = 0;
|
||||
device_create_info.ppEnabledLayerNames = nullptr;
|
||||
device_create_info.enabledExtensionCount = 0;
|
||||
device_create_info.ppEnabledExtensionNames = nullptr;
|
||||
device_create_info.pEnabledFeatures = nullptr;
|
||||
|
||||
VulkanContext ctx;
|
||||
// setup context
|
||||
ctx.phy_device = phy_dev;
|
||||
vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
|
||||
VULKAN_CALL(vkCreateDevice(
|
||||
phy_dev, &device_create_info, nullptr, &(ctx.device)));
|
||||
vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
|
||||
ctx.queue_family_index = queue_family_index;
|
||||
FindMemoryTypeIndex(&ctx);
|
||||
// Find suitable memory type for staging and compute
|
||||
result.push_back(ctx);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void VulkanWorkspace::Init() {
|
||||
if (initialized_) return;
|
||||
std::lock_guard<std::mutex>(this->mu);
|
||||
if (initialized_) return;
|
||||
initialized_ = true;
|
||||
instance_ = CreateInstance();
|
||||
context_ = GetContext(instance_);
|
||||
LOG(INFO) << "Initialzie Vulkan with " << context_.size() << " devices..";
|
||||
for (size_t i = 0; i < context_.size(); ++i) {
|
||||
LOG(INFO) << "vulkan(" << i
|
||||
<< ")=\'" << context_[i].phy_device_prop.deviceName
|
||||
<< "\' phy_dev_id=" << context_[i].phy_device;
|
||||
}
|
||||
}
|
||||
|
||||
bool InitVulkan(TVMArgs args, TVMRetValue* rv) {
|
||||
vulkan::VulkanWorkspace::Global()->Init();
|
||||
return true;
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("device_api.vulkan")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
DeviceAPI* ptr = VulkanWorkspace::Global().get();
|
||||
*rv = static_cast<void*>(ptr);
|
||||
});
|
||||
|
||||
} // namespace vulkan
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_VULKAN_RUNTIME
|
|
@ -0,0 +1,424 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file vulkan_module.cc
|
||||
*/
|
||||
#include "./vulkan_module.h"
|
||||
|
||||
#if TVM_VULKAN_RUNTIME
|
||||
|
||||
#include <dmlc/memory_io.h>
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include <tvm/runtime/module.h>
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include "./vulkan_common.h"
|
||||
#include "../pack_args.h"
|
||||
#include "../thread_storage_scope.h"
|
||||
#include "../meta_data.h"
|
||||
#include "../file_util.h"
|
||||
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
||||
void VulkanShader::Save(dmlc::Stream* writer) const {
|
||||
writer->Write(flag);
|
||||
writer->Write(data);
|
||||
}
|
||||
|
||||
bool VulkanShader::Load(dmlc::Stream* reader) {
|
||||
if (!reader->Read(&flag)) return false;
|
||||
if (!reader->Read(&data)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Multi-device enabled module.
|
||||
class VulkanModuleNode final :public runtime::ModuleNode {
|
||||
public:
|
||||
// Pipeline cache states
|
||||
struct PipelineEntry {
|
||||
VkShaderModule shader{nullptr};
|
||||
VkPipelineLayout pipeline_layout{nullptr};
|
||||
VkDescriptorSetLayout descriptor_layout{nullptr};
|
||||
VkPipeline pipeline{nullptr};
|
||||
};
|
||||
// constructor
|
||||
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string source)
|
||||
: smap_(smap), fmap_(fmap), source_(source) {
|
||||
}
|
||||
|
||||
~VulkanModuleNode() {
|
||||
// cleanup vulkan related caches.
|
||||
for (DeviceEntry& e : finfo_) {
|
||||
if (e.device == nullptr) continue;
|
||||
for (auto &kv : e.smap) {
|
||||
PipelineEntry& pe = kv.second;
|
||||
vkDestroyShaderModule(e.device, pe.shader, nullptr);
|
||||
vkDestroyDescriptorSetLayout(e.device, pe.descriptor_layout, nullptr);
|
||||
vkDestroyPipelineLayout(e.device, pe.pipeline_layout, nullptr);
|
||||
vkDestroyPipeline(e.device, pe.pipeline, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
const char* type_key() const final {
|
||||
return "vulkan";
|
||||
}
|
||||
|
||||
PackedFunc GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
|
||||
|
||||
void SaveToFile(const std::string& file_name,
|
||||
const std::string& format) final {
|
||||
std::string fmt = GetFileFormat(file_name, format);
|
||||
CHECK_EQ(fmt, fmt_)
|
||||
<< "Can only save to customized format vulkan";
|
||||
std::string meta_file = GetMetaFilePath(file_name);
|
||||
SaveMetaDataToFile(meta_file, fmap_);
|
||||
std::string data_bin;
|
||||
dmlc::MemoryStringStream fs(&data_bin);
|
||||
dmlc::Stream* stream = &fs;
|
||||
uint32_t magic = kVulkanModuleMagic;
|
||||
stream->Write(magic);
|
||||
stream->Write(smap_);
|
||||
SaveBinaryToFile(file_name, data_bin);
|
||||
}
|
||||
|
||||
void SaveToBinary(dmlc::Stream* stream) final {
|
||||
stream->Write(fmt_);
|
||||
stream->Write(fmap_);
|
||||
stream->Write(smap_);
|
||||
}
|
||||
std::string GetSource(const std::string& format) final {
|
||||
// can only return source code.
|
||||
return source_;
|
||||
}
|
||||
|
||||
// get a from primary context in device_id
|
||||
PipelineEntry GetPipeline(size_t device_id,
|
||||
const std::string& func_name,
|
||||
size_t num_pack_args) {
|
||||
vulkan::VulkanWorkspace* w = vulkan::VulkanWorkspace::Global().get();
|
||||
CHECK_LT(device_id, w->context_.size());
|
||||
// start lock scope.
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (finfo_.size() <= device_id) {
|
||||
finfo_.resize(device_id + 1, DeviceEntry());
|
||||
}
|
||||
DeviceEntry& e = finfo_[device_id];
|
||||
auto it = e.smap.find(func_name);
|
||||
if (it != e.smap.end()) return it->second;
|
||||
PipelineEntry pe;
|
||||
if (e.device == nullptr) {
|
||||
e.device = w->context_[device_id].device;
|
||||
}
|
||||
{
|
||||
// create shader
|
||||
auto sit = smap_.find(func_name);
|
||||
CHECK(sit != smap_.end());
|
||||
const std::vector<uint32_t>& data = sit->second.data;
|
||||
VkShaderModuleCreateInfo shader_cinfo;
|
||||
shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
|
||||
shader_cinfo.pNext = nullptr;
|
||||
shader_cinfo.flags = 0;
|
||||
shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
|
||||
shader_cinfo.pCode = data.data();
|
||||
VULKAN_CALL(vkCreateShaderModule(
|
||||
e.device, &shader_cinfo, nullptr, &(pe.shader)));
|
||||
}
|
||||
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
|
||||
uint32_t num_pod = 0, num_buffer = 0;
|
||||
{
|
||||
auto fit = fmap_.find(func_name);
|
||||
CHECK(fit != fmap_.end());
|
||||
for (TVMType arg_type : fit->second.arg_types) {
|
||||
if (arg_type.code == kHandle) {
|
||||
VkDescriptorSetLayoutBinding bd;
|
||||
bd.binding = num_buffer;
|
||||
bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
|
||||
bd.descriptorCount = 1;
|
||||
bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
bd.pImmutableSamplers = nullptr;
|
||||
arg_binding.push_back(bd);
|
||||
++num_buffer;
|
||||
} else {
|
||||
++num_pod;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VkDescriptorSetLayoutCreateInfo descrip_cinfo;
|
||||
descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
|
||||
descrip_cinfo.pNext = nullptr;
|
||||
descrip_cinfo.flags = 0;
|
||||
descrip_cinfo.bindingCount = arg_binding.size();
|
||||
descrip_cinfo.pBindings = arg_binding.data();
|
||||
VULKAN_CALL(vkCreateDescriptorSetLayout(
|
||||
e.device, &descrip_cinfo, nullptr, &(pe.descriptor_layout)));
|
||||
|
||||
VkPushConstantRange crange;
|
||||
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
crange.offset = 0;
|
||||
crange.size = sizeof(ArgUnion) * num_pack_args;
|
||||
|
||||
VkPipelineLayoutCreateInfo playout_cinfo;
|
||||
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
|
||||
playout_cinfo.pNext = nullptr;
|
||||
playout_cinfo.flags = 0;
|
||||
playout_cinfo.setLayoutCount = 1;
|
||||
playout_cinfo.pSetLayouts = &(pe.descriptor_layout);
|
||||
|
||||
if (num_pack_args != 0) {
|
||||
playout_cinfo.pushConstantRangeCount = 1;
|
||||
playout_cinfo.pPushConstantRanges = &crange;
|
||||
CHECK_LE(crange.size,
|
||||
w->context_[device_id].phy_device_prop.limits.maxPushConstantsSize);
|
||||
} else {
|
||||
playout_cinfo.pushConstantRangeCount = 0;
|
||||
playout_cinfo.pPushConstantRanges = nullptr;
|
||||
}
|
||||
|
||||
VULKAN_CALL(vkCreatePipelineLayout(
|
||||
e.device, &playout_cinfo, nullptr, &(pe.pipeline_layout)));
|
||||
VkComputePipelineCreateInfo pipeline_cinfo;
|
||||
pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
|
||||
pipeline_cinfo.pNext = nullptr;
|
||||
pipeline_cinfo.flags = 0;
|
||||
pipeline_cinfo.stage.sType =
|
||||
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
pipeline_cinfo.stage.pNext = nullptr;
|
||||
pipeline_cinfo.stage.flags = 0;
|
||||
pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
pipeline_cinfo.stage.module = pe.shader;
|
||||
pipeline_cinfo.stage.pName = func_name.c_str();
|
||||
pipeline_cinfo.stage.pSpecializationInfo = nullptr;
|
||||
pipeline_cinfo.layout = pe.pipeline_layout;
|
||||
pipeline_cinfo.basePipelineHandle = nullptr;
|
||||
pipeline_cinfo.basePipelineIndex = 0;
|
||||
VULKAN_CALL(vkCreateComputePipelines(
|
||||
e.device, nullptr, 1, &pipeline_cinfo, nullptr, &(pe.pipeline)));
|
||||
e.smap[func_name] = pe;
|
||||
return pe;
|
||||
}
|
||||
|
||||
private:
|
||||
// device specific entry
|
||||
struct DeviceEntry {
|
||||
VkDevice device{nullptr};
|
||||
std::unordered_map<std::string, PipelineEntry> smap;
|
||||
};
|
||||
// the binary data
|
||||
std::vector<uint32_t> data_;
|
||||
// function information table.
|
||||
std::unordered_map<std::string, VulkanShader> smap_;
|
||||
// function information table.
|
||||
std::unordered_map<std::string, FunctionInfo> fmap_;
|
||||
// The format
|
||||
std::string fmt_{"vulkan"};
|
||||
// The source
|
||||
std::string source_;
|
||||
// device local pipeline information.
|
||||
std::vector<DeviceEntry> finfo_;
|
||||
// internal mutex when updating the module
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
// a wrapped function class to get packed fucn.
|
||||
class VulkanWrappedFunc {
|
||||
public:
|
||||
// initialize the VULKAN function.
|
||||
void Init(VulkanModuleNode* m,
|
||||
std::shared_ptr<ModuleNode> sptr,
|
||||
const std::string& func_name,
|
||||
size_t num_buffer_args,
|
||||
size_t num_pack_args,
|
||||
const std::vector<std::string>& thread_axis_tags) {
|
||||
w_ = vulkan::VulkanWorkspace::Global().get();
|
||||
m_ = m;
|
||||
sptr_ = sptr;
|
||||
func_name_ = func_name;
|
||||
num_buffer_args_ = num_buffer_args;
|
||||
num_pack_args_ = num_pack_args;
|
||||
thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
|
||||
}
|
||||
// invoke the function with void arguments
|
||||
void operator()(TVMArgs args,
|
||||
TVMRetValue* rv,
|
||||
const ArgUnion* pack_args) const {
|
||||
vulkan::VulkanThreadEntry* tls = vulkan::VulkanThreadEntry::ThreadLocal();
|
||||
int device_id = tls->context.device_id;
|
||||
CHECK_LT(device_id, kVulkanMaxNumDevice);
|
||||
const vulkan::VulkanContext& vctx = w_->context_[device_id];
|
||||
VulkanModuleNode::PipelineEntry& pe = scache_[device_id];
|
||||
if (pe.pipeline == nullptr) {
|
||||
pe = m_->GetPipeline(device_id, func_name_, num_pack_args_);
|
||||
}
|
||||
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
||||
vulkan::VulkanCommandBuffer* cmd = tls->CommandPool(device_id)->Alloc(
|
||||
&(pe.descriptor_layout));
|
||||
|
||||
cmd->write_descriptor_set.dstSet = cmd->descriptor_set;
|
||||
|
||||
// setup descriptors
|
||||
for (uint32_t i = 0; i < num_buffer_args_; ++i) {
|
||||
void* buf = args[static_cast<int>(i)];
|
||||
VkDescriptorBufferInfo binfo;
|
||||
binfo.buffer = static_cast<vulkan::VulkanBuffer*>(buf)->buffer;
|
||||
binfo.offset = 0;
|
||||
binfo.range = VK_WHOLE_SIZE;
|
||||
cmd->write_descriptor_set.dstBinding = i;
|
||||
cmd->write_descriptor_set.pBufferInfo = &binfo;
|
||||
vkUpdateDescriptorSets(
|
||||
vctx.device, 1, &(cmd->write_descriptor_set), 0, nullptr);
|
||||
}
|
||||
|
||||
// dispatch
|
||||
VkCommandBufferBeginInfo cb_begin;
|
||||
cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
cb_begin.pNext = nullptr;
|
||||
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
cb_begin.pInheritanceInfo = 0;
|
||||
|
||||
VkSubmitInfo cb_submit;
|
||||
cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
cb_submit.pNext = nullptr;
|
||||
cb_submit.waitSemaphoreCount = 0;
|
||||
cb_submit.pWaitSemaphores = nullptr;
|
||||
cb_submit.pWaitDstStageMask = 0;
|
||||
cb_submit.commandBufferCount = 1;
|
||||
cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
|
||||
cb_submit.signalSemaphoreCount = 0;
|
||||
cb_submit.pSignalSemaphores = nullptr;
|
||||
// 0: begin
|
||||
VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
|
||||
// 1: dispatch
|
||||
vkCmdBindPipeline(
|
||||
cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline);
|
||||
vkCmdBindDescriptorSets(
|
||||
cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
|
||||
pe.pipeline_layout, 0, 1, &(cmd->descriptor_set), 0, nullptr);
|
||||
// bind push constant if necessary
|
||||
if (num_pack_args_ != 0) {
|
||||
vkCmdPushConstants(
|
||||
cmd->cmd_buffer,
|
||||
pe.pipeline_layout,
|
||||
VK_SHADER_STAGE_COMPUTE_BIT,
|
||||
0, num_pack_args_ * sizeof(ArgUnion),
|
||||
pack_args);
|
||||
}
|
||||
vkCmdDispatch(
|
||||
cmd->cmd_buffer, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
|
||||
// 2: barrier(compute->compute|transfer)
|
||||
VkMemoryBarrier barrier_info;
|
||||
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
|
||||
barrier_info.pNext = nullptr;
|
||||
barrier_info.srcAccessMask =
|
||||
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
|
||||
barrier_info.dstAccessMask =
|
||||
(VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
|
||||
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
|
||||
vkCmdPipelineBarrier(
|
||||
cmd->cmd_buffer,
|
||||
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
|
||||
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
|
||||
0, 1, &barrier_info, 0, nullptr, 0, nullptr);
|
||||
// 3: end
|
||||
VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
|
||||
// 4: submit with cmd->fence
|
||||
VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
|
||||
}
|
||||
|
||||
private:
|
||||
// Reference to global workspace.
|
||||
vulkan::VulkanWorkspace* w_;
|
||||
// internal module
|
||||
VulkanModuleNode* m_;
|
||||
// the resource holder
|
||||
std::shared_ptr<ModuleNode> sptr_;
|
||||
// The name of the function.
|
||||
std::string func_name_;
|
||||
// Number of buffer arguments
|
||||
size_t num_buffer_args_;
|
||||
// number of packed arguments.
|
||||
size_t num_pack_args_;
|
||||
// Device state cache per device.
|
||||
// mark as mutable, to enable lazy initialization
|
||||
mutable std::array<VulkanModuleNode::PipelineEntry, kVulkanMaxNumDevice> scache_;
|
||||
// thread axis configuration
|
||||
ThreadAxisConfig thread_axis_cfg_;
|
||||
};
|
||||
|
||||
PackedFunc VulkanModuleNode::GetFunction(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<ModuleNode>& sptr_to_self) {
|
||||
CHECK_EQ(sptr_to_self.get(), this);
|
||||
CHECK_NE(name, symbol::tvm_module_main)
|
||||
<< "Device function do not have main";
|
||||
auto it = fmap_.find(name);
|
||||
if (it == fmap_.end()) return PackedFunc();
|
||||
const FunctionInfo& info = it->second;
|
||||
VulkanWrappedFunc f;
|
||||
size_t num_buffer_args = NumBufferArgs(info.arg_types);
|
||||
f.Init(this, sptr_to_self, name,
|
||||
num_buffer_args, info.arg_types.size() - num_buffer_args,
|
||||
info.thread_axis_tags);
|
||||
return PackFuncNonBufferArg(f, info.arg_types);
|
||||
}
|
||||
|
||||
Module VulkanModuleCreate(
|
||||
std::unordered_map<std::string, VulkanShader> smap,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string source) {
|
||||
vulkan::VulkanWorkspace::Global()->Init();
|
||||
std::shared_ptr<VulkanModuleNode> n =
|
||||
std::make_shared<VulkanModuleNode>(smap, fmap, source);
|
||||
return Module(n);
|
||||
}
|
||||
|
||||
// Load module from module.
|
||||
Module VulkanModuleLoadFile(const std::string& file_name,
|
||||
const std::string& format) {
|
||||
std::string data;
|
||||
std::unordered_map<std::string, VulkanShader> smap;
|
||||
std::unordered_map<std::string, FunctionInfo> fmap;
|
||||
std::string fmt = GetFileFormat(file_name, format);
|
||||
std::string meta_file = GetMetaFilePath(file_name);
|
||||
LoadBinaryFromFile(file_name, &data);
|
||||
LoadMetaDataFromFile(meta_file, &fmap);
|
||||
dmlc::MemoryStringStream fs(&data);
|
||||
dmlc::Stream* stream = &fs;
|
||||
uint32_t magic;
|
||||
stream->Read(&magic);
|
||||
CHECK_EQ(magic, kVulkanModuleMagic)
|
||||
<< "VulkanModule Magic mismatch";
|
||||
stream->Read(&smap);
|
||||
return VulkanModuleCreate(smap, fmap, "");
|
||||
}
|
||||
|
||||
Module VulkanModuleLoadBinary(void* strm) {
|
||||
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
|
||||
std::unordered_map<std::string, VulkanShader> smap;
|
||||
std::unordered_map<std::string, FunctionInfo> fmap;
|
||||
|
||||
std::string fmt;
|
||||
stream->Read(&fmt);
|
||||
stream->Read(&fmap);
|
||||
stream->Read(&smap);
|
||||
return VulkanModuleCreate(smap, fmap, "");
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = VulkanModuleLoadFile(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = VulkanModuleLoadBinary(args[0]);
|
||||
});
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_VULKAN_RUNTIME
|
|
@ -0,0 +1,64 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file metal_module.h
|
||||
* \brief Execution handling of Metal kernels
|
||||
*/
|
||||
#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
|
||||
#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
|
||||
|
||||
#include <tvm/runtime/config.h>
|
||||
#include <tvm/runtime/packed_func.h>
|
||||
#include <dmlc/type_traits.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "../meta_data.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
/*! \brief Maximum number of GPU supported in VulkanModule. */
|
||||
static constexpr const int kVulkanMaxNumDevice = 8;
|
||||
|
||||
/*! \brief TVM Vulkan binary pack magic number */
|
||||
static constexpr const int kVulkanModuleMagic = 0x02700027;
|
||||
|
||||
/*!
|
||||
* \brief A single VK shader program
|
||||
*
|
||||
* Due to the global resource declaration.
|
||||
* Current SPIRV only allows one entry program per shader,
|
||||
* making it less useful for a Module like system.
|
||||
*
|
||||
* Instead we pass in map of str->VulkanShader until
|
||||
* there is a native solution available.
|
||||
*/
|
||||
struct VulkanShader {
|
||||
/*! \brief header flag */
|
||||
uint32_t flag{0};
|
||||
/*! \brief Data segment */
|
||||
std::vector<uint32_t> data;
|
||||
|
||||
void Save(dmlc::Stream *writer) const;
|
||||
bool Load(dmlc::Stream *reader);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief create a metal module from data.
|
||||
*
|
||||
* \param pmap The program map.
|
||||
* \param fmap The function information map.
|
||||
* \param source Optional, source code.
|
||||
*/
|
||||
Module VulkanModuleCreate(
|
||||
std::unordered_map<std::string, VulkanShader> smap,
|
||||
std::unordered_map<std::string, FunctionInfo> fmap,
|
||||
std::string source);
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::VulkanShader, true);
|
||||
} // namespace dmlc
|
||||
|
||||
#endif // TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
|
|
@ -18,7 +18,8 @@ def test_exp():
|
|||
def check_device(device, host="stackvm"):
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
return
|
||||
fexp = tvm.build(s, [A, B],
|
||||
device, host,
|
||||
|
@ -33,6 +34,7 @@ def test_exp():
|
|||
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
|
||||
|
||||
check_device("cuda", "llvm")
|
||||
check_device("vulkan")
|
||||
check_device("opencl")
|
||||
|
||||
|
||||
|
@ -75,11 +77,12 @@ def test_popcount():
|
|||
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
if str(ctx).startswith('gpu'):
|
||||
target = tvm.target.create(device)
|
||||
if "cpu" not in target.keys:
|
||||
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
func = tvm.build(s, [A, B], device)
|
||||
|
@ -95,6 +98,8 @@ def test_popcount():
|
|||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
if dtype == "uint32":
|
||||
check_device("vulkan")
|
||||
run('uint32')
|
||||
run('uint64')
|
||||
|
||||
|
@ -121,14 +126,14 @@ def test_add():
|
|||
|
||||
# one line to build the function.
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
fadd = tvm.build(s, [A, B, C],
|
||||
device,
|
||||
name="myadd")
|
||||
print(fadd.imported_modules[0].get_source())
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
# launch the kernel.
|
||||
n = 1024
|
||||
a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx)
|
||||
|
@ -142,6 +147,8 @@ def test_add():
|
|||
check_device("opencl")
|
||||
check_device("metal")
|
||||
check_device("cuda")
|
||||
check_device("vulkan")
|
||||
|
||||
run("float32")
|
||||
run("int32")
|
||||
run("int64")
|
||||
|
@ -149,7 +156,7 @@ def test_add():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_add()
|
||||
test_log_pow_llvm()
|
||||
test_exp()
|
||||
test_add()
|
||||
test_popcount()
|
||||
|
|
|
@ -2,6 +2,7 @@ import tvm
|
|||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def test_gemm():
|
||||
# graph
|
||||
nn = 1024
|
||||
|
@ -64,13 +65,14 @@ def test_gemm():
|
|||
|
||||
# one line to build the function.
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
|
||||
with tvm.target.create(device):
|
||||
f = tvm.build(s, [A, B, C])
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
# launch the kernel.
|
||||
n = nn
|
||||
m = n
|
||||
|
@ -86,12 +88,12 @@ def test_gemm():
|
|||
np.testing.assert_allclose(
|
||||
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
|
||||
|
||||
check_device("vulkan")
|
||||
check_device("nvptx -mcpu=sm_20")
|
||||
check_device("rocm")
|
||||
check_device("metal")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
#check_device("nvptx -mcpu=sm_20")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemm()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_reduce_prims():
|
||||
def test_prim(reducer, np_reducer):
|
||||
# graph
|
||||
|
@ -21,12 +22,12 @@ def test_reduce_prims():
|
|||
|
||||
# one line to build the function.
|
||||
def check_device(device, host="stackvm"):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.module.enabled(device):
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
freduce = tvm.build(s,
|
||||
args=[A, B],
|
||||
target=device, target_host=host,
|
||||
|
@ -44,6 +45,7 @@ def test_reduce_prims():
|
|||
np.testing.assert_allclose(npy, res, rtol=1e-4)
|
||||
|
||||
check_device("metal")
|
||||
check_device("vulkan")
|
||||
check_device("cuda")
|
||||
check_device("opencl")
|
||||
test_prim(tvm.sum, np.sum)
|
||||
|
@ -106,10 +108,11 @@ def test_rfactor_threads():
|
|||
|
||||
# one line to build the function.
|
||||
def check_target(device, host="stackvm"):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
fapi = tvm.lower(s, args=[A, B])
|
||||
fsum = tvm.build(fapi,
|
||||
target=device,
|
||||
|
@ -125,6 +128,7 @@ def test_rfactor_threads():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), res, rtol=1e-4)
|
||||
|
||||
check_target("vulkan")
|
||||
check_target("cuda")
|
||||
check_target("metal")
|
||||
check_target("opencl")
|
||||
|
@ -159,15 +163,14 @@ def test_rfactor_elemwise_threads():
|
|||
|
||||
# one line to build the function.
|
||||
def check_target(device, host="stackvm"):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
fapi = tvm.lower(s, args=[A, C])
|
||||
fsum = tvm.build(fapi,
|
||||
target=device,
|
||||
name="mysum")
|
||||
print(fsum.imported_modules[0].get_source())
|
||||
# launch the kernel.
|
||||
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
|
||||
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
|
||||
|
@ -176,6 +179,7 @@ def test_rfactor_elemwise_threads():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), res, rtol=1e-4)
|
||||
|
||||
check_target("vulkan")
|
||||
check_target("cuda")
|
||||
check_target("metal")
|
||||
check_target("opencl")
|
||||
|
@ -264,10 +268,10 @@ def test_rfactor_argmax():
|
|||
s[B0].set_store_predicate(thread_x.var.equal(0))
|
||||
|
||||
def check_target(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
fapi = tvm.lower(s, args=[A0, A1, B0, B1])
|
||||
fargmax = tvm.build(fapi,
|
||||
target=device,
|
||||
|
@ -285,6 +289,7 @@ def test_rfactor_argmax():
|
|||
np.testing.assert_allclose(np_res, nd_res0.asnumpy())
|
||||
|
||||
check_target("cuda")
|
||||
check_target("vulkan")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rfactor_elemwise_threads()
|
||||
|
|
|
@ -24,13 +24,13 @@ def test_scan():
|
|||
|
||||
# one line to build the function.
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("skip because %s is not enabled.." % device)
|
||||
return
|
||||
fscan = tvm.build(s, [X, res],
|
||||
device,
|
||||
name="myscan")
|
||||
ctx = tvm.context(device, 0)
|
||||
# launch the kernel.
|
||||
n = 1024
|
||||
m = 10
|
||||
|
@ -41,6 +41,7 @@ def test_scan():
|
|||
np.testing.assert_allclose(
|
||||
b.asnumpy(), np.cumsum(a_np, axis=0))
|
||||
|
||||
check_device("vulkan")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("opencl")
|
||||
|
|
|
@ -13,12 +13,12 @@ def test_add_pipeline():
|
|||
# GPU schedule have to split by gridIdx and threadIdx
|
||||
num_thread = 256
|
||||
xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
|
||||
s[C].bind(xo, tvm.thread_axis("threadIdx.x"))
|
||||
s[C].bind(xi, tvm.thread_axis("blockIdx.x"))
|
||||
s[C].bind(xi, tvm.thread_axis("threadIdx.x"))
|
||||
s[C].bind(xo, tvm.thread_axis("blockIdx.x"))
|
||||
|
||||
xo, xi = s[D].split(D.op.axis[0], factor=num_thread)
|
||||
s[D].bind(xo, tvm.thread_axis("threadIdx.x"))
|
||||
s[D].bind(xi, tvm.thread_axis("blockIdx.x"))
|
||||
s[D].bind(xi, tvm.thread_axis("threadIdx.x"))
|
||||
s[D].bind(xo, tvm.thread_axis("blockIdx.x"))
|
||||
|
||||
# compile to IR
|
||||
s = s.normalize()
|
||||
|
@ -35,11 +35,11 @@ def test_add_pipeline():
|
|||
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
|
||||
|
||||
def check_target(device, host="stackvm"):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
return
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
mhost = tvm.codegen.build_module(fsplits[0], host)
|
||||
mdev = tvm.codegen.build_module(fsplits[1:], device)
|
||||
mhost.import_module(mdev)
|
||||
|
@ -55,12 +55,12 @@ def test_add_pipeline():
|
|||
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
|
||||
|
||||
def check_module_save(device, host="stackvm"):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
return
|
||||
if not tvm.module.enabled(host):
|
||||
return
|
||||
if not tvm.module.enabled(device):
|
||||
return
|
||||
ctx = tvm.context(device, 0)
|
||||
fmt = "ptx" if device == "cuda" else "cl"
|
||||
fmt = "ptx" if device == "cuda" else device
|
||||
mhost = tvm.codegen.build_module(fsplits[0], host)
|
||||
mdev = tvm.codegen.build_module(fsplits[1:], device)
|
||||
temp = util.tempdir()
|
||||
|
@ -82,7 +82,9 @@ def test_add_pipeline():
|
|||
check_target("cuda", host="llvm")
|
||||
check_module_save("cuda", host="stackvm")
|
||||
check_target("nvptx", host="llvm")
|
||||
check_target("vulkan", host="llvm")
|
||||
check_target("rocm", host="llvm")
|
||||
check_module_save("vulkan", host="stackvm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -95,7 +95,7 @@ def test_device_module_dump():
|
|||
f = tvm.build(s, [A, B], device, "llvm", name=name)
|
||||
else:
|
||||
raise ValueError("Unsupported platform")
|
||||
|
||||
|
||||
path_dso = temp.relpath("dev_lib.so")
|
||||
f.export_library(path_dso)
|
||||
|
||||
|
@ -110,6 +110,7 @@ def test_device_module_dump():
|
|||
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
|
||||
|
||||
check_device("cuda")
|
||||
check_device("vulkan")
|
||||
check_device("opencl")
|
||||
check_device("metal")
|
||||
|
||||
|
@ -172,7 +173,7 @@ def test_combine_module_llvm():
|
|||
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
|
||||
mm['myadd2'](a, b)
|
||||
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
|
||||
|
||||
|
||||
if sys.platform != "win32":
|
||||
check_system_lib()
|
||||
check_llvm()
|
||||
|
|
|
@ -7,6 +7,7 @@ def enabled_ctx_list():
|
|||
('cl', tvm.opencl(0)),
|
||||
('metal', tvm.metal(0)),
|
||||
('rocm', tvm.rocm(0)),
|
||||
('vulkan', tvm.vulkan(0)),
|
||||
('vpi', tvm.vpi(0))]
|
||||
for k, v in ctx_list:
|
||||
assert tvm.context(k, 0) == v
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import tvm
|
||||
import os
|
||||
from tvm.contrib import nvcc
|
||||
from tvm.contrib import spirv
|
||||
import numpy as np
|
||||
|
||||
TASK="gemm"
|
||||
|
@ -25,6 +26,7 @@ def tvm_callback_cuda_postproc(code):
|
|||
code = open("perf/%s_manual.cu" % TASK).read()
|
||||
return code
|
||||
|
||||
|
||||
def test_gemm():
|
||||
# graph
|
||||
nn = 2048
|
||||
|
@ -101,12 +103,12 @@ def test_gemm():
|
|||
s[BB].double_buffer()
|
||||
# correctness
|
||||
def check_device(device):
|
||||
print("Device %s" % device)
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Device %s" % device)
|
||||
f = tvm.build(s, [A, B, C], device)
|
||||
ctx = tvm.context(device, 0)
|
||||
# launch the kernel.
|
||||
n, m, l = nn, nn, nn
|
||||
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
|
||||
|
@ -126,7 +128,7 @@ def test_gemm():
|
|||
GFLOPS = num_flops / (t * 1e3) / 1e6
|
||||
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
|
||||
|
||||
for device in ["cuda", "opencl", "rocm", "nvptx"]:
|
||||
for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]:
|
||||
with tvm.build_config(auto_unroll_max_step=128,
|
||||
unroll_explicit=(device != "cuda")):
|
||||
check_device(device)
|
||||
|
|
|
@ -9,13 +9,13 @@ def verify_broadcast_to_ele(in_shape, out_shape):
|
|||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.broadcast_to(A, out_shape)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_broadcast(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="broadcast_to")
|
||||
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
|
||||
out_npy = np.broadcast_to(data_npy, out_shape)
|
||||
|
@ -25,6 +25,7 @@ def verify_broadcast_to_ele(in_shape, out_shape):
|
|||
foo(data_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
check_device("vulkan")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
|
@ -50,13 +51,13 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_broadcast(C)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
|
||||
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
|
||||
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
|
||||
|
@ -82,6 +83,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
|
|||
foo(lhs_nd, rhs_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
|
||||
|
||||
check_device("vulkan")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
check_device("metal")
|
||||
|
@ -105,5 +107,5 @@ def test_broadcast_binary():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_broadcast_to()
|
||||
test_broadcast_binary()
|
||||
test_broadcast_to()
|
||||
|
|
|
@ -31,11 +31,11 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
a_np, w_np, b_np, c_np = get_ref_data()
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
|
@ -49,7 +49,7 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -29,14 +29,14 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
a_np, w_np, b_np, c_np = get_ref_data()
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s1 = topi.generic.schedule_conv2d_nchw([B])
|
||||
s2 = topi.generic.schedule_conv2d_nchw([C])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
|
@ -50,7 +50,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
|
|||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -29,14 +29,14 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
|
|||
a_np, w_np, b_np, c_np = get_ref_data()
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
|
||||
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
w = tvm.nd.array(w_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
|
@ -50,7 +50,7 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel,
|
|||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
a_np, b_np, c_np, d_np = get_ref_data()
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_dense(D)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(b_np, ctx)
|
||||
c = tvm.nd.array(c_np, ctx)
|
||||
|
@ -44,7 +44,7 @@ def verify_dense(batch, in_dim, out_dim, use_bias=True):
|
|||
f(a, b, c, d)
|
||||
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_dense():
|
||||
|
|
|
@ -23,7 +23,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -32,7 +33,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
|
||||
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
|
||||
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernels
|
||||
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
|
||||
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
|
||||
|
@ -90,6 +90,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
|
|||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("vulkan")
|
||||
|
||||
|
||||
def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
|
||||
in_width = in_height
|
||||
|
@ -108,7 +110,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
# schedule
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
|
@ -117,7 +120,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
|
||||
s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
|
||||
s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernels
|
||||
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
|
||||
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
|
||||
|
@ -177,6 +179,7 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
|
|||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("vulkan")
|
||||
|
||||
def test_depthwise_conv2d():
|
||||
print("testing nchw")
|
||||
|
|
|
@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
|
|||
schedule = schedule_depthwise_conv2d_backward_input_nhwc(In_grad)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernel
|
||||
f = tvm.build(schedule, [Filter, Out_grad, In_grad], device)
|
||||
# prepare pod type for test data closure
|
||||
|
@ -85,6 +85,7 @@ def verify_depthwise_conv2d_back_input(batch, in_channel, in_h, channel_multipli
|
|||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("vulkan")
|
||||
|
||||
def test_topi_depthwise_conv2d_backward_input_nhwc():
|
||||
verify_depthwise_conv2d_back_input(16, 256, 56, 1, 3, 1, 1)
|
||||
|
|
|
@ -32,11 +32,11 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
|
|||
schedule = schedule_depthwise_conv2d_backward_weight_nhwc(Weight_grad)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
ctx = tvm.context(device, 0)
|
||||
# build the kernel
|
||||
f = tvm.build(schedule, [Input, Out_grad, Weight_grad], device)
|
||||
# prepare pod type for test data closure
|
||||
|
@ -78,6 +78,7 @@ def verify_depthwise_conv2d_back_weight(batch, in_channel, in_h, channel_multipl
|
|||
check_device("cuda")
|
||||
check_device("metal")
|
||||
check_device("rocm")
|
||||
check_device("vulkan")
|
||||
|
||||
def test_topi_depthwise_conv2d_backward_weight_nhwc():
|
||||
verify_depthwise_conv2d_back_weight(16, 256, 56, 1, 3, 1, 1)
|
||||
|
|
|
@ -44,20 +44,21 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
|
|||
b_np = np.maximum(b_np, 0.0)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_pool(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
f(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_pool():
|
||||
|
@ -82,20 +83,20 @@ def verify_global_pool(n, c, h, w, pool_type):
|
|||
b_np = np.maximum(b_np, 0.0)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_global_pool(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
f(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_global_pool():
|
||||
|
|
|
@ -47,13 +47,14 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
|
|||
raise NotImplementedError
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_reduce(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
foo = tvm.build(s, [A, B], device, name=type)
|
||||
# Test
|
||||
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
|
||||
|
@ -90,7 +91,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
|
|||
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
|
||||
else:
|
||||
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
|
||||
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
|
||||
for device in ["cuda", "opencl", "metal", "llvm", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -13,20 +13,21 @@ def verify_relu(m, n):
|
|||
b_np = a_np * (a_np > 0)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_elemwise(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
foo = tvm.build(s, [A, B], device, name="relu")
|
||||
foo(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -17,20 +17,21 @@ def verify_softmax(m, n):
|
|||
b_np = topi.testing.softmax_python(a_np)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_softmax(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
foo = tvm.build(s, [A, B], device, name="softmax")
|
||||
foo(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm']:
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_softmax():
|
||||
|
@ -48,20 +49,20 @@ def verify_log_softmax(m, n):
|
|||
b_np = topi.testing.log_softmax_python(a_np)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_softmax(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
foo = tvm.build(s, [A, B], device, name="log_softmax")
|
||||
foo(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ["cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
|
|||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.expand_dims(A, axis, num_newaxis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_broadcast(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="expand_dims")
|
||||
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
|
||||
out_npy = data_npy.reshape(out_shape)
|
||||
|
@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
|
|||
foo(data_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
@ -30,13 +30,13 @@ def verify_tranpose(in_shape, axes):
|
|||
A = tvm.placeholder(shape=in_shape, name="A")
|
||||
B = topi.transpose(A, axes)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="tranpose")
|
||||
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
|
||||
out_npy = data_npy.transpose(axes)
|
||||
|
@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes):
|
|||
foo(data_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
@ -53,13 +53,13 @@ def verify_reshape(src_shape, dst_shape):
|
|||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
B = topi.reshape(A, dst_shape)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
foo = tvm.build(s, [A, B], device, name="reshape")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
out_npy = np.reshape(data_npy, newshape=dst_shape)
|
||||
|
@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape):
|
|||
foo(data_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
@ -76,13 +76,14 @@ def verify_squeeze(src_shape, axis):
|
|||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
B = topi.squeeze(A, axis=axis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
foo = tvm.build(s, [A, B], device, name="squeeze")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
out_npy = np.squeeze(data_npy, axis=axis)
|
||||
|
@ -95,7 +96,7 @@ def verify_squeeze(src_shape, axis):
|
|||
foo(data_nd, out_nd)
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
def verify_concatenate(shapes, axis):
|
||||
|
@ -104,13 +105,14 @@ def verify_concatenate(shapes, axis):
|
|||
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
|
||||
out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(out_tensor)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
|
||||
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
|
||||
out_npy = np.concatenate(data_npys, axis=axis)
|
||||
|
@ -119,7 +121,7 @@ def verify_concatenate(shapes, axis):
|
|||
foo(*(data_nds + [out_nd]))
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
@ -127,13 +129,14 @@ def verify_split(src_shape, indices_or_sections, axis):
|
|||
A = tvm.placeholder(shape=src_shape, name="A")
|
||||
tensor_l = topi.split(A, indices_or_sections, axis=axis)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(tensor_l)
|
||||
ctx = tvm.context(device, 0)
|
||||
|
||||
foo = tvm.build(s, [A] + tensor_l, device, name="split")
|
||||
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
|
||||
out_npys = np.split(data_npy, indices_or_sections, axis=axis)
|
||||
|
@ -143,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis):
|
|||
for out_nd, out_npy in zip(out_nds, out_npys):
|
||||
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
|
||||
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
|
||||
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
|
||||
check_device(device)
|
||||
|
||||
|
||||
|
|
|
@ -14,13 +14,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
|
|||
b_np = topi.testing.upsampling_python(a_np, scale)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
|
@ -28,7 +28,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
|
|||
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['llvm', 'cuda']:
|
||||
for device in ['llvm', 'cuda', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_upsampling():
|
||||
|
|
Загрузка…
Ссылка в новой задаче