Rename dim_var to axis, update testcases
This commit is contained in:
Родитель
ff26cd68d0
Коммит
57a74936db
|
@ -129,12 +129,14 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
|
|||
* \param handle The node handle
|
||||
* \param key The attribute name
|
||||
* \param out_value The attribute value
|
||||
* \param out_typeid The typeif of the attribute.
|
||||
* \param out_typeid The typeid of the attribute.
|
||||
* \param out_success Whether get is successful.
|
||||
*/
|
||||
TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
|
||||
const char* key,
|
||||
ArgVariant* out_value,
|
||||
int* out_typeid);
|
||||
int* out_typeid,
|
||||
int* out_success);
|
||||
|
||||
/*!
|
||||
* \brief get attributes names in the node.
|
||||
|
|
|
@ -17,8 +17,8 @@ namespace tvm {
|
|||
*/
|
||||
class ComputeOpNode : public OperationNode {
|
||||
public:
|
||||
/*! \brief Iteration variables over the dimensions */
|
||||
Array<IterVar> dim_var;
|
||||
/*! \brief IterVar on each axis */
|
||||
Array<IterVar> axis;
|
||||
/*! \brief the compute expression */
|
||||
Expr body;
|
||||
/*! \brief constructor */
|
||||
|
@ -34,11 +34,11 @@ class ComputeOpNode : public OperationNode {
|
|||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("name", &name);
|
||||
v->Visit("dim_var", &dim_var);
|
||||
v->Visit("axis", &axis);
|
||||
v->Visit("body", &body);
|
||||
}
|
||||
static Operation make(std::string name,
|
||||
Array<IterVar> dim_var,
|
||||
Array<IterVar> axis,
|
||||
Expr body);
|
||||
|
||||
static constexpr const char* _type_key = "ComputeOp";
|
||||
|
|
|
@ -72,10 +72,18 @@ class NodeBase(object):
|
|||
def __getattr__(self, name):
|
||||
ret_val = ArgVariant()
|
||||
ret_typeid = ctypes.c_int()
|
||||
ret_success = ctypes.c_int()
|
||||
check_call(_LIB.TVMNodeGetAttr(
|
||||
self.handle, c_str(name),
|
||||
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
|
||||
return RET_SWITCH[ret_typeid.value](ret_val)
|
||||
ctypes.byref(ret_val),
|
||||
ctypes.byref(ret_typeid),
|
||||
ctypes.byref(ret_success)))
|
||||
value = RET_SWITCH[ret_typeid.value](ret_val)
|
||||
if not ret_success.value:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute '%s'" % (str(type(self)), name))
|
||||
return value
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
return _function_internal._raw_ptr(self)
|
||||
|
|
|
@ -37,6 +37,7 @@ using TVMAPINode = std::shared_ptr<Node>;
|
|||
struct APIAttrGetter : public AttrVisitor {
|
||||
std::string skey;
|
||||
APIVariantValue* ret;
|
||||
bool found_node_ref{false};
|
||||
|
||||
void Visit(const char* key, double* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
|
@ -62,7 +63,10 @@ struct APIAttrGetter : public AttrVisitor {
|
|||
if (skey == key) *ret = value[0];
|
||||
}
|
||||
void Visit(const char* key, NodeRef* value) final {
|
||||
if (skey == key) *ret = value[0];
|
||||
if (skey == key) {
|
||||
*ret = value[0];
|
||||
found_node_ref = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -198,7 +202,8 @@ int TVMNodeFree(NodeHandle handle) {
|
|||
int TVMNodeGetAttr(NodeHandle handle,
|
||||
const char* key,
|
||||
ArgVariant* ret_val,
|
||||
int* ret_typeid) {
|
||||
int* ret_typeid,
|
||||
int* ret_success) {
|
||||
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
|
||||
API_BEGIN();
|
||||
ret->ret_value.type_id = kNull;
|
||||
|
@ -209,11 +214,14 @@ int TVMNodeGetAttr(NodeHandle handle,
|
|||
if (getter.skey == "type_key") {
|
||||
ret_val->v_str = (*tnode)->type_key();
|
||||
*ret_typeid = kStr;
|
||||
*ret_success = 1;
|
||||
} else {
|
||||
(*tnode)->VisitAttrs(&getter);
|
||||
if (ret->ret_value.type_id != kNull) {
|
||||
ret->SetReturn(ret_val, ret_typeid);
|
||||
*ret_success = 1;
|
||||
} else {
|
||||
*ret_success = getter.found_node_ref ? 1 : 0;
|
||||
*ret_typeid = kNull;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,10 +13,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
|
|||
} // namespace dmlc
|
||||
|
||||
namespace tvm {
|
||||
|
||||
Range::Range(Expr begin, Expr end)
|
||||
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
|
||||
// TODO(tqchen) add simplify to end - begin
|
||||
: Range(std::make_shared<Halide::IR::RangeNode>(
|
||||
begin,
|
||||
is_zero(begin) ? end : (end - begin))) {
|
||||
}
|
||||
|
||||
Range Range::make_with_min_extent(Expr min, Expr extent) {
|
||||
|
|
|
@ -18,27 +18,27 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
|
|||
auto op_node = std::make_shared<ComputeOpNode>();
|
||||
// compute dimension.
|
||||
size_t ndim = shape.size();
|
||||
std::vector<IterVar> dim_var;
|
||||
std::vector<IterVar> axis;
|
||||
std::vector<Var> args;
|
||||
for (size_t i = 0; i < ndim; ++i) {
|
||||
std::ostringstream os;
|
||||
os << "dim_var" << i;
|
||||
dim_var.push_back(IterVar(Range(0, shape[i]), os.str()));
|
||||
args.push_back(dim_var.back()->var);
|
||||
os << "ax" << i;
|
||||
axis.emplace_back(IterVar(Range(0, shape[i]), os.str()));
|
||||
args.push_back(axis.back()->var);
|
||||
}
|
||||
|
||||
op_node->dim_var = Array<IterVar>(dim_var);
|
||||
op_node->axis = Array<IterVar>(axis);
|
||||
op_node->body = fcompute(args);
|
||||
op_node->name = name;
|
||||
return Operation(op_node).output(0);
|
||||
}
|
||||
|
||||
Operation ComputeOpNode::make(std::string name,
|
||||
Array<IterVar> dim_var,
|
||||
Array<IterVar> axis,
|
||||
Expr body) {
|
||||
auto n = std::make_shared<ComputeOpNode>();
|
||||
n->name = name;
|
||||
n->dim_var = dim_var;
|
||||
n->axis = axis;
|
||||
n->body = body;
|
||||
return Operation(n);
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ Tensor Operation::output(size_t i) const {
|
|||
}
|
||||
|
||||
Array<IterVar> ComputeOpNode::root_iter_vars() const {
|
||||
return dim_var;
|
||||
return axis;
|
||||
}
|
||||
|
||||
std::string ComputeOpNode::output_name(size_t i) const {
|
||||
|
@ -70,8 +70,8 @@ Type ComputeOpNode::output_dtype(size_t i) const {
|
|||
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
|
||||
CHECK_EQ(i, 0U);
|
||||
std::vector<Expr> shape;
|
||||
for (size_t i = 0; i < dim_var.size(); ++i) {
|
||||
const Range& r = dim_var[i]->dom;
|
||||
for (size_t i = 0; i < axis.size(); ++i) {
|
||||
const Range& r = axis[i]->dom;
|
||||
shape.push_back(r->extent);
|
||||
}
|
||||
return Array<Expr>(shape);
|
||||
|
|
|
@ -30,7 +30,15 @@ def test_attr():
|
|||
stmt = tvm.make.AttrStmt(
|
||||
y, "stride", 10, tvm.make.Evaluate(x + 1));
|
||||
assert stmt.node == y
|
||||
print(stmt)
|
||||
|
||||
a = tvm.convert(1)
|
||||
assert a.value == 1
|
||||
try:
|
||||
a.no_field
|
||||
assert False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_basic():
|
||||
a = tvm.Var('a')
|
||||
|
@ -48,7 +56,6 @@ def test_stmt():
|
|||
|
||||
if __name__ == "__main__":
|
||||
test_attr()
|
||||
|
||||
test_const()
|
||||
test_make()
|
||||
test_ir()
|
||||
|
|
|
@ -8,11 +8,11 @@ def test_bound1():
|
|||
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
|
||||
sA1 = tvm.Schedule(A1.op)
|
||||
sA2 = tvm.Schedule(A2.op)
|
||||
xo, xi = sA2.split(A2.op.dim_var[0], 8)
|
||||
xo, xi = sA2.split(A2.op.axis[0], 8)
|
||||
sA1.compute_at(sA2, xo)
|
||||
bounds = tvm.schedule.InferBound(sA2)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
|
||||
assert(bounds[A1.op.axis[0]].extent.value == 8)
|
||||
|
||||
def test_bound2():
|
||||
m = tvm.Var('m')
|
||||
|
@ -22,12 +22,12 @@ def test_bound2():
|
|||
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
|
||||
sA1 = tvm.Schedule(A1.op)
|
||||
sA2 = tvm.Schedule(A2.op)
|
||||
xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8)
|
||||
xo, yo, xi, yi = sA2.tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
|
||||
sA1.compute_at(sA2, yo)
|
||||
bounds = tvm.schedule.InferBound(sA2)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
|
||||
assert(bounds[A1.op.dim_var[1]].extent.value == 8)
|
||||
assert(bounds[A1.op.axis[0]].extent.value == 8)
|
||||
assert(bounds[A1.op.axis[1]].extent.value == 8)
|
||||
|
||||
def test_bound3():
|
||||
m = tvm.Var('m')
|
||||
|
@ -38,16 +38,16 @@ def test_bound3():
|
|||
sA1 = tvm.Schedule(A1.op, scope="shared")
|
||||
sA2 = tvm.Schedule(A2.op)
|
||||
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
|
||||
xo, xi = sA2.split(A2.op.dim_var[0], 32)
|
||||
xo, xi = sA2.split(A2.op.axis[0], 32)
|
||||
xi0, xi1 = sA2.split(xi, outer=thread_x)
|
||||
yo, yi = sA2.split(A2.op.dim_var[1], 16)
|
||||
yo, yi = sA2.split(A2.op.axis[1], 16)
|
||||
sA2.reorder(xo, xi0, yo, xi1, yi)
|
||||
sA1.compute_at(sA2, yo)
|
||||
|
||||
bounds = tvm.schedule.InferBound(sA2)
|
||||
assert isinstance(bounds, tvm.collections.Map)
|
||||
assert(bounds[A1.op.dim_var[0]].extent.value==32)
|
||||
assert(bounds[A1.op.dim_var[1]].extent.value==16)
|
||||
assert(bounds[A1.op.axis[0]].extent.value==32)
|
||||
assert(bounds[A1.op.axis[1]].extent.value==16)
|
||||
|
||||
|
||||
def test_create_read_graph():
|
||||
|
|
|
@ -3,11 +3,10 @@ import tvm
|
|||
def test_inline():
|
||||
m = tvm.Var('m')
|
||||
A = tvm.placeholder((m,), name='A')
|
||||
T = tvm.compute((m,), lambda i,: A(i) + 10, name='T')
|
||||
X = T(100)
|
||||
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
|
||||
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
|
||||
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
|
||||
stmt = tvm.ir_pass.Inline(
|
||||
T, T.op.dim_var, T.op.body, stmt)
|
||||
T, [x.var for x in T.op.axis], T.op.body, stmt)
|
||||
print(stmt)
|
||||
assert(tvm.ir_pass.VerifySSA(stmt))
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@ def test_schedule_create():
|
|||
sch_T = tvm.Schedule(T.op, scope="shared")
|
||||
sch_A = tvm.Schedule(AA.op, scope="global")
|
||||
|
||||
xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
|
||||
xo, xi = sch_T.split(T.op.axis[0], factor=10)
|
||||
xi1, xi2 = sch_T.split(xi, factor=2)
|
||||
|
||||
sch_A.compute_at(sch_T, xi1)
|
||||
xo, xi = sch_A.split(AA.op.dim_var[0], factor=10)
|
||||
xo, xi = sch_A.split(AA.op.axis[0], factor=10)
|
||||
|
||||
sch_T.reorder(xi2, xi1)
|
||||
assert T.op.dim_var[1] in sch_T.leaf_iter_vars
|
||||
assert T.op.axis[1] in sch_T.leaf_iter_vars
|
||||
|
||||
def test_reorder():
|
||||
m = tvm.Var('m')
|
||||
|
@ -27,7 +27,7 @@ def test_reorder():
|
|||
T = tvm.compute(m, lambda i: A[i+1])
|
||||
|
||||
sch_T = tvm.Schedule(T.op, scope="shared")
|
||||
xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
|
||||
xo, xi = sch_T.split(T.op.axis[0], factor=10)
|
||||
xi1, xi2 = sch_T.split(xi, factor=2)
|
||||
order = (xi2, xi1, xo)
|
||||
assert tuple(sch_T.leaf_iter_vars) != order
|
||||
|
@ -40,7 +40,7 @@ def test_split():
|
|||
T = tvm.compute((m,), lambda i: A[i])
|
||||
|
||||
sT = tvm.Schedule(T.op)
|
||||
xo, xi = sT.split(T.op.dim_var[0], factor=10)
|
||||
xo, xi = sT.split(T.op.axis[0], factor=10)
|
||||
assert tuple(sT.leaf_iter_vars) == (xo, xi)
|
||||
|
||||
|
||||
|
@ -51,7 +51,7 @@ def test_tile():
|
|||
T = tvm.compute((m, n), lambda i, j: A[i, j])
|
||||
|
||||
sch_T = tvm.Schedule(T.op, scope="shared")
|
||||
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
|
||||
xo, yo, xi, yi = sch_T.tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
|
||||
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -10,7 +10,7 @@ def test_tensor():
|
|||
print(T)
|
||||
print(T.op.body)
|
||||
assert(tuple(T.shape) == (m, n, l))
|
||||
assert(A.source is None)
|
||||
assert(A.op is None)
|
||||
|
||||
def test_tensor_reduce():
|
||||
m = tvm.Var('m')
|
||||
|
|
Загрузка…
Ссылка в новой задаче