[BACKEND] initial llvm codegen for amdgpu (#402)
* added initial llvm codegen for amdgpu * fixed whitespace * fixed hsaco gen from ir * fixed targetmachine for rocm and added GetSource for rocm * fixed whitespace issues * changed statement to use less than 100 lines * added intrinsics for workgroup - rocm * whitespace - newline error fix * fixed error msg for workitem-workgroup intrinsics * added llvm ir dump for rocm codegen * [ROCM] changed codegen to emit proper amdgpu kernel header * fixed whitespace error * fixed whitespace error- 2 * fixed AddFunction to not to use extra arg 1. Changed AddFunctionInternal to not to take extra arg for target type 2. Use Target from CodeGenLLVM to check for AMDGPU target * fixed whitespaces * fixed whitespaces 2 * fixed codegen for AMDGPU - now generating valid IR * fixed codegen depending on code review * reviewed alignment for amd devices * added code to dump code object to file * fixed cpplint errors * print out IR after pass manager * added code to dump asm, obj to file and std string * fixed whitespaces * Update codegen_amdgpu.cc * used registry for amdgpu llvm * Fixed whitespaces * added code for calling linker * fixed formatting errors * added rocm link python interface * fixed pylint issues and added more body to the function * added doc string * added doc string for module * fixed python code after review, fixed llvm object codegen * fixed linker to generate code object * removed dumping to output file and debugging log out * fixed lint for python code * added fault check after running linker * removed print statement in rocm.py * changed rocm lld linker to raise runtimeerror than emitting error log to stderr * changed the way linker command line is pass to subprocess.popen * removed redundant code and reuse tvm utils * removed commented out code * removed cloning of unused modules, and put IR into string
This commit is contained in:
Родитель
5061a6da5e
Коммит
891e226bae
|
@ -1 +1 @@
|
|||
Subproject commit 46886a6b47f660cda581e497378204ccc029a01e
|
||||
Subproject commit a527100d7d5001efc4954848a2fc6027e48c05f4
|
|
@ -29,3 +29,4 @@ from .ndarray import register_extension
|
|||
from .schedule import create_schedule
|
||||
from .build_module import build, lower, build_config
|
||||
from .tag import tag_scope
|
||||
from .contrib import rocm as _rocm
|
||||
|
|
|
@ -59,6 +59,8 @@ def context(dev_type, dev_id=0):
|
|||
if dev_type not in TVMContext.STR2MASK:
|
||||
if dev_type.find("nvptx") != -1:
|
||||
dev_type = "cuda"
|
||||
if dev_type.find("rocm") != -1:
|
||||
dev_type = "rocm"
|
||||
if dev_type not in TVMContext.STR2MASK:
|
||||
raise ValueError("Unknown device type %s" % dev_type)
|
||||
dev_type = TVMContext.STR2MASK[dev_type]
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
"""Utility for ROCm backend"""
|
||||
import subprocess
|
||||
from . import util
|
||||
from ..api import register_func
|
||||
|
||||
def rocm_link(in_file, out_file):
|
||||
"""Link relocatable ELF object to shared ELF object using lld
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_file : str
|
||||
Input file name (relocatable ELF object file)
|
||||
|
||||
out_file : str
|
||||
Output file name (shared ELF object file)
|
||||
"""
|
||||
args = ["ld.lld", "-shared", in_file, "-o", out_file]
|
||||
proc = subprocess.Popen(
|
||||
args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
(out, _) = proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
msg = "Linking error using ld.lld:\n"
|
||||
msg += str(out)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@register_func("tvm_callback_rocm_link")
|
||||
def callback_rocm_link(obj_bin):
|
||||
"""Links object file generated from LLVM to HSA Code Object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj_bin : bytearray
|
||||
The object file
|
||||
|
||||
Return
|
||||
------
|
||||
cobj_bin : bytearray
|
||||
The HSA Code Object
|
||||
"""
|
||||
tmp_dir = util.tempdir()
|
||||
tmp_obj = tmp_dir.relpath("rocm_kernel.o")
|
||||
tmp_cobj = tmp_dir.relpath("rocm_kernel.co")
|
||||
with open(tmp_obj, "wb") as out_file:
|
||||
out_file.write(bytes(obj_bin))
|
||||
rocm_link(tmp_obj, tmp_cobj)
|
||||
cobj_bin = bytearray(open(tmp_cobj, "rb").read())
|
||||
return cobj_bin
|
|
@ -0,0 +1,188 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_amdgpu.cc
|
||||
* \brief AMDGPU code generator.
|
||||
*/
|
||||
#ifdef TVM_LLVM_VERSION
|
||||
#if TVM_ROCM_RUNTIME
|
||||
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <tvm/runtime/registry.h>
|
||||
#include "./codegen_llvm.h"
|
||||
#include "../build_common.h"
|
||||
#include "../../pass/ir_util.h"
|
||||
#include "../../runtime/rocm/rocm_module.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
// AMDGPU code generator.
|
||||
class CodeGenAMDGPU : public CodeGenLLVM {
|
||||
public:
|
||||
void AddFunction(const LoweredFunc& f) final {
|
||||
// add function as void return value
|
||||
CodeGenLLVM::AddFunctionInternal(f, true);
|
||||
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
|
||||
}
|
||||
|
||||
void VisitStmt_(const Allocate* op) final {
|
||||
CHECK(!is_zero(op->condition));
|
||||
llvm::Value* buf = nullptr;
|
||||
if (op->new_expr.defined()) {
|
||||
CHECK_EQ(op->free_function, "nop");
|
||||
buf = MakeValue(op->new_expr);
|
||||
} else {
|
||||
int32_t constant_size = op->constant_allocation_size();
|
||||
CHECK_GT(constant_size, 0)
|
||||
<< "Can only handle constant size stack allocation in GPU";
|
||||
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
|
||||
if (constant_size % 4 == 0 && info.alignment == 0) {
|
||||
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
|
||||
}
|
||||
// maximum necessary alignment in the AMD devices
|
||||
if (info.alignment > 16) {
|
||||
info.alignment = 16;
|
||||
}
|
||||
if (info.scope.rank == 2) {
|
||||
// const int local_address_space = 5;
|
||||
// TODO(tqchen): for higher version of LLVM, local address space can be set.
|
||||
llvm::AllocaInst* alloca = builder_->CreateAlloca(
|
||||
LLVMType(op->type), ConstInt32(constant_size));
|
||||
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
|
||||
alloca->setAlignment(info.alignment);
|
||||
}
|
||||
buf = alloca;
|
||||
} else {
|
||||
CHECK_EQ(info.scope.rank, 1)
|
||||
<< "Can only allocate shared or local memory inside kernel";
|
||||
// Shared memory: address space == 3
|
||||
const unsigned shared_address_space = 3;
|
||||
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
|
||||
// Allocate shared memory in global, address_space = 3
|
||||
llvm::GlobalVariable *global = new llvm::GlobalVariable(
|
||||
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
|
||||
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
|
||||
global->setAlignment(info.alignment);
|
||||
buf = global;
|
||||
}
|
||||
}
|
||||
buf = builder_->CreatePointerCast(
|
||||
buf, LLVMType(op->type)->getPointerTo(
|
||||
buf->getType()->getPointerAddressSpace()));
|
||||
CHECK(!var_map_.count(op->buffer_var.get()));
|
||||
var_map_[op->buffer_var.get()] = buf;
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
// Return the thread index via intrinsics.
|
||||
llvm::Value* GetThreadIndex(const IterVar& iv) final {
|
||||
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
|
||||
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
|
||||
if (ts.rank == 1) {
|
||||
switch (ts.dim_index) {
|
||||
case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break;
|
||||
case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break;
|
||||
case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break;
|
||||
default: LOG(FATAL) << "unknown workitem idx";
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(ts.rank, 0);
|
||||
switch (ts.dim_index) {
|
||||
case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break;
|
||||
case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break;
|
||||
case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break;
|
||||
default: LOG(FATAL) << "unknown workgroup idx";
|
||||
}
|
||||
}
|
||||
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
|
||||
return builder_->CreateCall(f, {});
|
||||
}
|
||||
|
||||
llvm::Value* CreateStorageSync(const Call* op) final {
|
||||
const std::string& sync = op->args[0].as<StringImm>()->value;
|
||||
if (sync == "warp") {
|
||||
// TODO(tqchen) warp sync in CUDA9
|
||||
return nullptr;
|
||||
} else if (sync == "shared") {
|
||||
llvm::Function* f = llvm::Intrinsic::getDeclaration(
|
||||
module_.get(),
|
||||
::llvm::Intrinsic::amdgcn_s_barrier);
|
||||
return builder_->CreateCall(f, {});
|
||||
} else {
|
||||
LOG(FATAL) << "Do not support sync " << sync;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
|
||||
// Additional optimization hook to tweak the builder.
|
||||
}
|
||||
|
||||
unsigned GetGlobalAddressSpace() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitTarget(llvm::TargetMachine* tm) final {
|
||||
// Maximum vector lane = float4
|
||||
native_vector_bits_ = 4 * 32;
|
||||
CodeGenLLVM::InitTarget(tm);
|
||||
}
|
||||
};
|
||||
|
||||
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
|
||||
CHECK(target.length(
|
||||
) >= 4 &&
|
||||
target.substr(0, 4) == "rocm");
|
||||
llvm::TargetMachine* tm = \
|
||||
GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx900" + \
|
||||
target.substr(4, target.length() - 4));
|
||||
|
||||
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
|
||||
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
|
||||
cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
|
||||
for (LoweredFunc f : funcs) {
|
||||
cg->AddFunction(f);
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::Module> module = cg->Finish();
|
||||
|
||||
llvm::SmallString<8> dataObj, data_ll, dataAsm;
|
||||
llvm::raw_svector_ostream destObj(dataObj), dest_ll(data_ll), destAsm(dataAsm);
|
||||
destObj.SetUnbuffered();
|
||||
dest_ll.SetUnbuffered();
|
||||
destAsm.SetUnbuffered();
|
||||
module->print(dest_ll, nullptr);
|
||||
std::unique_ptr<llvm::Module> mAsm = llvm::CloneModule(module.get());
|
||||
std::unique_ptr<llvm::Module> mObj = llvm::CloneModule(module.get());
|
||||
llvm::legacy::PassManager pass;
|
||||
|
||||
CHECK(tm->addPassesToEmitFile(
|
||||
pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0)
|
||||
<< "Cannot emit target CGFT_ObjectFile";
|
||||
pass.run(*mObj);
|
||||
std::string obj(dataObj.begin(), dataObj.end());
|
||||
|
||||
const auto* f = tvm::runtime::Registry::Get("tvm_callback_rocm_link");
|
||||
CHECK(f != nullptr) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";
|
||||
|
||||
TVMByteArray arr;
|
||||
arr.data = &obj[0];
|
||||
arr.size = obj.length();
|
||||
|
||||
std::string hsaco = (*f)(arr);
|
||||
std::string ll(data_ll.begin(), data_ll.end());
|
||||
|
||||
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_rocm")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = BuildAMDGPU(args[0], args[1]);
|
||||
});
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_ROCM_RUNTIME
|
||||
#endif // TVM_LLVM_VERSION
|
|
@ -100,7 +100,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
|
|||
Type t = arg.type();
|
||||
if (t.is_handle() && f->handle_data_type.count(arg)) {
|
||||
arg_type.push_back(
|
||||
LLVMType(f->handle_data_type[arg].type())->getPointerTo());
|
||||
LLVMType(f->handle_data_type[arg].type())->getPointerTo(GetGlobalAddressSpace()));
|
||||
if (!is_restricted_) {
|
||||
alias_var_set_.insert(arg.get());
|
||||
}
|
||||
|
@ -555,6 +555,10 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co
|
|||
return native_vector_bits_;
|
||||
}
|
||||
|
||||
unsigned CodeGenLLVM::GetGlobalAddressSpace() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CodeGenLLVM::GetAlignment(
|
||||
Type t, const Variable* buf_var, const Expr& index,
|
||||
int* p_alignment, int* p_native_bits) {
|
||||
|
|
|
@ -23,6 +23,7 @@ namespace codegen {
|
|||
|
||||
using namespace ir;
|
||||
|
||||
|
||||
/*!
|
||||
* \brief A base class to generate a LLVM.
|
||||
*/
|
||||
|
@ -148,6 +149,9 @@ class CodeGenLLVM :
|
|||
virtual void Optimize();
|
||||
// Get the maximim storage align bits of buffer pointer given storage scope.
|
||||
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
|
||||
// Get correct address space depending on the backend
|
||||
virtual unsigned GetGlobalAddressSpace();
|
||||
|
||||
void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
|
||||
// Create extern call
|
||||
llvm::CallInst* CreateCallExtern(llvm::Type* ret,
|
||||
|
|
|
@ -125,6 +125,8 @@ bool RuntimeEnabled(const std::string& target) {
|
|||
f_name = "device_api.vpi";
|
||||
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
|
||||
f_name = "codegen.build_nvptx";
|
||||
} else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
|
||||
f_name = "codegen.build_rocm";
|
||||
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
|
||||
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
|
||||
if (pf == nullptr) return false;
|
||||
|
|
|
@ -59,10 +59,17 @@ class ROCMModuleNode : public runtime::ModuleNode {
|
|||
stream->Write(data_);
|
||||
}
|
||||
|
||||
std::string GetSource(const std::string& format) final {
|
||||
if (format == fmt_) { return data_; }
|
||||
if (fmt_ == "hsaco") { return data_; }
|
||||
return "";
|
||||
}
|
||||
|
||||
// get a CUfunction from primary context in device_id
|
||||
hipFunction_t GetFunc(int device_id, const std::string& func_name) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// must recheck under the lock scope
|
||||
|
||||
if (module_[device_id] == nullptr) {
|
||||
ROCM_DRIVER_CALL(hipModuleLoadData(&(module_[device_id]), data_.c_str()));
|
||||
}
|
||||
|
@ -140,7 +147,9 @@ class ROCMWrappedFunc {
|
|||
if (fcache_[device_id] == nullptr) {
|
||||
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
|
||||
}
|
||||
|
||||
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
|
||||
|
||||
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
||||
void* config[] = {
|
||||
HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args,
|
||||
|
@ -181,7 +190,6 @@ PackedFunc ROCMModuleNode::GetFunction(
|
|||
CHECK_EQ(sptr_to_self.get(), this);
|
||||
CHECK_NE(name, symbol::tvm_module_main)
|
||||
<< "Device function do not have main";
|
||||
|
||||
auto it = fmap_.find(name);
|
||||
if (it == fmap_.end()) return PackedFunc();
|
||||
const FunctionInfo& info = it->second;
|
||||
|
|
|
@ -85,6 +85,8 @@ def test_gemm():
|
|||
np.testing.assert_allclose(
|
||||
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
|
||||
|
||||
check_device("nvptx -mcpu=sm_20")
|
||||
check_device("rocm")
|
||||
check_device("metal")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
|
|
|
@ -82,6 +82,7 @@ def test_add_pipeline():
|
|||
check_target("cuda", host="llvm")
|
||||
check_module_save("cuda", host="stackvm")
|
||||
check_target("nvptx", host="llvm")
|
||||
check_target("rocm", host="llvm")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_add_pipeline()
|
||||
|
|
Загрузка…
Ссылка в новой задаче