[CODEGEN] Enable inline llvm asm code (#1486)
This commit is contained in:
Родитель
55a08deca0
Коммит
f7d05b7ce2
|
@ -8,6 +8,12 @@ tvm.contrib.cblas
|
|||
:members:
|
||||
|
||||
|
||||
tvm.contrib.clang
|
||||
~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: tvm.contrib.clang
|
||||
:members:
|
||||
|
||||
|
||||
tvm.contrib.cc
|
||||
~~~~~~~~~~~~~~
|
||||
.. automodule:: tvm.contrib.cc
|
||||
|
|
|
@ -179,6 +179,8 @@ constexpr const char* loop_scope = "loop_scope";
|
|||
constexpr const char* reduce_scope = "reduce_scope";
|
||||
/*! \brief Mark region is guarded by the pragma extension */
|
||||
constexpr const char* pragma_scope_prefix = "pragma_";
|
||||
/*! \brief Import llvm source or file into the final code gen module */
|
||||
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
|
||||
/*!
|
||||
* \brief Mark of prefetch scope, value=offset,
|
||||
* run prefetch of Tensor on the current loop scope
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
"""Util to invoke clang in the system."""
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import as _abs
|
||||
import subprocess
|
||||
|
||||
from .._ffi.base import py_str
|
||||
from .. import codegen
|
||||
from . import util
|
||||
|
||||
|
||||
def find_clang(required=True):
|
||||
"""Find clang in system.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
required : bool
|
||||
Whether it is required,
|
||||
runtime error will be raised if the compiler is required.
|
||||
|
||||
Returns
|
||||
-------
|
||||
valid_list : list of str
|
||||
List of possible paths.
|
||||
|
||||
Note
|
||||
----
|
||||
This function will first search clang that
|
||||
matches the major llvm version that built with tvm
|
||||
"""
|
||||
cc_list = []
|
||||
if hasattr(codegen, "llvm_version_major"):
|
||||
cc_list += ["clang-%d.0" % codegen.llvm_version_major()]
|
||||
cc_list += ["clang"]
|
||||
valid_list = [util.which(x) for x in cc_list]
|
||||
valid_list = [x for x in valid_list if x]
|
||||
if not valid_list and required:
|
||||
raise RuntimeError(
|
||||
"cannot find clang, candidates are: " + str(cc_list))
|
||||
return valid_list
|
||||
|
||||
|
||||
def create_llvm(inputs,
|
||||
output=None,
|
||||
options=None,
|
||||
cc=None):
|
||||
"""Create llvm text ir.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : list of str
|
||||
List of input files name or code source.
|
||||
|
||||
output : str, optional
|
||||
Output file, if it is none
|
||||
a temporary file is created
|
||||
|
||||
options : list
|
||||
The list of additional options string.
|
||||
|
||||
cc : str, optional
|
||||
The clang compiler, if not specified,
|
||||
we will try to guess the matched clang version.
|
||||
|
||||
Returns
|
||||
-------
|
||||
code : str
|
||||
The generated llvm text IR.
|
||||
"""
|
||||
cc = cc if cc else find_clang()[0]
|
||||
cmd = [cc]
|
||||
cmd += ["-S", "-emit-llvm"]
|
||||
temp = util.tempdir()
|
||||
output = output if output else temp.relpath("output.ll")
|
||||
inputs = [inputs] if isinstance(inputs, str) else inputs
|
||||
input_files = []
|
||||
for i, code in enumerate(inputs):
|
||||
if util.is_source_path(code):
|
||||
input_files.append(code)
|
||||
else:
|
||||
temp_path = temp.relpath("input%d.cc" % i)
|
||||
with open(temp_path, "w") as output_file:
|
||||
output_file.write(code)
|
||||
input_files.append(temp_path)
|
||||
if options:
|
||||
cmd += options
|
||||
cmd += ["-o", output]
|
||||
cmd += input_files
|
||||
proc = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
(out, _) = proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
msg = "Compilation error:\n"
|
||||
msg += py_str(out)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return open(output).read()
|
|
@ -86,6 +86,7 @@ class FileLock(object):
|
|||
self.lock_file.close()
|
||||
self.lock_file = None
|
||||
|
||||
|
||||
def filelock(path):
|
||||
"""Create a file lock which locks on path
|
||||
|
||||
|
@ -99,3 +100,45 @@ def filelock(path):
|
|||
lock : File lock object
|
||||
"""
|
||||
return FileLock(path)
|
||||
|
||||
|
||||
def is_source_path(path):
|
||||
"""Check if path is source code path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
A possible path
|
||||
|
||||
Returns
|
||||
-------
|
||||
valid : bool
|
||||
Whether path is a possible source path
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
return True
|
||||
if path.find("\n") != -1:
|
||||
return False
|
||||
spath = path.rsplit(".", 1)
|
||||
return len(spath) == 2 and spath[1].strip() == spath[1]
|
||||
|
||||
|
||||
def which(exec_name):
|
||||
"""Try to find full path of exec_name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
exec_name : str
|
||||
The executable name
|
||||
|
||||
Returns
|
||||
-------
|
||||
path : str
|
||||
The full path of executable if found, otherwise returns None
|
||||
"""
|
||||
base_list = ["", "/bin"] + os.environ.get("PATH", "").split(os.pathsep)
|
||||
for path in base_list:
|
||||
full_path = os.path.join(path, exec_name)
|
||||
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
|
||||
return full_path
|
||||
return None
|
||||
|
|
|
@ -603,6 +603,8 @@ class Stage(NodeBase):
|
|||
:code:`for (int i = task_id; i < end; i += num_task)`
|
||||
|
||||
"""
|
||||
if isinstance(pragma_value, string_types):
|
||||
pragma_value = convert(pragma_value)
|
||||
_api_internal._StagePragma(self, var, pragma_type, pragma_value)
|
||||
|
||||
def prefetch(self, tensor, var, offset):
|
||||
|
|
|
@ -701,6 +701,11 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
|
|||
builder_->CreateCall(
|
||||
RuntimeTVMParallelBarrier(),
|
||||
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
|
||||
} else if (op->attr_key == ir::attr::pragma_import_llvm) {
|
||||
const StringImm* value = op->value.as<StringImm>();
|
||||
CHECK(value != nullptr);
|
||||
this->HandleImport(value->value);
|
||||
this->VisitStmt(op->body);
|
||||
} else {
|
||||
LOG(WARNING) << "Unknown pragma " << op->attr_key;
|
||||
this->VisitStmt(op->body);
|
||||
|
|
|
@ -159,6 +159,42 @@ std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
|
|||
return std::move(module_);
|
||||
}
|
||||
|
||||
|
||||
void CodeGenLLVM::HandleImport(const std::string& code) {
|
||||
std::unique_ptr<llvm::Module> mlib;
|
||||
llvm::SMDiagnostic err;
|
||||
if (code.length() >= 3 &&
|
||||
(code.substr(code.length() - 3) == ".ll" ||
|
||||
code.substr(code.length() - 3) == ".bc")) {
|
||||
mlib = llvm::parseIRFile(code, err, *ctx_);
|
||||
if (mlib.get() == nullptr) {
|
||||
std::string msg = err.getMessage();
|
||||
LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
|
||||
<< "line " << err.getLineNo() << ":" << msg;
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<llvm::MemoryBuffer> buf =
|
||||
llvm::MemoryBuffer::getMemBuffer(code);
|
||||
mlib = llvm::parseIR(*buf, err, *ctx_);
|
||||
if (mlib.get() == nullptr) {
|
||||
std::string msg = err.getMessage();
|
||||
LOG(FATAL) << "Fail to load llvm ir "
|
||||
<< "line " << err.getLineNo() << ":" << msg
|
||||
<< "\ncontent:\n" << code;
|
||||
}
|
||||
}
|
||||
mlib->setTargetTriple(target_machine_->getTargetTriple().str());
|
||||
mlib->setDataLayout(target_machine_->createDataLayout());
|
||||
// mark all the functions as force inline
|
||||
for (llvm::Function &f : mlib->functions()) {
|
||||
f.removeFnAttr(llvm::Attribute::NoInline);
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
|
||||
}
|
||||
// add to linker libraries.
|
||||
this->AddLinkModule(std::move(mlib));
|
||||
}
|
||||
|
||||
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
|
||||
link_modules_.emplace_back(std::move(mod));
|
||||
}
|
||||
|
|
|
@ -178,6 +178,8 @@ class CodeGenLLVM :
|
|||
// do a scalarize call with f
|
||||
llvm::Value* CreateScalarizedCall(
|
||||
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
|
||||
// handle module import
|
||||
void HandleImport(const std::string& code);
|
||||
// cast operatpr
|
||||
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
|
||||
// comparison op
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include <llvm/Transforms/IPO.h>
|
||||
|
||||
#include <llvm/Support/FileSystem.h>
|
||||
#include <llvm/Support/MemoryBuffer.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
#include <llvm/Support/Casting.h>
|
||||
#include <llvm/Support/TargetRegistry.h>
|
||||
|
|
|
@ -298,6 +298,13 @@ TVM_REGISTER_API("codegen.build_llvm")
|
|||
*rv = runtime::Module(n);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("codegen.llvm_version_major")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
std::ostringstream os;
|
||||
int major = TVM_LLVM_VERSION / 10;
|
||||
*rv = major;
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("module.loadfile_ll")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import tvm
|
||||
from tvm.contrib import util, clang
|
||||
import numpy as np
|
||||
import ctypes
|
||||
|
||||
|
@ -17,6 +18,47 @@ def test_llvm_intrin():
|
|||
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
|
||||
fcode = tvm.build(func, None, "llvm")
|
||||
|
||||
|
||||
def test_llvm_import():
|
||||
# extern "C" is necessary to get the correct signature
|
||||
cc_code = """
|
||||
extern "C" float my_add(float x, float y) {
|
||||
return x + y;
|
||||
}
|
||||
"""
|
||||
n = 10
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.compute((n,), lambda *i:
|
||||
tvm.call_pure_extern("float32", "my_add", A(*i), 1.0),
|
||||
name='B')
|
||||
def check_llvm(use_file):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
if not clang.find_clang(required=False):
|
||||
print("skip because clang is not available")
|
||||
return
|
||||
temp = util.tempdir()
|
||||
ll_path = temp.relpath("temp.ll")
|
||||
ll_code = clang.create_llvm(cc_code, output=ll_path)
|
||||
s = tvm.create_schedule(B.op)
|
||||
if use_file:
|
||||
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
|
||||
else:
|
||||
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
|
||||
# BUILD and invoke the kernel.
|
||||
f = tvm.build(s, [A, B], "llvm")
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
|
||||
f(a, b)
|
||||
np.testing.assert_allclose(
|
||||
b.asnumpy(), a.asnumpy() + 1.0)
|
||||
check_llvm(use_file=True)
|
||||
check_llvm(use_file=False)
|
||||
|
||||
|
||||
|
||||
def test_llvm_lookup_intrin():
|
||||
ib = tvm.ir_builder.create()
|
||||
m = tvm.var("m")
|
||||
|
@ -322,6 +364,7 @@ def test_alignment():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llvm_import()
|
||||
test_alignment()
|
||||
test_rank_zero()
|
||||
test_llvm_bool()
|
||||
|
|
Загрузка…
Ссылка в новой задаче