[IR] Update new version of HalideIR (#116)
This commit is contained in:
Родитель
d3c8256b5a
Коммит
330d49f81c
2
HalideIR
2
HalideIR
|
@ -1 +1 @@
|
|||
Subproject commit 398edacd956c6de82185821ffd9f482598182e51
|
||||
Subproject commit 4fffc62c124651c1cde18f31957db413b677d601
|
|
@ -171,6 +171,14 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
|
|||
|
||||
/*! \brief namespace of TVM Intrinsic functions */
|
||||
namespace intrinsic {
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
* Handle tvm_address_of(Load *op) {
|
||||
* return &op->buffer_var[index];
|
||||
* }
|
||||
*/
|
||||
constexpr const char* tvm_address_of = "tvm_address_of";
|
||||
/*!
|
||||
* \brief See pesudo code
|
||||
*
|
||||
|
@ -355,6 +363,7 @@ using Halide::Internal::Realize;
|
|||
using Halide::Internal::Block;
|
||||
using Halide::Internal::IfThenElse;
|
||||
using Halide::Internal::Evaluate;
|
||||
using Halide::Internal::Shuffle;
|
||||
// ir functions
|
||||
using Halide::Internal::is_const_power_of_two_integer;
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ class IRMutator {
|
|||
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
|
||||
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
|
||||
virtual Expr Mutate_(const StringImm* op, const Expr& e);
|
||||
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#define TVM_IR_PASS_H_
|
||||
|
||||
#include <ir/IREquality.h>
|
||||
#include <pass/Simplify.h>
|
||||
#include <arithmetic/Simplify.h>
|
||||
#include <tvm/ir_functor.h>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
|
|
@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For")
|
|||
args[5]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("make.Load")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Type t = args[0];
|
||||
if (args.size() == 3) {
|
||||
*ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
|
||||
} else {
|
||||
*ret = Load::make(t, args[1], args[2], args[3]);
|
||||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("make.Store")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Expr value = args[1];
|
||||
if (args.size() == 3) {
|
||||
*ret = Store::make(args[0], value, args[2], const_true(value.type().lanes()));
|
||||
} else {
|
||||
*ret = Store::make(args[0], value, args[2], args[3]);
|
||||
}
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("make.Realize")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = Realize::make(args[0],
|
||||
|
@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call")
|
|||
args[5]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("make.Allocate")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = Allocate::make(args[0],
|
||||
args[1],
|
||||
args[2],
|
||||
args[3],
|
||||
args[4]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_API("make.CommReducer")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = CommReducerNode::make(args[0], args[1], args[2]);
|
||||
|
@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer")
|
|||
*ret = Node::make(args[0], args[1], args[2], args[3]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_MAKE5(Node) \
|
||||
TVM_REGISTER_API("make."#Node) \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
|
||||
}) \
|
||||
|
||||
#define REGISTER_MAKE_BINARY_OP(Node) \
|
||||
TVM_REGISTER_API("make."#Node) \
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) { \
|
||||
|
@ -125,8 +142,7 @@ REGISTER_MAKE3(Let);
|
|||
REGISTER_MAKE3(LetStmt);
|
||||
REGISTER_MAKE2(AssertStmt);
|
||||
REGISTER_MAKE3(ProducerConsumer);
|
||||
REGISTER_MAKE3(Load);
|
||||
REGISTER_MAKE3(Store);
|
||||
REGISTER_MAKE5(Allocate);
|
||||
REGISTER_MAKE4(Provide);
|
||||
REGISTER_MAKE1(Free);
|
||||
REGISTER_MAKE2(Block);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
|
||||
|
||||
#include <tvm/ir.h>
|
||||
#include <pass/Interval.h>
|
||||
#include <arithmetic/Interval.h>
|
||||
#include <limits>
|
||||
|
||||
namespace tvm {
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include <tvm/ir.h>
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/arithmetic.h>
|
||||
#include <pass/Interval.h>
|
||||
#include <arithmetic/Interval.h>
|
||||
#include <unordered_map>
|
||||
#include "./compute_expr.h"
|
||||
#include "./int_set_internal.h"
|
||||
|
|
|
@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
|
|||
PrintBinaryIntrinsitc(op, " << ", os, this);
|
||||
} else if (op->is_intrinsic(Call::shift_right)) {
|
||||
PrintBinaryIntrinsitc(op, " >> ", os, this);
|
||||
} else if (op->is_intrinsic(Call::address_of)) {
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
|
||||
const Load *l = op->args[0].as<Load>();
|
||||
CHECK(op->args.size() == 1 && l);
|
||||
os << "((";
|
||||
|
@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
|
|||
std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
|
||||
os << ref;
|
||||
} else {
|
||||
CHECK(is_one(op->predicate))
|
||||
<< "predicated load is not supported";
|
||||
Expr base;
|
||||
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
|
||||
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
|
||||
|
@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) {
|
|||
this->PrintIndent();
|
||||
stream << ref << " = " << value << ";\n";
|
||||
} else {
|
||||
CHECK(is_one(op->predicate))
|
||||
<< "Predicated store is not supported";
|
||||
Expr base;
|
||||
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
|
||||
std::string value = this->PrintExpr(op->value);
|
||||
|
|
|
@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
|
|||
return builder_->CreateLShr(
|
||||
MakeValue(op->args[0]), MakeValue(op->args[1]));
|
||||
}
|
||||
} else if (op->is_intrinsic(Call::address_of)) {
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
|
||||
const Load *l = op->args[0].as<Load>();
|
||||
CHECK(op->args.size() == 1 && l);
|
||||
return CreateBufferPtr(
|
||||
|
@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
|
|||
} else {
|
||||
LOG(FATAL) << "Unknown stack alloca type " << type;
|
||||
}
|
||||
} else if (op->is_intrinsic(Call::null_handle)) {
|
||||
} else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
|
||||
return llvm::Constant::getNullValue(t_void_p_);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown intrinstic " << op->name;
|
||||
|
@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(
|
|||
}
|
||||
|
||||
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
|
||||
CHECK(is_one(op->predicate))
|
||||
<< "Predicated Load is not supported";
|
||||
Type t = op->type;
|
||||
const Ramp* ramp = op->index.as<Ramp>();
|
||||
llvm::Value* buf = GetVarValue(op->buffer_var.get());
|
||||
|
@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
|
|||
t, op->buffer_var,
|
||||
Ramp::make(arith::ComputeExpr<Add>(
|
||||
ramp->base, make_const(bt, first_shift)),
|
||||
make_const(bt, 1), ramp->lanes)));
|
||||
make_const(bt, 1), ramp->lanes),
|
||||
const_true(t.lanes())));
|
||||
llvm::Value* next = MakeValue(Load::make(
|
||||
t, op->buffer_var,
|
||||
Ramp::make(arith::ComputeExpr<Add>(
|
||||
ramp->base, make_const(bt, ramp->lanes + next_shift)),
|
||||
make_const(bt, 1), ramp->lanes)));
|
||||
make_const(bt, 1), ramp->lanes),
|
||||
const_true(t.lanes())));
|
||||
// shuffle
|
||||
std::vector<llvm::Constant*> indices;
|
||||
int target_index = 0;
|
||||
|
@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
|
|||
make_const(ramp->base.type(), 1),
|
||||
lanes);
|
||||
// load value then flip
|
||||
llvm::Value* v = MakeValue(Load::make(t, op->buffer_var, neg_ramp));
|
||||
llvm::Value* v = MakeValue(
|
||||
Load::make(t, op->buffer_var, neg_ramp, const_true(t.lanes())));
|
||||
return CreateVecFlip(v);
|
||||
} else {
|
||||
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
|
||||
|
@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
|
|||
|
||||
// stmts
|
||||
void CodeGenLLVM::VisitStmt_(const Store* op) {
|
||||
CHECK(is_one(op->predicate))
|
||||
<< "Predicated Load is not supported";
|
||||
llvm::Value* value = MakeValue(op->value);
|
||||
Type t = op->value.type();
|
||||
const Ramp* ramp = op->index.as<Ramp>();
|
||||
|
|
|
@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
|
|||
}
|
||||
|
||||
void CodeGenStackVM::VisitExpr_(const Call* op) {
|
||||
if (op->is_intrinsic(Call::address_of)) {
|
||||
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
|
||||
const Load *l = op->args[0].as<Load>();
|
||||
CHECK(op->args.size() == 1 && l);
|
||||
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
|
||||
|
@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
|
|||
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
|
||||
this->PushOp(StackVM::MUL_I64);
|
||||
this->PushOp(StackVM::ADDR_ADD);
|
||||
} else if (op->is_intrinsic(Call::null_handle)) {
|
||||
this->PushOp(StackVM::PUSH_I64, 0);
|
||||
} else if (op->is_intrinsic(Call::reinterpret)) {
|
||||
this->Push(op->args[0]);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
|
||||
CHECK_EQ(op->args.size(), 3U);
|
||||
int kind = op->args[2].as<IntImm>()->value;
|
||||
|
|
|
@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor {
|
|||
if (is_zero(op->index) && load) {
|
||||
compute->body = Store::make(
|
||||
op->buffer_var,
|
||||
Load::make(load->type, load->buffer_var, repl.Mutate(load->index)),
|
||||
op->index);
|
||||
Load::make(load->type, load->buffer_var,
|
||||
repl.Mutate(load->index), op->predicate),
|
||||
op->index, op->predicate);
|
||||
} else {
|
||||
compute->body = Store::make(
|
||||
op->buffer_var, repl.Mutate(op->value), repl.Mutate(op->index));
|
||||
op->buffer_var, repl.Mutate(op->value),
|
||||
repl.Mutate(op->index), op->predicate);
|
||||
}
|
||||
compute->inputs = repl.inputs_;
|
||||
pipeline_->stages.push_back(ComputeBlock(compute));
|
||||
|
|
|
@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
|
|||
|
||||
Expr Buffer::MakeLoad(Array<Expr> index) const {
|
||||
const BufferNode* n = operator->();
|
||||
return ir::Load::make(n->dtype, n->data, BufferOffset(n, index));
|
||||
return ir::Load::make(
|
||||
n->dtype, n->data, BufferOffset(n, index),
|
||||
const_true(n->dtype.lanes()));
|
||||
}
|
||||
|
||||
Stmt Buffer::MakeStore(Array<Expr> index, Expr value) const {
|
||||
const BufferNode* n = operator->();
|
||||
CHECK_EQ(value.type(), n->dtype);
|
||||
return ir::Store::make(n->data, value, BufferOffset(n, index));
|
||||
return ir::Store::make(n->data, value, BufferOffset(n, index),
|
||||
const_true(n->dtype.lanes()));
|
||||
}
|
||||
|
||||
Buffer BufferNode::make(std::string name,
|
||||
|
|
|
@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction(
|
|||
}
|
||||
}
|
||||
}
|
||||
Type t = reduce->type;
|
||||
Expr pred = const_true(t.lanes());
|
||||
Stmt reduce_body = Store::make(res_handle,
|
||||
Call::make(
|
||||
reduce->type,
|
||||
ir::intrinsic::tvm_thread_allreduce,
|
||||
freduce_args, Call::Intrinsic),
|
||||
0);
|
||||
0, pred);
|
||||
reduce_body = AttrStmt::make(
|
||||
reduce->combiner,
|
||||
attr::reduce_scope,
|
||||
make_zero(reduce->type),
|
||||
reduce_body);
|
||||
Stmt assign_body = Provide::make(
|
||||
stage->op, 0, Load::make(reduce->type, res_handle, 0), args);
|
||||
stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args);
|
||||
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
|
||||
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
|
||||
Stmt body = Allocate::make(
|
||||
|
|
|
@ -152,11 +152,7 @@ class VTInjector : public IRMutator {
|
|||
return e;
|
||||
}
|
||||
Expr RewriteIndex(Expr index, Expr alloc_extent) const {
|
||||
if (index_rewrite_strategy_ == 0) {
|
||||
return index * num_threads_ + var_;
|
||||
} else {
|
||||
return index + var_ * alloc_extent;
|
||||
}
|
||||
return index + var_ * alloc_extent;
|
||||
}
|
||||
// Load
|
||||
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||
|
@ -168,7 +164,8 @@ class VTInjector : public IRMutator {
|
|||
auto it = touched_alloc_.find(op->buffer_var.get());
|
||||
if (it != touched_alloc_.end()) {
|
||||
return Load::make(op->type, op->buffer_var,
|
||||
RewriteIndex(op->index, it->second));
|
||||
RewriteIndex(op->index, it->second),
|
||||
op->predicate);
|
||||
} else {
|
||||
return expr;
|
||||
}
|
||||
|
@ -184,7 +181,8 @@ class VTInjector : public IRMutator {
|
|||
if (it != touched_alloc_.end()) {
|
||||
return Store::make(op->buffer_var,
|
||||
op->value,
|
||||
RewriteIndex(op->index, it->second));
|
||||
RewriteIndex(op->index, it->second),
|
||||
op->predicate);
|
||||
} else {
|
||||
return stmt;
|
||||
}
|
||||
|
@ -307,6 +305,9 @@ class VTInjector : public IRMutator {
|
|||
for (size_t i = 1; i < extents.size(); ++i) {
|
||||
stride = arith::ComputeExpr<Mul>(stride, extents[i]);
|
||||
}
|
||||
if (op->type.lanes() != 0) {
|
||||
stride = stride * op->type.lanes();
|
||||
}
|
||||
Array<Expr> other;
|
||||
other.push_back(num_threads_);
|
||||
for (Expr e : extents) {
|
||||
|
@ -368,8 +369,6 @@ class VTInjector : public IRMutator {
|
|||
Var var_;
|
||||
// the threads/lanes
|
||||
int num_threads_;
|
||||
// Index rewriting strategy
|
||||
int index_rewrite_strategy_{1};
|
||||
// whethe the loop is already injected.
|
||||
bool vt_loop_injected_{false};
|
||||
// whether current expression get touched.
|
||||
|
|
|
@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
|
|||
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
|
||||
Expr value = this->Mutate(op->value);
|
||||
Expr index = this->Mutate(op->index);
|
||||
if (value.same_as(op->value) && index.same_as(op->index)) {
|
||||
Expr pred = this->Mutate(op->predicate);
|
||||
if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) {
|
||||
return s;
|
||||
} else {
|
||||
return Store::make(op->buffer_var, value, index);
|
||||
return Store::make(op->buffer_var, value, index, pred);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
|
|||
|
||||
Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
|
||||
Expr index = this->Mutate(op->index);
|
||||
if (index.same_as(op->index)) {
|
||||
Expr pred = this->Mutate(op->predicate);
|
||||
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
|
||||
return e;
|
||||
} else {
|
||||
return Load::make(op->type, op->buffer_var, index);
|
||||
return Load::make(op->type, op->buffer_var, index, pred);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
|
|||
}
|
||||
}
|
||||
|
||||
Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) {
|
||||
auto new_vec = MutateArray(op->vectors, this);
|
||||
if (new_vec.same_as(op->vectors)) {
|
||||
return e;
|
||||
} else {
|
||||
return Shuffle::make(new_vec, op->indices);
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
|
||||
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
|
||||
return e; \
|
||||
|
@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
|
|||
.DISPATCH_TO_MUTATE_EXPR(IntImm)
|
||||
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
|
||||
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
|
||||
.DISPATCH_TO_MUTATE_EXPR(StringImm);
|
||||
.DISPATCH_TO_MUTATE_EXPR(StringImm)
|
||||
.DISPATCH_TO_MUTATE_EXPR(Shuffle);
|
||||
|
||||
} // namespace ir
|
||||
} // namespace tvm
|
||||
|
|
|
@ -111,8 +111,10 @@ inline Expr TVMStructGet(
|
|||
*/
|
||||
inline Expr AddressOffset(Var handle, Type dtype, int offset) {
|
||||
return Call::make(
|
||||
Handle(), Call::address_of,
|
||||
{Load::make(dtype, handle, make_const(Int(32), offset))}, Call::PureIntrinsic);
|
||||
Handle(), intrinsic::tvm_address_of,
|
||||
{Load::make(dtype, handle, make_const(Int(32), offset * dtype.lanes()),
|
||||
const_true(dtype.lanes()))},
|
||||
Call::PureIntrinsic);
|
||||
}
|
||||
|
||||
/*!
|
||||
|
|
|
@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) {
|
|||
|
||||
void IRVisitor::Visit_(const Load *op) {
|
||||
this->Visit(op->index);
|
||||
this->Visit(op->predicate);
|
||||
}
|
||||
|
||||
void IRVisitor::Visit_(const Store *op) {
|
||||
this->Visit(op->value);
|
||||
this->Visit(op->index);
|
||||
this->Visit(op->predicate);
|
||||
}
|
||||
|
||||
void IRVisitor::Visit_(const IfThenElse *op) {
|
||||
|
|
|
@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator {
|
|||
for (size_t i = 0; i < op->args.size(); ++i) {
|
||||
prep_seq_.emplace_back(
|
||||
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
|
||||
ConstInt32(stack_begin +i)));
|
||||
ConstInt32(stack_begin +i), const_true(1)));
|
||||
}
|
||||
return AddressOffset(stack_shape_, Int(64), stack_begin);
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator {
|
|||
prep_seq_.emplace_back(
|
||||
Store::make(stack_tcode_,
|
||||
ConstInt32(arg_tcode),
|
||||
stack_index));
|
||||
stack_index, const_true(1)));
|
||||
}
|
||||
// UPDATE stack value
|
||||
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
|
||||
|
|
|
@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
int threadx_extent = 1;
|
||||
Expr reduce_index = FlattenThread(vred, &reduce_extent);
|
||||
Expr group_index = FlattenThread(vpar, &group_extent);
|
||||
Expr pred = const_true(value.type().lanes());
|
||||
if (reduce_extent == 1) {
|
||||
// special case, no reduction is needed.
|
||||
return Store::make(op->buffer_var, value, 0);
|
||||
return Store::make(op->buffer_var, value, 0, pred);
|
||||
}
|
||||
// Whether the threadIdx.x is involved in reduction.
|
||||
if (vred[0].scope.dim_index == 0) {
|
||||
|
@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
std::vector<Stmt> seq;
|
||||
seq.emplace_back(Store::make(
|
||||
shared_buf, value,
|
||||
BufIndex(reduce_index, group_index, reduce_extent)));
|
||||
BufIndex(reduce_index, group_index, reduce_extent), pred));
|
||||
seq.emplace_back(SyncThread("shared"));
|
||||
seq.emplace_back(MakeBufAllreduce(
|
||||
combiner, value.type(), shared_buf,
|
||||
|
@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
load_remap_[op->buffer_var.get()] =
|
||||
Load::make(
|
||||
value.type(), shared_buf,
|
||||
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent));
|
||||
BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent),
|
||||
pred);
|
||||
alloc_remap_[op->buffer_var.get()] =
|
||||
Allocate::make(shared_buf, value.type(),
|
||||
{Expr(group_extent), Expr(reduce_extent)},
|
||||
const_true(), Evaluate::make(0));
|
||||
pred, Evaluate::make(0));
|
||||
return MergeSeq(seq);
|
||||
}
|
||||
// make allreduce.
|
||||
|
@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator {
|
|||
auto freduce = [&](int offset) {
|
||||
Expr b = Load::make(
|
||||
type, shared_buf,
|
||||
BufIndex(reduce_index + offset, group_index, reduce_extent));
|
||||
Expr a = Load::make(type, shared_buf, buf_index);
|
||||
return Store::make(shared_buf, (*combiner)(a, b), buf_index);
|
||||
BufIndex(reduce_index + offset, group_index, reduce_extent), const_true());
|
||||
Expr a = Load::make(type, shared_buf, buf_index, const_true());
|
||||
return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true());
|
||||
};
|
||||
// Step one, check for
|
||||
if (reduce_align > reduce_extent) {
|
||||
|
|
|
@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
Var tcode(v_arg->name_hint + ".code", Int(32));
|
||||
seq_init.emplace_back(LetStmt::make(
|
||||
tcode, Load::make(
|
||||
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i)), nop));
|
||||
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
|
||||
nop));
|
||||
Type t = v_arg.type();
|
||||
if (t.is_handle()) {
|
||||
std::ostringstream msg;
|
||||
|
@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
f_push(buf->shape[k],
|
||||
cast(buf->shape[k].type(),
|
||||
Load::make(tvm_shape_type, v_shape,
|
||||
IntImm::make(Int(32), k))),
|
||||
IntImm::make(Int(32), k), const_true(1))),
|
||||
field_name.str());
|
||||
}
|
||||
// strides field
|
||||
|
@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body,
|
|||
f_push(buf->strides[k],
|
||||
cast(buf->shape[k].type(),
|
||||
Load::make(tvm_shape_type, v_strides,
|
||||
IntImm::make(Int(32), k))),
|
||||
IntImm::make(Int(32), k), const_true(1))),
|
||||
field_name.str());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
|
|||
op = expr.as<Load>();
|
||||
if (read_access_ && buf_var_ == op->buffer_var.get()) {
|
||||
return Load::make(
|
||||
op->type, op->buffer_var, ir::Simplify(op->index - min_));
|
||||
op->type, op->buffer_var, ir::Simplify(op->index - min_),
|
||||
op->predicate);
|
||||
} else {
|
||||
return expr;
|
||||
}
|
||||
|
@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
|
|||
op = stmt.as<Store>();
|
||||
if (!read_access_ && buf_var_ == op->buffer_var.get()) {
|
||||
return Store::make(
|
||||
op->buffer_var, op->value, ir::Simplify(op->index - min_));
|
||||
op->buffer_var, op->value, ir::Simplify(op->index - min_),
|
||||
op->predicate);
|
||||
} else {
|
||||
return stmt;
|
||||
}
|
||||
|
|
|
@ -170,12 +170,13 @@ class StageSplitter : public IRMutator {
|
|||
Expr index = Mutate(op->index);
|
||||
Stmt provide = Store::make(
|
||||
ch->handle_var,
|
||||
Load::make(op->type, op->buffer_var, index), 0);
|
||||
Load::make(op->type, op->buffer_var, index, op->predicate),
|
||||
0, op->predicate);
|
||||
Stmt temp = nest_.back(); nest_.pop_back();
|
||||
stages_.emplace_back(BuildStage(provide, ch));
|
||||
nest_.push_back(temp);
|
||||
fifo_map_[ch->handle_var.get()] = ch;
|
||||
return Load::make(op->type, ch->handle_var, 0);
|
||||
return Load::make(op->type, ch->handle_var, 0, op->predicate);
|
||||
}
|
||||
|
||||
Stmt Split(Stmt stmt, const ProducerConsumer* env) {
|
||||
|
|
|
@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator {
|
|||
op = stmt.as<Store>();
|
||||
auto it = extern_buf_remap_.find(op->buffer_var.get());
|
||||
if (it != extern_buf_remap_.end()) {
|
||||
return Store::make(it->second, op->value, op->index);
|
||||
return Store::make(it->second, op->value, op->index, op->predicate);
|
||||
} else {
|
||||
return stmt;
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator {
|
|||
op = expr.as<Load>();
|
||||
auto it = extern_buf_remap_.find(op->buffer_var.get());
|
||||
if (it != extern_buf_remap_.end()) {
|
||||
return Load::make(op->type, it->second, op->index);
|
||||
return Load::make(op->type, it->second, op->index, op->predicate);
|
||||
} else {
|
||||
return expr;
|
||||
}
|
||||
|
|
|
@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator {
|
|||
op = stmt.as<Store>();
|
||||
auto it = alloc_map_.find(op->buffer_var.get());
|
||||
if (it == alloc_map_.end()) return stmt;
|
||||
return Store::make(it->second->alloc_var, op->value, op->index);
|
||||
return Store::make(it->second->alloc_var, op->value, op->index, op->predicate);
|
||||
}
|
||||
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||
Expr expr = IRMutator::Mutate_(op, e);
|
||||
op = expr.as<Load>();
|
||||
auto it = alloc_map_.find(op->buffer_var.get());
|
||||
if (it == alloc_map_.end()) return expr;
|
||||
return Load::make(op->type, it->second->alloc_var, op->index);
|
||||
return Load::make(op->type, it->second->alloc_var, op->index, op->predicate);
|
||||
}
|
||||
Expr Mutate_(const Variable* op, const Expr& e) final {
|
||||
auto it = alloc_map_.find(op);
|
||||
|
|
|
@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor {
|
|||
}
|
||||
}
|
||||
void Visit_(const Call* op) final {
|
||||
if (op->is_intrinsic(Call::address_of)) {
|
||||
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
|
||||
const Load *l = op->args[0].as<Load>();
|
||||
IRVisitor::Visit_(l);
|
||||
} else {
|
||||
|
|
|
@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator {
|
|||
op = expr.as<Load>();
|
||||
if (op->buffer_var.get() == buf_) {
|
||||
return Load::make(op->type, op->buffer_var,
|
||||
op->index * var_lanes_ + var_);
|
||||
op->index * var_lanes_ + var_,
|
||||
op->predicate);
|
||||
} else {
|
||||
return expr;
|
||||
}
|
||||
|
@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator {
|
|||
if (op->buffer_var.get() == buf_) {
|
||||
return Store::make(op->buffer_var,
|
||||
op->value,
|
||||
op->index * var_lanes_ + var_);
|
||||
op->index * var_lanes_ + var_,
|
||||
op->predicate);
|
||||
} else {
|
||||
return stmt;
|
||||
}
|
||||
|
@ -160,11 +162,16 @@ class Vectorizer : public IRMutator {
|
|||
// Load
|
||||
Expr Mutate_(const Load* op, const Expr& e) final {
|
||||
Expr index = this->Mutate(op->index);
|
||||
if (index.same_as(op->index)) {
|
||||
Expr pred = this->Mutate(op->predicate);
|
||||
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
|
||||
return e;
|
||||
} else {
|
||||
return Load::make(op->type.with_lanes(index.type().lanes()),
|
||||
op->buffer_var, index);
|
||||
int lanes = std::max(index.type().lanes(), pred.type().lanes());
|
||||
return Load::make(
|
||||
op->type.with_lanes(lanes),
|
||||
op->buffer_var,
|
||||
BroadcastTo(index, lanes),
|
||||
BroadcastTo(pred, lanes));
|
||||
}
|
||||
}
|
||||
// Let
|
||||
|
@ -201,13 +208,16 @@ class Vectorizer : public IRMutator {
|
|||
Stmt Mutate_(const Store* op, const Stmt& s) final {
|
||||
Expr value = this->Mutate(op->value);
|
||||
Expr index = this->Mutate(op->index);
|
||||
Expr pred = this->Mutate(op->predicate);
|
||||
if (value.same_as(op->value) && index.same_as(op->index)) {
|
||||
return s;
|
||||
} else {
|
||||
int lanes = std::max(value.type().lanes(), index.type().lanes());
|
||||
lanes = std::max(lanes, pred.type().lanes());
|
||||
return Store::make(op->buffer_var,
|
||||
BroadcastTo(value, lanes),
|
||||
BroadcastTo(index, lanes));
|
||||
BroadcastTo(index, lanes),
|
||||
BroadcastTo(pred, lanes));
|
||||
}
|
||||
}
|
||||
// For
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
#include <dmlc/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <tvm/tvm.h>
|
||||
#include <pass/CSE.h>
|
||||
|
||||
TEST(IR_PASS, CSE) {
|
||||
using namespace Halide::Internal;
|
||||
cse_test();
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
#include <dmlc/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <tvm/tvm.h>
|
||||
#include <pass/Simplify.h>
|
||||
#include <arithmetic/Simplify.h>
|
||||
|
||||
TEST(IRSIMPLIFY, Basic) {
|
||||
using namespace Halide::Internal;
|
||||
|
|
Загрузка…
Ссылка в новой задаче