[CODEGEN] NVPTX backend. (#392)
* [CODEGEN] NVPTX backend. * Fix pylint * use fix
This commit is contained in:
Родитель
efafa1a0dd
Коммит
0560e1569e
|
@ -55,7 +55,10 @@ def context(dev_type, dev_id=0):
|
|||
assert tvm.context("cuda", 0) == tvm.gpu(0)
|
||||
"""
|
||||
if isinstance(dev_type, string_types):
|
||||
if not dev_type in TVMContext.STR2MASK:
|
||||
if dev_type not in TVMContext.STR2MASK:
|
||||
if dev_type.find("nvptx") != -1:
|
||||
dev_type = "cuda"
|
||||
if dev_type not in TVMContext.STR2MASK:
|
||||
raise ValueError("Unknown device type %s" % dev_type)
|
||||
dev_type = TVMContext.STR2MASK[dev_type]
|
||||
return TVMContext(dev_type, dev_id)
|
||||
|
|
|
@ -62,6 +62,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
|
|||
module_->setTargetTriple(tm->getTargetTriple().str());
|
||||
module_->setDataLayout(tm->createDataLayout());
|
||||
data_layout_.reset(new llvm::DataLayout(module_.get()));
|
||||
target_machine_ = tm;
|
||||
// initialize native vector bits
|
||||
std::string target = tm->getTarget().getName();
|
||||
if (target == "x86-64") {
|
||||
|
@ -86,6 +87,10 @@ void CodeGenLLVM::InitFuncState() {
|
|||
}
|
||||
|
||||
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
|
||||
AddFunctionInternal(f, false);
|
||||
}
|
||||
|
||||
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
|
||||
this->InitFuncState();
|
||||
is_restricted_ = f->is_restricted;
|
||||
CHECK(!module_->getFunction(f->name))
|
||||
|
@ -103,7 +108,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
|
|||
arg_type.push_back(LLVMType(t));
|
||||
}
|
||||
}
|
||||
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_type, false);
|
||||
|
||||
llvm::FunctionType* ftype = llvm::FunctionType::get(
|
||||
ret_void ? t_void_ : t_int_, arg_type, false);
|
||||
// setup the function.
|
||||
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
|
||||
function_->setCallingConv(llvm::CallingConv::C);
|
||||
|
@ -129,8 +136,13 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
|
|||
|
||||
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
|
||||
builder_->SetInsertPoint(block);
|
||||
|
||||
this->VisitStmt(f->body);
|
||||
builder_->CreateRet(ConstInt32(0));
|
||||
if (ret_void) {
|
||||
builder_->CreateRetVoid();
|
||||
} else {
|
||||
builder_->CreateRet(ConstInt32(0));
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
|
||||
|
@ -155,6 +167,9 @@ class MPassManager : public llvm::legacy::PassManager {
|
|||
}
|
||||
};
|
||||
|
||||
void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {
|
||||
}
|
||||
|
||||
void CodeGenLLVM::Optimize() {
|
||||
// place optimization pass
|
||||
llvm::PassManagerBuilder builder;
|
||||
|
@ -167,6 +182,12 @@ void CodeGenLLVM::Optimize() {
|
|||
#endif
|
||||
builder.LoopVectorize = true;
|
||||
builder.SLPVectorize = true;
|
||||
this->InitPassManagerBuilder(&builder);
|
||||
|
||||
#if TVM_LLVM_VERSION >= 50
|
||||
target_machine_->adjustPassManager(builder);
|
||||
#endif
|
||||
|
||||
// pass manager
|
||||
FPassManager fpass(module_.get());
|
||||
MPassManager mpass;
|
||||
|
@ -313,23 +334,29 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
|
|||
}
|
||||
}
|
||||
|
||||
llvm::CallInst* CodeGenLLVM::CreateCallExtern(
|
||||
llvm::Type* ret,
|
||||
const std::string& name,
|
||||
const std::vector<llvm::Value*>& arg_values) {
|
||||
std::vector<llvm::Type*> arg_types;
|
||||
for (llvm::Value* v : arg_values) {
|
||||
arg_types.push_back(v->getType());
|
||||
}
|
||||
llvm::FunctionType* ftype = llvm::FunctionType::get(ret, arg_types, false);
|
||||
llvm::Function* f = module_->getFunction(name);
|
||||
if (f == nullptr) {
|
||||
f = llvm::Function::Create(
|
||||
ftype, llvm::Function::ExternalLinkage, name, module_.get());
|
||||
}
|
||||
return builder_->CreateCall(f, arg_values);
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
|
||||
std::vector<llvm::Value*> arg_values(op->args.size());
|
||||
for (size_t i = 0; i < op->args.size(); ++i) {
|
||||
arg_values[i] = MakeValue(op->args[i]);
|
||||
}
|
||||
std::vector<llvm::Type*> arg_types;
|
||||
for (llvm::Value* v : arg_values) {
|
||||
arg_types.push_back(v->getType());
|
||||
}
|
||||
llvm::FunctionType* ftype = llvm::FunctionType::get(
|
||||
LLVMType(op->type), arg_types, false);
|
||||
llvm::Function* f = module_->getFunction(op->name);
|
||||
if (f == nullptr) {
|
||||
f = llvm::Function::Create(
|
||||
ftype, llvm::Function::ExternalLinkage, op->name, module_.get());
|
||||
}
|
||||
return builder_->CreateCall(f, arg_values);
|
||||
return CreateCallExtern(LLVMType(op->type), op->name, arg_values);
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
|
||||
|
@ -437,6 +464,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
|
|||
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
|
||||
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id);
|
||||
return builder_->CreateCall(f, arg_values);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
|
||||
return CreateStorageSync(op);
|
||||
} else if (op->is_intrinsic(Call::bitwise_and)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
return builder_->CreateAnd(
|
||||
|
@ -510,7 +539,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
int CodeGenLLVM::NativeVectorBits(const std::string& storage_scope) const {
|
||||
// Get the corresponding thread index
|
||||
llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
|
||||
LOG(FATAL) << "Donot support threading " << iv;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
|
||||
LOG(FATAL) << "Donot support storage sync in CPU mode";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) const {
|
||||
// By default, we ask the buffer to be aligned to 64 bytes
|
||||
return native_vector_bits_;
|
||||
}
|
||||
|
@ -855,7 +895,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
|
|||
ramp->base, make_const(ramp->base.type(), offset));
|
||||
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
|
||||
llvm::Type* vtype = llvm::VectorType::get(
|
||||
LLVMType(t.element_of()), lanes)->getPointerTo();
|
||||
LLVMType(t.element_of()), lanes)->getPointerTo(
|
||||
ptr->getType()->getPointerAddressSpace());
|
||||
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
|
||||
builder_->CreatePointerCast(ptr, vtype), alignment);
|
||||
AddAliasInfo(inst, op->buffer_var.get(),
|
||||
|
@ -971,7 +1012,8 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
|
|||
ramp->base, make_const(ramp->base.type(), offset));
|
||||
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
|
||||
llvm::Type* vtype = llvm::VectorType::get(
|
||||
LLVMType(t.element_of()), lanes)->getPointerTo();
|
||||
LLVMType(t.element_of()), lanes)->getPointerTo(
|
||||
ptr->getType()->getPointerAddressSpace());
|
||||
llvm::StoreInst* inst = builder_->CreateAlignedStore(
|
||||
CreateVecSlice(value, offset, lanes),
|
||||
builder_->CreatePointerCast(ptr, vtype), alignment);
|
||||
|
@ -1069,17 +1111,28 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
|
|||
}
|
||||
info.alignment = alloca->getAlignment();
|
||||
}
|
||||
buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo());
|
||||
buf = builder_->CreatePointerCast(
|
||||
buf, LLVMType(op->type)->getPointerTo(
|
||||
buf->getType()->getPointerAddressSpace()));
|
||||
CHECK(!var_map_.count(op->buffer_var.get()));
|
||||
var_map_[op->buffer_var.get()] = buf;
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
|
||||
if (op->attr_key == ir::attr::storage_scope) {
|
||||
if (op->attr_key == ir::attr::thread_extent) {
|
||||
IterVar iv(op->node.node_);
|
||||
if (iv->thread_tag.length() != 0) {
|
||||
if (!var_map_.count(iv->var.get())) {
|
||||
var_map_[iv->var.get()] = GetThreadIndex(iv);
|
||||
}
|
||||
}
|
||||
this->VisitStmt(op->body);
|
||||
} else if (op->attr_key == ir::attr::storage_scope) {
|
||||
const Variable* v = op->node.as<Variable>();
|
||||
CHECK(v);
|
||||
alloc_storage_info_[v].scope = op->value.as<StringImm>()->value;
|
||||
alloc_storage_info_[v].scope = runtime::StorageScope::make(
|
||||
op->value.as<StringImm>()->value);
|
||||
this->VisitStmt(op->body);
|
||||
} else if (op->attr_key == ir::attr::storage_alignment) {
|
||||
const Variable* v = op->node.as<Variable>();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include "./llvm_common.h"
|
||||
#include "../../runtime/thread_storage_scope.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
@ -116,22 +117,29 @@ class CodeGenLLVM :
|
|||
void VisitStmt_(const Block* op) override;
|
||||
void VisitStmt_(const Evaluate* op) override;
|
||||
void VisitStmt_(const ProducerConsumer* op) override;
|
||||
// create intrinstic given call
|
||||
virtual llvm::Value* CreateIntrinsic(const Call* op);
|
||||
// create extern function call
|
||||
virtual llvm::Value* CreateCallExtern(const Call* op);
|
||||
// Scalarize e by iterating elements of e.
|
||||
// f is a callback that takes index and v.
|
||||
virtual void Scalarize(const Expr& e,
|
||||
std::function<void(int i, llvm::Value* v)> f);
|
||||
|
||||
protected:
|
||||
/*! \brief The storage information */
|
||||
struct StorageInfo {
|
||||
/*! \brief The storage scope */
|
||||
std::string scope;
|
||||
runtime::StorageScope scope;
|
||||
/*! \brief The alignment of allocation */
|
||||
int alignment{0};
|
||||
};
|
||||
// create intrinstic given call
|
||||
virtual llvm::Value* CreateIntrinsic(const Call* op);
|
||||
// create extern function call
|
||||
virtual llvm::Value* CreateCallExtern(const Call* op);
|
||||
// Get the corresponding thread index
|
||||
virtual llvm::Value* GetThreadIndex(const IterVar& iv);
|
||||
// Get the corresponding thread index
|
||||
virtual llvm::Value* CreateStorageSync(const Call* op);
|
||||
// apply optimization on the module.
|
||||
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
|
||||
// Scalarize by iterating elements of e.
|
||||
// f is a callback that takes index and v.
|
||||
virtual void Scalarize(const Expr& e,
|
||||
std::function<void(int i, llvm::Value* v)> f);
|
||||
// Initialize target
|
||||
virtual void InitTarget(llvm::TargetMachine* tm);
|
||||
// Add module startup function if needed.
|
||||
|
@ -139,7 +147,12 @@ class CodeGenLLVM :
|
|||
// apply optimization on the module.
|
||||
virtual void Optimize();
|
||||
// Get the maximim storage align bits of buffer pointer given storage scope.
|
||||
virtual int NativeVectorBits(const std::string& storage_scope) const;
|
||||
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
|
||||
void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
|
||||
// Create extern call
|
||||
llvm::CallInst* CreateCallExtern(llvm::Type* ret,
|
||||
const std::string& name,
|
||||
const std::vector<llvm::Value*>& value);
|
||||
/*!
|
||||
* \param t The original type.
|
||||
* \return LLVM type of t
|
||||
|
@ -192,6 +205,8 @@ class CodeGenLLVM :
|
|||
std::unique_ptr<llvm::DataLayout> data_layout_;
|
||||
// Internal metabuilder
|
||||
std::unique_ptr<llvm::MDBuilder> md_builder_;
|
||||
// llvm target machine
|
||||
llvm::TargetMachine* target_machine_{nullptr};
|
||||
// llvm context
|
||||
llvm::LLVMContext* ctx_{nullptr};
|
||||
// helpful data types
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file codegen_nvptx.cc
|
||||
* \brief NVPTX code generator.
|
||||
*/
|
||||
#ifdef TVM_LLVM_VERSION
|
||||
#if TVM_CUDA_RUNTIME
|
||||
|
||||
#include <tvm/runtime/device_api.h>
|
||||
#include "./codegen_llvm.h"
|
||||
#include "../build_common.h"
|
||||
#include "../../pass/ir_util.h"
|
||||
#include "../../runtime/cuda/cuda_module.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace codegen {
|
||||
|
||||
// NVPTX code generator.
|
||||
class CodeGenNVPTX : public CodeGenLLVM {
|
||||
public:
|
||||
void AddFunction(const LoweredFunc& f) final {
|
||||
// add function as void return value
|
||||
CodeGenLLVM::AddFunctionInternal(f, true);
|
||||
// annotate as kernel function
|
||||
module_->getOrInsertNamedMetadata("nvvm.annotations")
|
||||
->addOperand(llvm::MDNode::get(*ctx_, {
|
||||
llvm::ValueAsMetadata::get(function_),
|
||||
llvm::MDString::get(*ctx_, "kernel"),
|
||||
llvm::ValueAsMetadata::get(ConstInt32(1)) }));
|
||||
}
|
||||
|
||||
void VisitStmt_(const Allocate* op) final {
|
||||
CHECK(!is_zero(op->condition));
|
||||
llvm::Value* buf = nullptr;
|
||||
if (op->new_expr.defined()) {
|
||||
CHECK_EQ(op->free_function, "nop");
|
||||
buf = MakeValue(op->new_expr);
|
||||
} else {
|
||||
int32_t constant_size = op->constant_allocation_size();
|
||||
CHECK_GT(constant_size, 0)
|
||||
<< "Can only handle constant size stack allocation in GPU";
|
||||
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
|
||||
if (constant_size % 4 == 0 && info.alignment == 0) {
|
||||
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
|
||||
}
|
||||
// maximum necessary alignment in the NV devices
|
||||
if (info.alignment > 16) {
|
||||
info.alignment = 16;
|
||||
}
|
||||
if (info.scope.rank == 2) {
|
||||
// const int local_address_space = 5;
|
||||
// TODO(tqchen): for higher version of LLVM, local address space can be set.
|
||||
llvm::AllocaInst* alloca = builder_->CreateAlloca(
|
||||
LLVMType(op->type), ConstInt32(constant_size));
|
||||
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
|
||||
alloca->setAlignment(info.alignment);
|
||||
}
|
||||
buf = alloca;
|
||||
} else {
|
||||
CHECK_EQ(info.scope.rank, 1)
|
||||
<< "Can only allocate shared or local memory inside kernel";
|
||||
// Shared memory: address space == 3
|
||||
const unsigned shared_address_space = 3;
|
||||
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
|
||||
// Allocate shared memory in global, address_space = 3
|
||||
llvm::GlobalVariable *global = new llvm::GlobalVariable(
|
||||
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
|
||||
nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
|
||||
global->setAlignment(info.alignment);
|
||||
buf = global;
|
||||
}
|
||||
}
|
||||
buf = builder_->CreatePointerCast(
|
||||
buf, LLVMType(op->type)->getPointerTo(
|
||||
buf->getType()->getPointerAddressSpace()));
|
||||
CHECK(!var_map_.count(op->buffer_var.get()));
|
||||
var_map_[op->buffer_var.get()] = buf;
|
||||
this->VisitStmt(op->body);
|
||||
}
|
||||
|
||||
// Return the thread index via intrinsics.
|
||||
llvm::Value* GetThreadIndex(const IterVar& iv) final {
|
||||
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
|
||||
llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
|
||||
if (ts.rank == 1) {
|
||||
switch (ts.dim_index) {
|
||||
case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break;
|
||||
case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break;
|
||||
case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break;
|
||||
default: LOG(FATAL) << "unknown thread idx";
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(ts.rank, 0);
|
||||
switch (ts.dim_index) {
|
||||
case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break;
|
||||
case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break;
|
||||
case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break;
|
||||
default: LOG(FATAL) << "unknown thread idx";
|
||||
}
|
||||
}
|
||||
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id);
|
||||
return builder_->CreateCall(f, {});
|
||||
}
|
||||
|
||||
llvm::Value* CreateStorageSync(const Call* op) final {
|
||||
const std::string& sync = op->args[0].as<StringImm>()->value;
|
||||
if (sync == "warp") {
|
||||
// TODO(tqchen) warp sync in CUDA9
|
||||
return nullptr;
|
||||
} else if (sync == "shared") {
|
||||
llvm::Function* f = llvm::Intrinsic::getDeclaration(
|
||||
module_.get(),
|
||||
::llvm::Intrinsic::nvvm_barrier0);
|
||||
return builder_->CreateCall(f, {});
|
||||
} else {
|
||||
LOG(FATAL) << "Do not support sync " << sync;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void InitPassManagerBuilder(llvm::PassManagerBuilder* builder) final {
|
||||
// Additional optimization hook to tweak the builder.
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitTarget(llvm::TargetMachine* tm) final {
|
||||
// Maximum vector lane = float4
|
||||
native_vector_bits_ = 4 * 32;
|
||||
CodeGenLLVM::InitTarget(tm);
|
||||
}
|
||||
};
|
||||
|
||||
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
|
||||
CHECK(target.length(
|
||||
) >= 5 &&
|
||||
target.substr(0, 5) == "nvptx");
|
||||
llvm::TargetMachine* tm = GetLLVMTargetMachine(
|
||||
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" +
|
||||
target.substr(5, target.length() - 5));
|
||||
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
|
||||
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
|
||||
cg->Init(funcs[0]->name, tm, ctx.get(), false, false);
|
||||
for (LoweredFunc f : funcs) {
|
||||
cg->AddFunction(f);
|
||||
}
|
||||
std::unique_ptr<llvm::Module> module = cg->Finish();
|
||||
llvm::SmallString<8> data;
|
||||
llvm::raw_svector_ostream dest(data);
|
||||
dest.SetUnbuffered();
|
||||
llvm::legacy::PassManager pass;
|
||||
CHECK(tm->addPassesToEmitFile(
|
||||
pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
|
||||
<< "Cannot emit target CGFT_ObjectFile";
|
||||
pass.run(*module);
|
||||
std::string ptx(data.begin(), data.end());
|
||||
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), "");
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("codegen.build_nvptx")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = BuildNVPTX(args[0], args[1]);
|
||||
});
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace tvm
|
||||
#endif // TVM_CUDA_RUNTIME
|
||||
#endif // TVM_LLVM_VERSION
|
|
@ -37,53 +37,56 @@ void InitializeLLVM() {
|
|||
}
|
||||
|
||||
llvm::TargetMachine*
|
||||
GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
|
||||
GetLLVMTargetMachine(const std::string& target_str,
|
||||
bool allow_null) {
|
||||
// setup target triple
|
||||
CHECK(target_str.length() >= 4 &&
|
||||
target_str.substr(0, 4) == "llvm")
|
||||
<< "llvm target must starts with llvm";
|
||||
size_t start = 0;
|
||||
if (target_str.length() >= 4 &&
|
||||
target_str.substr(0, 4) == "llvm") {
|
||||
start = 4;
|
||||
}
|
||||
// simple parser
|
||||
std::string target_triple = "";
|
||||
std::string cpu = "generic";
|
||||
std::string attr = "";
|
||||
bool soft_float_abi = false;
|
||||
std::string key, value;
|
||||
if (target_str.length() > 5) {
|
||||
std::istringstream is(target_str.substr(5, target_str.length() - 5));
|
||||
while (is >> key) {
|
||||
if (key == "--system-lib" || key == "-system-lib") {
|
||||
continue;
|
||||
}
|
||||
size_t pos = key.find('=');
|
||||
if (pos != std::string::npos) {
|
||||
CHECK_GE(key.length(), pos + 1)
|
||||
<< "inavlid argument " << key;
|
||||
value = key.substr(pos + 1, key.length() - 1);
|
||||
key = key.substr(0, pos);
|
||||
std::istringstream is(target_str.substr(start, target_str.length() - start));
|
||||
|
||||
while (is >> key) {
|
||||
if (key == "--system-lib" || key == "-system-lib") {
|
||||
continue;
|
||||
}
|
||||
size_t pos = key.find('=');
|
||||
if (pos != std::string::npos) {
|
||||
CHECK_GE(key.length(), pos + 1)
|
||||
<< "inavlid argument " << key;
|
||||
value = key.substr(pos + 1, key.length() - 1);
|
||||
key = key.substr(0, pos);
|
||||
} else {
|
||||
CHECK(is >> value)
|
||||
<< "Unspecified value for option " << key;
|
||||
}
|
||||
if (key == "-target" ||
|
||||
key == "-mtriple") {
|
||||
target_triple = value;
|
||||
} else if (key == "-mcpu") {
|
||||
cpu = value;
|
||||
} else if (key == "-mattr") {
|
||||
attr = value;
|
||||
} else if (key == "-mfloat-abi") {
|
||||
if (value == "hard") {
|
||||
soft_float_abi = false;
|
||||
} else if (value == "soft") {
|
||||
soft_float_abi = true;
|
||||
} else {
|
||||
CHECK(is >> value)
|
||||
<< "Unspecified value for option " << key;
|
||||
}
|
||||
if (key == "-target" ||
|
||||
key == "-mtriple") {
|
||||
target_triple = value;
|
||||
} else if (key == "-mcpu") {
|
||||
cpu = value;
|
||||
} else if (key == "-mattr") {
|
||||
attr = value;
|
||||
} else if (key == "-mfloat-abi") {
|
||||
if (value == "hard") {
|
||||
soft_float_abi = false;
|
||||
} else if (value == "soft") {
|
||||
soft_float_abi = true;
|
||||
} else {
|
||||
LOG(FATAL) << "invalid -mfloat-abi option " << value;
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "unknown option " << key;
|
||||
LOG(FATAL) << "invalid -mfloat-abi option " << value;
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "unknown option " << key;
|
||||
}
|
||||
}
|
||||
|
||||
if (target_triple.length() == 0 ||
|
||||
target_triple == "default") {
|
||||
target_triple = llvm::sys::getDefaultTargetTriple();
|
||||
|
@ -109,9 +112,8 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null) {
|
|||
} else {
|
||||
opt.FloatABIType = llvm::FloatABI::Hard;
|
||||
}
|
||||
auto rmodel = llvm::Reloc::PIC_;
|
||||
llvm::TargetMachine* tm =
|
||||
target->createTargetMachine(target_triple, cpu, attr, opt, rmodel);
|
||||
llvm::TargetMachine* tm = target->createTargetMachine(
|
||||
target_triple, cpu, attr, opt, llvm::Reloc::PIC_);
|
||||
return tm;
|
||||
}
|
||||
|
||||
|
|
|
@ -112,6 +112,8 @@ bool RuntimeEnabled(const std::string& target) {
|
|||
f_name = "device_api.rpc";
|
||||
} else if (target == "vpi" || target == "verilog") {
|
||||
f_name = "device_api.vpi";
|
||||
} else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
|
||||
f_name = "codegen.build_nvptx";
|
||||
} else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
|
||||
const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
|
||||
if (pf == nullptr) return false;
|
||||
|
|
|
@ -84,6 +84,7 @@ def test_gemm():
|
|||
np.testing.assert_allclose(
|
||||
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
|
||||
|
||||
check_device("nvptx -mcpu=sm_20")
|
||||
check_device("metal")
|
||||
check_device("opencl")
|
||||
check_device("cuda")
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_add_pipeline():
|
|||
check_target("cuda", host="stackvm")
|
||||
check_target("cuda", host="llvm")
|
||||
check_module_save("cuda", host="stackvm")
|
||||
|
||||
check_target("nvptx", host="llvm")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_add_pipeline()
|
||||
|
|
Загрузка…
Ссылка в новой задаче