[CODEGEN] Enable inline llvm asm code (#1486)

This commit is contained in:
Tianqi Chen 2018-07-25 09:30:23 -07:00 коммит произвёл GitHub
Родитель 55a08deca0
Коммит f7d05b7ce2
11 изменённых файлов: 243 добавлений и 0 удалений

Просмотреть файл

@ -8,6 +8,12 @@ tvm.contrib.cblas
:members: :members:
tvm.contrib.clang
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.clang
:members:
tvm.contrib.cc tvm.contrib.cc
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.cc .. automodule:: tvm.contrib.cc

Просмотреть файл

@ -179,6 +179,8 @@ constexpr const char* loop_scope = "loop_scope";
constexpr const char* reduce_scope = "reduce_scope"; constexpr const char* reduce_scope = "reduce_scope";
/*! \brief Mark region is guarded by the pragma extension */ /*! \brief Mark region is guarded by the pragma extension */
constexpr const char* pragma_scope_prefix = "pragma_"; constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import llvm source or file into the final code gen module */
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*! /*!
* \brief Mark of prefetch scope, value=offset, * \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope * run prefetch of Tensor on the current loop scope

Просмотреть файл

@ -0,0 +1,96 @@
"""Util to invoke clang in the system."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import subprocess
from .._ffi.base import py_str
from .. import codegen
from . import util
def find_clang(required=True):
"""Find clang in system.
Parameters
----------
required : bool
Whether it is required,
runtime error will be raised if the compiler is required.
Returns
-------
valid_list : list of str
List of possible paths.
Note
----
This function will first search clang that
matches the major llvm version that built with tvm
"""
cc_list = []
if hasattr(codegen, "llvm_version_major"):
cc_list += ["clang-%d.0" % codegen.llvm_version_major()]
cc_list += ["clang"]
valid_list = [util.which(x) for x in cc_list]
valid_list = [x for x in valid_list if x]
if not valid_list and required:
raise RuntimeError(
"cannot find clang, candidates are: " + str(cc_list))
return valid_list
def create_llvm(inputs,
output=None,
options=None,
cc=None):
"""Create llvm text ir.
Parameters
----------
inputs : list of str
List of input files name or code source.
output : str, optional
Output file, if it is none
a temporary file is created
options : list
The list of additional options string.
cc : str, optional
The clang compiler, if not specified,
we will try to guess the matched clang version.
Returns
-------
code : str
The generated llvm text IR.
"""
cc = cc if cc else find_clang()[0]
cmd = [cc]
cmd += ["-S", "-emit-llvm"]
temp = util.tempdir()
output = output if output else temp.relpath("output.ll")
inputs = [inputs] if isinstance(inputs, str) else inputs
input_files = []
for i, code in enumerate(inputs):
if util.is_source_path(code):
input_files.append(code)
else:
temp_path = temp.relpath("input%d.cc" % i)
with open(temp_path, "w") as output_file:
output_file.write(code)
input_files.append(temp_path)
if options:
cmd += options
cmd += ["-o", output]
cmd += input_files
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "Compilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)
return open(output).read()

Просмотреть файл

@ -86,6 +86,7 @@ class FileLock(object):
self.lock_file.close() self.lock_file.close()
self.lock_file = None self.lock_file = None
def filelock(path): def filelock(path):
"""Create a file lock which locks on path """Create a file lock which locks on path
@ -99,3 +100,45 @@ def filelock(path):
lock : File lock object lock : File lock object
""" """
return FileLock(path) return FileLock(path)
def is_source_path(path):
"""Check if path is source code path.
Parameters
----------
path : str
A possible path
Returns
-------
valid : bool
Whether path is a possible source path
"""
if os.path.exists(path):
return True
if path.find("\n") != -1:
return False
spath = path.rsplit(".", 1)
return len(spath) == 2 and spath[1].strip() == spath[1]
def which(exec_name):
"""Try to find full path of exec_name
Parameters
----------
exec_name : str
The executable name
Returns
-------
path : str
The full path of executable if found, otherwise returns None
"""
base_list = ["", "/bin"] + os.environ.get("PATH", "").split(os.pathsep)
for path in base_list:
full_path = os.path.join(path, exec_name)
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
return full_path
return None

Просмотреть файл

@ -603,6 +603,8 @@ class Stage(NodeBase):
:code:`for (int i = task_id; i < end; i += num_task)` :code:`for (int i = task_id; i < end; i += num_task)`
""" """
if isinstance(pragma_value, string_types):
pragma_value = convert(pragma_value)
_api_internal._StagePragma(self, var, pragma_type, pragma_value) _api_internal._StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset): def prefetch(self, tensor, var, offset):

Просмотреть файл

@ -701,6 +701,11 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
builder_->CreateCall( builder_->CreateCall(
RuntimeTVMParallelBarrier(), RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv}); {MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == ir::attr::pragma_import_llvm) {
const StringImm* value = op->value.as<StringImm>();
CHECK(value != nullptr);
this->HandleImport(value->value);
this->VisitStmt(op->body);
} else { } else {
LOG(WARNING) << "Unknown pragma " << op->attr_key; LOG(WARNING) << "Unknown pragma " << op->attr_key;
this->VisitStmt(op->body); this->VisitStmt(op->body);

Просмотреть файл

@ -159,6 +159,42 @@ std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
return std::move(module_); return std::move(module_);
} }
void CodeGenLLVM::HandleImport(const std::string& code) {
std::unique_ptr<llvm::Module> mlib;
llvm::SMDiagnostic err;
if (code.length() >= 3 &&
(code.substr(code.length() - 3) == ".ll" ||
code.substr(code.length() - 3) == ".bc")) {
mlib = llvm::parseIRFile(code, err, *ctx_);
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load bitcode file " << code << "\n"
<< "line " << err.getLineNo() << ":" << msg;
}
} else {
std::unique_ptr<llvm::MemoryBuffer> buf =
llvm::MemoryBuffer::getMemBuffer(code);
mlib = llvm::parseIR(*buf, err, *ctx_);
if (mlib.get() == nullptr) {
std::string msg = err.getMessage();
LOG(FATAL) << "Fail to load llvm ir "
<< "line " << err.getLineNo() << ":" << msg
<< "\ncontent:\n" << code;
}
}
mlib->setTargetTriple(target_machine_->getTargetTriple().str());
mlib->setDataLayout(target_machine_->createDataLayout());
// mark all the functions as force inline
for (llvm::Function &f : mlib->functions()) {
f.removeFnAttr(llvm::Attribute::NoInline);
f.addFnAttr(llvm::Attribute::AlwaysInline);
f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage);
}
// add to linker libraries.
this->AddLinkModule(std::move(mlib));
}
void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) { void CodeGenLLVM::AddLinkModule(std::unique_ptr<llvm::Module>&& mod) {
link_modules_.emplace_back(std::move(mod)); link_modules_.emplace_back(std::move(mod));
} }

Просмотреть файл

@ -178,6 +178,8 @@ class CodeGenLLVM :
// do a scalarize call with f // do a scalarize call with f
llvm::Value* CreateScalarizedCall( llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args); const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// handle module import
void HandleImport(const std::string& code);
// cast operatpr // cast operatpr
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
// comparison op // comparison op

Просмотреть файл

@ -34,6 +34,7 @@
#include <llvm/Transforms/IPO.h> #include <llvm/Transforms/IPO.h>
#include <llvm/Support/FileSystem.h> #include <llvm/Support/FileSystem.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/raw_ostream.h> #include <llvm/Support/raw_ostream.h>
#include <llvm/Support/Casting.h> #include <llvm/Support/Casting.h>
#include <llvm/Support/TargetRegistry.h> #include <llvm/Support/TargetRegistry.h>

Просмотреть файл

@ -298,6 +298,13 @@ TVM_REGISTER_API("codegen.build_llvm")
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
TVM_REGISTER_API("codegen.llvm_version_major")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::ostringstream os;
int major = TVM_LLVM_VERSION / 10;
*rv = major;
});
TVM_REGISTER_API("module.loadfile_ll") TVM_REGISTER_API("module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>(); std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();

Просмотреть файл

@ -1,4 +1,5 @@
import tvm import tvm
from tvm.contrib import util, clang
import numpy as np import numpy as np
import ctypes import ctypes
@ -17,6 +18,47 @@ def test_llvm_intrin():
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
def test_llvm_import():
# extern "C" is necessary to get the correct signature
cc_code = """
extern "C" float my_add(float x, float y) {
return x + y;
}
"""
n = 10
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda *i:
tvm.call_pure_extern("float32", "my_add", A(*i), 1.0),
name='B')
def check_llvm(use_file):
if not tvm.module.enabled("llvm"):
return
if not clang.find_clang(required=False):
print("skip because clang is not available")
return
temp = util.tempdir()
ll_path = temp.relpath("temp.ll")
ll_code = clang.create_llvm(cc_code, output=ll_path)
s = tvm.create_schedule(B.op)
if use_file:
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
else:
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
# BUILD and invoke the kernel.
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 1.0)
check_llvm(use_file=True)
check_llvm(use_file=False)
def test_llvm_lookup_intrin(): def test_llvm_lookup_intrin():
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
m = tvm.var("m") m = tvm.var("m")
@ -322,6 +364,7 @@ def test_alignment():
if __name__ == "__main__": if __name__ == "__main__":
test_llvm_import()
test_alignment() test_alignment()
test_rank_zero() test_rank_zero()
test_llvm_bool() test_llvm_bool()