Add shuffle support to TVM (#3633)
This commit is contained in:
Родитель
9ae01e0b07
Коммит
a279dd0e58
|
@ -199,6 +199,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
|
|||
IR_EXPR_FUNCTOR_DISPATCH(Not);
|
||||
IR_EXPR_FUNCTOR_DISPATCH(Select);
|
||||
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
|
||||
IR_EXPR_FUNCTOR_DISPATCH(Shuffle);
|
||||
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
|
||||
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
|
||||
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
|
||||
|
|
|
@ -131,6 +131,7 @@ class TVM_DLL IRVisitor {
|
|||
virtual void Visit_(const Not* op);
|
||||
virtual void Visit_(const Select* op);
|
||||
virtual void Visit_(const Ramp* op);
|
||||
virtual void Visit_(const Shuffle* op);
|
||||
virtual void Visit_(const Broadcast* op);
|
||||
virtual void Visit_(const AssertStmt* op);
|
||||
virtual void Visit_(const ProducerConsumer* op);
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#define TVM_CODEGEN_BUILD_COMMON_H_
|
||||
|
||||
#include <tvm/codegen.h>
|
||||
#include <tvm/ir.h>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include "../runtime/meta_data.h"
|
||||
|
|
|
@ -728,6 +728,10 @@ void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
|
|||
os << "))";
|
||||
}
|
||||
|
||||
void CodeGenC::VisitExpr_(const Shuffle* op, std::ostream& os) {
|
||||
LOG(FATAL) << "Shuffle: not supported ";
|
||||
}
|
||||
|
||||
void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
|
||||
LOG(FATAL) << "Broadcast: not supported ";
|
||||
}
|
||||
|
|
|
@ -126,6 +126,7 @@ class CodeGenC :
|
|||
void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
|
||||
void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
|
||||
void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
|
||||
void VisitExpr_(const Shuffle* op, std::ostream& os) override; // NOLINT(*)
|
||||
void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
|
||||
void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
|
||||
void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
|
||||
|
|
|
@ -205,7 +205,7 @@ void CodeGenCUDA::PrintVecBinaryOp(
|
|||
|
||||
void CodeGenCUDA::PrintVecElemLoad(
|
||||
const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*)
|
||||
const char access[] = {'x', 'y', 'z', 'w'};
|
||||
static const char access[] = {'x', 'y', 'z', 'w'};
|
||||
CHECK(i >= 0 && i < 4);
|
||||
os << vec << "." << access[i];
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ void CodeGenCUDA::PrintVecElemLoad(
|
|||
void CodeGenCUDA::PrintVecElemStore(
|
||||
const std::string& vec, Type t, int i, const std::string& value) {
|
||||
this->PrintIndent();
|
||||
const char access[] = {'x', 'y', 'z', 'w'};
|
||||
static const char access[] = {'x', 'y', 'z', 'w'};
|
||||
CHECK(i >= 0 && i < 4);
|
||||
stream << vec << "." << access[i] << " = " << value << ";\n";
|
||||
}
|
||||
|
@ -308,7 +308,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
|
|||
std::string v = PrintExpr(op->value);
|
||||
os << "make_";
|
||||
PrintType(op->type, os);
|
||||
os << "(";
|
||||
os << '(';
|
||||
for (int i = 0; i < op->lanes; ++i) {
|
||||
if (i != 0) os << ", ";
|
||||
os << v;
|
||||
|
@ -316,6 +316,23 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
|
|||
os << ')';
|
||||
}
|
||||
|
||||
void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
|
||||
std::vector<std::string> to_shuffle(op->vectors.size());
|
||||
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
|
||||
CHECK(op->vectors[i].type().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
|
||||
to_shuffle[i] = PrintExpr(op->vectors[i]);
|
||||
}
|
||||
os << "make_";
|
||||
PrintType(op->type, os);
|
||||
os << '(';
|
||||
for (int i = 0, e = op->indices.size(); i < e; ++i) {
|
||||
const int64_t *val = as_const_int(op->indices[i]);
|
||||
CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size());
|
||||
if (i != 0) os << ", ";
|
||||
os << to_shuffle[*val];
|
||||
}
|
||||
os << ')';
|
||||
}
|
||||
|
||||
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
|
||||
switch (op->type.bits()) {
|
||||
|
|
|
@ -57,6 +57,7 @@ class CodeGenCUDA final : public CodeGenC {
|
|||
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
|
||||
// overload visitor
|
||||
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
|
||||
void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
|
||||
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
|
||||
void VisitExpr_(const FloatImm *op, std::ostream& os) final;
|
||||
void VisitStmt_(const Evaluate *op) final;
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
|
||||
#include "codegen_llvm.h"
|
||||
#include "codegen_cpu.h"
|
||||
#include "../build_common.h"
|
||||
#include "../../pass/ir_util.h"
|
||||
#include "../../arithmetic/compute_expr.h"
|
||||
|
||||
|
@ -446,6 +447,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
|
|||
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
|
||||
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
|
||||
if (extent == num_elems && begin == 0) return vec;
|
||||
CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
|
||||
std::vector<llvm::Constant*> indices;
|
||||
indices.reserve(extent);
|
||||
for (int i = 0; i < extent; ++i) {
|
||||
|
@ -481,6 +483,7 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
|
|||
llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
|
||||
// concat vector, tree shape reduction
|
||||
int total_lanes = 0;
|
||||
|
||||
for (llvm::Value* v : vecs) {
|
||||
total_lanes += static_cast<int>(
|
||||
v->getType()->getVectorNumElements());
|
||||
|
@ -652,12 +655,14 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
|
|||
CHECK_GE(op->args.size(), 2U);
|
||||
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
|
||||
op->args[0].as<UIntImm>()->value);
|
||||
uint64_t num_signature = op->args[1].as<UIntImm>()->value;
|
||||
const uint64_t *num_signature = as_const_uint(op->args[1]);
|
||||
CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
|
||||
<< "but " << op->args[1] << " got!\n";
|
||||
std::vector<llvm::Value*> arg_value;
|
||||
std::vector<llvm::Type*> sig_type;
|
||||
for (size_t i = 2; i < op->args.size(); ++i) {
|
||||
arg_value.push_back(MakeValue(op->args[i]));
|
||||
if (i - 2 < num_signature) {
|
||||
if (i - 2 < *num_signature) {
|
||||
sig_type.push_back(arg_value.back()->getType());
|
||||
}
|
||||
}
|
||||
|
@ -1002,6 +1007,26 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
|
|||
return vec;
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
|
||||
std::vector<llvm::Value *> vecs(op->vectors.size());
|
||||
int total_lanes = 0;
|
||||
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
|
||||
vecs[i] = VisitExpr(op->vectors[i]);
|
||||
total_lanes += op->vectors[i].type().lanes();
|
||||
}
|
||||
llvm::Value* v0 = CreateVecConcat(vecs);
|
||||
std::vector<uint32_t> idx(op->indices.size());
|
||||
for (int i = 0, e = op->indices.size(); i < e; ++i) {
|
||||
const int64_t *val = as_const_int(op->indices[i]);
|
||||
CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, "
|
||||
<< "but get " << op->indices[i] << "\n";
|
||||
idx[i] = *val;
|
||||
}
|
||||
llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
|
||||
auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
|
||||
return CreateBroadcast(MakeValue(op->value), op->lanes);
|
||||
}
|
||||
|
|
|
@ -131,6 +131,7 @@ class CodeGenLLVM :
|
|||
llvm::Value* VisitExpr_(const Load* op) override;
|
||||
llvm::Value* VisitExpr_(const Call* op) override;
|
||||
llvm::Value* VisitExpr_(const Ramp* op) override;
|
||||
llvm::Value* VisitExpr_(const Shuffle* op) override;
|
||||
llvm::Value* VisitExpr_(const Broadcast* op) override;
|
||||
// stmt
|
||||
void VisitStmt_(const Store* op) override;
|
||||
|
|
|
@ -177,6 +177,13 @@ void IRVisitor::Visit_(const Ramp *op) {
|
|||
this->Visit(op->stride);
|
||||
}
|
||||
|
||||
void IRVisitor::Visit_(const Shuffle *op) {
|
||||
for (const auto &elem : op->indices)
|
||||
this->Visit(elem);
|
||||
for (const auto &elem : op->vectors)
|
||||
this->Visit(elem);
|
||||
}
|
||||
|
||||
void IRVisitor::Visit_(const Broadcast *op) {
|
||||
this->Visit(op->value);
|
||||
}
|
||||
|
@ -269,6 +276,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
|
|||
.DISPATCH_TO_VISIT(Not)
|
||||
.DISPATCH_TO_VISIT(Select)
|
||||
.DISPATCH_TO_VISIT(Ramp)
|
||||
.DISPATCH_TO_VISIT(Shuffle)
|
||||
.DISPATCH_TO_VISIT(Broadcast)
|
||||
.DISPATCH_TO_VISIT(AssertStmt)
|
||||
.DISPATCH_TO_VISIT(ProducerConsumer)
|
||||
|
|
|
@ -154,9 +154,53 @@ def test_cuda_inf_nan():
|
|||
check_inf_nan(ctx, 1, float('nan'), 'float64')
|
||||
|
||||
|
||||
def test_cuda_shuffle():
|
||||
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
|
||||
print("skip because cuda is not enabled..")
|
||||
return
|
||||
|
||||
a = tvm.placeholder((64, ), 'int32')
|
||||
b = tvm.placeholder((64, ), 'int32')
|
||||
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
|
||||
sch = tvm.create_schedule(c.op)
|
||||
x = c.op.axis[0]
|
||||
xo, xi = sch[c].split(x, 4)
|
||||
thrx = tvm.thread_axis("threadIdx.x")
|
||||
sch[c].bind(xo, thrx)
|
||||
sch[c].vectorize(xi)
|
||||
|
||||
def my_vectorize(stmt):
|
||||
def vectorizer(op):
|
||||
if op.for_type == tvm.stmt.For.Vectorized:
|
||||
four = tvm.const(4, 'int32')
|
||||
idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
|
||||
all_ones = tvm.const(1, 'int32x4')
|
||||
store = op.body
|
||||
value = store.value
|
||||
new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones)
|
||||
bs, ids = [], []
|
||||
for i in range(4):
|
||||
bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
|
||||
ids.append(tvm.const(3 - i, 'int32'))
|
||||
new_b = tvm.make.Shuffle(bs, ids)
|
||||
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
|
||||
return None
|
||||
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
|
||||
|
||||
with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
|
||||
module = tvm.build(sch, [a, b, c], target='cuda')
|
||||
a_ = np.array(list(range(64)), dtype='int32')
|
||||
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
|
||||
c_ = np.zeros((64, ), dtype='int32')
|
||||
ref = a_ + np.array((list(range(4))) * 16, dtype='int32')
|
||||
nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
|
||||
module(nda, ndb, ndc)
|
||||
tvm.testing.assert_allclose(ndc.asnumpy(), ref)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cuda_vectorize_add()
|
||||
test_cuda_multiply_add()
|
||||
test_cuda_vectorize_load()
|
||||
test_cuda_make_int8x4()
|
||||
test_cuda_inf_nan()
|
||||
test_cuda_shuffle()
|
||||
|
|
|
@ -548,6 +548,37 @@ def test_dwarf_debug_information():
|
|||
check_llvm_object()
|
||||
check_llvm_ir()
|
||||
|
||||
|
||||
def test_llvm_shuffle():
|
||||
a = tvm.placeholder((8, ), 'int32')
|
||||
b = tvm.placeholder((8, ), 'int32')
|
||||
c = tvm.compute((8, ), lambda x: a[x] + b[7-x])
|
||||
sch = tvm.create_schedule(c.op)
|
||||
|
||||
def my_vectorize(stmt):
|
||||
|
||||
def vectorizer(op):
|
||||
store = op.body
|
||||
idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
|
||||
all_ones = tvm.const(1, 'int32x8')
|
||||
value = store.value
|
||||
b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
|
||||
new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones)
|
||||
new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
|
||||
value = new_a + new_b
|
||||
return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
|
||||
|
||||
return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
|
||||
|
||||
with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
|
||||
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
|
||||
module = tvm.build(sch, [a, b, c])
|
||||
a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
|
||||
b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
|
||||
c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
|
||||
module(a_, b_, c_)
|
||||
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llvm_import()
|
||||
test_alignment()
|
||||
|
@ -567,3 +598,4 @@ if __name__ == "__main__":
|
|||
test_llvm_div()
|
||||
test_llvm_fp_math()
|
||||
test_dwarf_debug_information()
|
||||
test_llvm_shuffle()
|
||||
|
|
Загрузка…
Ссылка в новой задаче