[CODEGEN] Enable inline llvm asm code (#1486)
This commit is contained in:
Родитель
55a08deca0
Коммит
f7d05b7ce2
|
@ -8,6 +8,12 @@ tvm.contrib.cblas
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
tvm.contrib.clang
|
||||||
|
~~~~~~~~~~~~~~~~~
|
||||||
|
.. automodule:: tvm.contrib.clang
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
tvm.contrib.cc
|
tvm.contrib.cc
|
||||||
~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~
|
||||||
.. automodule:: tvm.contrib.cc
|
.. automodule:: tvm.contrib.cc
|
||||||
|
|
|
@ -179,6 +179,8 @@ constexpr const char* loop_scope = "loop_scope";
|
||||||
constexpr const char* reduce_scope = "reduce_scope";
|
constexpr const char* reduce_scope = "reduce_scope";
|
||||||
/*! \brief Mark region is guarded by the pragma extension */
|
/*! \brief Mark region is guarded by the pragma extension */
|
||||||
constexpr const char* pragma_scope_prefix = "pragma_";
|
constexpr const char* pragma_scope_prefix = "pragma_";
|
||||||
|
/*! \brief Import llvm source or file into the final code gen module */
|
||||||
|
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
|
||||||
/*!
|
/*!
|
||||||
* \brief Mark of prefetch scope, value=offset,
|
* \brief Mark of prefetch scope, value=offset,
|
||||||
* run prefetch of Tensor on the current loop scope
|
* run prefetch of Tensor on the current loop scope
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
"""Util to invoke clang in the system."""
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
from __future__ import absolute_import as _abs
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from .._ffi.base import py_str
|
||||||
|
from .. import codegen
|
||||||
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
|
def find_clang(required=True):
|
||||||
|
"""Find clang in system.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
required : bool
|
||||||
|
Whether it is required,
|
||||||
|
runtime error will be raised if the compiler is required.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
valid_list : list of str
|
||||||
|
List of possible paths.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
This function will first search clang that
|
||||||
|
matches the major llvm version that built with tvm
|
||||||
|
"""
|
||||||
|
cc_list = []
|
||||||
|
if hasattr(codegen, "llvm_version_major"):
|
||||||
|
cc_list += ["clang-%d.0" % codegen.llvm_version_major()]
|
||||||
|
cc_list += ["clang"]
|
||||||
|
valid_list = [util.which(x) for x in cc_list]
|
||||||
|
valid_list = [x for x in valid_list if x]
|
||||||
|
if not valid_list and required:
|
||||||
|
raise RuntimeError(
|
||||||
|
"cannot find clang, candidates are: " + str(cc_list))
|
||||||
|
return valid_list
|
||||||
|
|
||||||
|
|
||||||
|
def create_llvm(inputs,
|
||||||
|
output=None,
|
||||||
|
options=None,
|
||||||
|
cc=None):
|
||||||
|
"""Create llvm text ir.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
inputs : list of str
|
||||||
|
List of input files name or code source.
|
||||||
|
|
||||||
|
output : str, optional
|
||||||
|
Output file, if it is none
|
||||||
|
a temporary file is created
|
||||||
|
|
||||||
|
options : list
|
||||||
|
The list of additional options string.
|
||||||
|
|
||||||
|
cc : str, optional
|
||||||
|
The clang compiler, if not specified,
|
||||||
|
we will try to guess the matched clang version.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
code : str
|
||||||
|
The generated llvm text IR.
|
||||||
|
"""
|
||||||
|
cc = cc if cc else find_clang()[0]
|
||||||
|
cmd = [cc]
|
||||||
|
cmd += ["-S", "-emit-llvm"]
|
||||||
|
temp = util.tempdir()
|
||||||
|
output = output if output else temp.relpath("output.ll")
|
||||||
|
inputs = [inputs] if isinstance(inputs, str) else inputs
|
||||||
|
input_files = []
|
||||||
|
for i, code in enumerate(inputs):
|
||||||
|
if util.is_source_path(code):
|
||||||
|
input_files.append(code)
|
||||||
|
else:
|
||||||
|
temp_path = temp.relpath("input%d.cc" % i)
|
||||||
|
with open(temp_path, "w") as output_file:
|
||||||
|
output_file.write(code)
|
||||||
|
input_files.append(temp_path)
|
||||||
|
if options:
|
||||||
|
cmd += options
|
||||||
|
cmd += ["-o", output]
|
||||||
|
cmd += input_files
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
|
(out, _) = proc.communicate()
|
||||||
|
if proc.returncode != 0:
|
||||||
|
msg = "Compilation error:\n"
|
||||||
|
msg += py_str(out)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
return open(output).read()
|
|
@ -86,6 +86,7 @@ class FileLock(object):
|
||||||
self.lock_file.close()
|
self.lock_file.close()
|
||||||
self.lock_file = None
|
self.lock_file = None
|
||||||
|
|
||||||
|
|
||||||
def filelock(path):
|
def filelock(path):
|
||||||
"""Create a file lock which locks on path
|
"""Create a file lock which locks on path
|
||||||
|
|
||||||
|
@ -99,3 +100,45 @@ def filelock(path):
|
||||||
lock : File lock object
|
lock : File lock object
|
||||||
"""
|
"""
|
||||||
return FileLock(path)
|
return FileLock(path)
|
||||||
|
|
||||||
|
|
||||||
|
def is_source_path(path):
|
||||||
|
"""Check if path is source code path.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : str
|
||||||
|
A possible path
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
valid : bool
|
||||||
|
Whether path is a possible source path
|
||||||
|
"""
|
||||||
|
if os.path.exists(path):
|
||||||
|
return True
|
||||||
|
if path.find("\n") != -1:
|
||||||
|
return False
|
||||||
|
spath = path.rsplit(".", 1)
|
||||||
|
return len(spath) == 2 and spath[1].strip() == spath[1]
|
||||||
|
|
||||||
|
|
||||||
|
def which(exec_name):
|
||||||
|
"""Try to find full path of exec_name
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
exec_name : str
|
||||||
|
The executable name
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
path : str
|
||||||
|
The full path of executable if found, otherwise returns None
|
||||||
|
"""
|
||||||
|
base_list = ["", "/bin"] + os.environ.get("PATH", "").split(os.pathsep)
|
||||||
|
for path in base_list:
|
||||||
|
full_path = os.path.join(path, exec_name)
|
||||||
|
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
|
||||||
|
return full_path
|
||||||
|
return None
|
||||||
|
|
|
@ -603,6 +603,8 @@ class Stage(NodeBase):
|
||||||
:code:`for (int i = task_id; i < end; i += num_task)`
|
:code:`for (int i = task_id; i < end; i += num_task)`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(pragma_value, string_types):
|
||||||
|
pragma_value = convert(pragma_value)
|
||||||
_api_internal._StagePragma(self, var, pragma_type, pragma_value)
|
_api_internal._StagePragma(self, var, pragma_type, pragma_value)
|
||||||
|
|
||||||
def prefetch(self, tensor, var, offset):
|
def prefetch(self, tensor, var, offset):
|
||||||
|
|
|
@ -701,6 +701,11 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
|
||||||
builder_->CreateCall(
|
builder_->CreateCall(
|
||||||
RuntimeTVMParallelBarrier(),
|
RuntimeTVMParallelBarrier(),
|
||||||
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
|
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
|
||||||
|
} else if (op->attr_key == ir::attr::pragma_import_llvm) {
|
||||||
|
const StringImm* value = op->value.as<StringImm>();
|
||||||
|
CHECK(value != nullptr);
|
||||||
|
this->HandleImport(value->value);
|
||||||
|
this->VisitStmt(op->body);
|
||||||
} else {
|
} else {
|
||||||
LOG(WARNING) << "Unknown pragma " << op->attr_key;
|
LOG(WARNING) << "Unknown pragma " << op->attr_key;
|
||||||
this->VisitStmt(op->body);
|
this->VisitStmt(op->body);
|
||||||
|
|
|
@ -159,6 +159,42 @@ std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
|
||||||
return std::move(module_);
|
return std::move(module_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void CodeGenLLVM::HandleImport(const std::string& code) {
|
||||||
|
std::unique_ptr<llvm::Module> mlib;
|
||||||
|
llvm::SMDiagnostic err;
|
||||||
|
if (code.length() >= 3 &&
|
||||||
|
(code.substr(code.length() - 3) == ".ll" ||
|
||||||
|
code.substr(code.length() - 3) == ".bc")) {
|
||||||
|
mlib = llvm::parseIRFile(code, err, *ctx_);
|
||||||
|
if (mlib.get() == nullptr) {
|
||||||
|
std::string msg = err.getMessage();
|
||||||
|
LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
|
||||||
|
<< "line " << err.getLineNo() << ":" << msg;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::unique_ptr<llvm::MemoryBuffer> buf =
|
||||||
|
llvm::MemoryBuffer::getMemBuffer(code);
|
||||||
|
mlib = llvm::parseIR(*buf, err, *ctx_);
|
||||||
|
if (mlib.get() == nullptr) {
|
||||||
|
std::string msg = err.getMessage();
|
||||||
|
LOG(FATAL) << "Fail to load llvm ir "
|
||||||
|
<< "line " << err.getLineNo() << ":" << msg
|
||||||
|
<< "\ncontent:\n" << code;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mlib->setTargetTriple(target_machine_->getTargetTriple().str());
|
||||||
|
mlib->setDataLayout(target_machine_->createDataLayout());
|
||||||
|
// mark all the functions as force inline
|
||||||
|
for (llvm::Function &f : mlib->functions()) {
|
||||||
|
f.removeFnAttr(llvm::Attribute::NoInline);
|
||||||
|
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||||
|
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
|
||||||
|
}
|
||||||
|
// add to linker libraries.
|
||||||
|
this->AddLinkModule(std::move(mlib));
|
||||||
|
}
|
||||||
|
|
||||||
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
|
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
|
||||||
link_modules_.emplace_back(std::move(mod));
|
link_modules_.emplace_back(std::move(mod));
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,6 +178,8 @@ class CodeGenLLVM :
|
||||||
// do a scalarize call with f
|
// do a scalarize call with f
|
||||||
llvm::Value* CreateScalarizedCall(
|
llvm::Value* CreateScalarizedCall(
|
||||||
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
|
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
|
||||||
|
// handle module import
|
||||||
|
void HandleImport(const std::string& code);
|
||||||
// cast operatpr
|
// cast operatpr
|
||||||
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
|
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
|
||||||
// comparison op
|
// comparison op
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
#include <llvm/Transforms/IPO.h>
|
#include <llvm/Transforms/IPO.h>
|
||||||
|
|
||||||
#include <llvm/Support/FileSystem.h>
|
#include <llvm/Support/FileSystem.h>
|
||||||
|
#include <llvm/Support/MemoryBuffer.h>
|
||||||
#include <llvm/Support/raw_ostream.h>
|
#include <llvm/Support/raw_ostream.h>
|
||||||
#include <llvm/Support/Casting.h>
|
#include <llvm/Support/Casting.h>
|
||||||
#include <llvm/Support/TargetRegistry.h>
|
#include <llvm/Support/TargetRegistry.h>
|
||||||
|
|
|
@ -298,6 +298,13 @@ TVM_REGISTER_API("codegen.build_llvm")
|
||||||
*rv = runtime::Module(n);
|
*rv = runtime::Module(n);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API("codegen.llvm_version_major")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
std::ostringstream os;
|
||||||
|
int major = TVM_LLVM_VERSION / 10;
|
||||||
|
*rv = major;
|
||||||
|
});
|
||||||
|
|
||||||
TVM_REGISTER_API("module.loadfile_ll")
|
TVM_REGISTER_API("module.loadfile_ll")
|
||||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||||
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
|
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import tvm
|
import tvm
|
||||||
|
from tvm.contrib import util, clang
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
|
@ -17,6 +18,47 @@ def test_llvm_intrin():
|
||||||
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
|
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
|
||||||
fcode = tvm.build(func, None, "llvm")
|
fcode = tvm.build(func, None, "llvm")
|
||||||
|
|
||||||
|
|
||||||
|
def test_llvm_import():
|
||||||
|
# extern "C" is necessary to get the correct signature
|
||||||
|
cc_code = """
|
||||||
|
extern "C" float my_add(float x, float y) {
|
||||||
|
return x + y;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
n = 10
|
||||||
|
A = tvm.placeholder((n,), name='A')
|
||||||
|
B = tvm.compute((n,), lambda *i:
|
||||||
|
tvm.call_pure_extern("float32", "my_add", A(*i), 1.0),
|
||||||
|
name='B')
|
||||||
|
def check_llvm(use_file):
|
||||||
|
if not tvm.module.enabled("llvm"):
|
||||||
|
return
|
||||||
|
if not clang.find_clang(required=False):
|
||||||
|
print("skip because clang is not available")
|
||||||
|
return
|
||||||
|
temp = util.tempdir()
|
||||||
|
ll_path = temp.relpath("temp.ll")
|
||||||
|
ll_code = clang.create_llvm(cc_code, output=ll_path)
|
||||||
|
s = tvm.create_schedule(B.op)
|
||||||
|
if use_file:
|
||||||
|
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
|
||||||
|
else:
|
||||||
|
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
|
||||||
|
# BUILD and invoke the kernel.
|
||||||
|
f = tvm.build(s, [A, B], "llvm")
|
||||||
|
ctx = tvm.cpu(0)
|
||||||
|
# launch the kernel.
|
||||||
|
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||||
|
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
|
||||||
|
f(a, b)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
b.asnumpy(), a.asnumpy() + 1.0)
|
||||||
|
check_llvm(use_file=True)
|
||||||
|
check_llvm(use_file=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_llvm_lookup_intrin():
|
def test_llvm_lookup_intrin():
|
||||||
ib = tvm.ir_builder.create()
|
ib = tvm.ir_builder.create()
|
||||||
m = tvm.var("m")
|
m = tvm.var("m")
|
||||||
|
@ -322,6 +364,7 @@ def test_alignment():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_llvm_import()
|
||||||
test_alignment()
|
test_alignment()
|
||||||
test_rank_zero()
|
test_rank_zero()
|
||||||
test_llvm_bool()
|
test_llvm_bool()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче