Rename dim_var to axis, update testcases

This commit is contained in:
tqchen 2017-01-06 15:08:36 -08:00
Родитель ff26cd68d0
Коммит 57a74936db
11 изменённых файлов: 69 добавлений и 45 удалений

Просмотреть файл

@ -129,12 +129,14 @@ TVM_DLL int TVMNodeFree(NodeHandle handle);
* \param handle The node handle
* \param key The attribute name
* \param out_value The attribute value
* \param out_typeid The typeif of the attribute.
* \param out_typeid The typeid of the attribute.
* \param out_success Whether get is successful.
*/
TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
const char* key,
ArgVariant* out_value,
int* out_typeid);
int* out_typeid,
int* out_success);
/*!
* \brief get attributes names in the node.

Просмотреть файл

@ -17,8 +17,8 @@ namespace tvm {
*/
class ComputeOpNode : public OperationNode {
public:
/*! \brief Iteration variables over the dimensions */
Array<IterVar> dim_var;
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief the compute expression */
Expr body;
/*! \brief constructor */
@ -34,11 +34,11 @@ class ComputeOpNode : public OperationNode {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("dim_var", &dim_var);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
static Operation make(std::string name,
Array<IterVar> dim_var,
Array<IterVar> axis,
Expr body);
static constexpr const char* _type_key = "ComputeOp";

Просмотреть файл

@ -72,10 +72,18 @@ class NodeBase(object):
def __getattr__(self, name):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val)
ctypes.byref(ret_val),
ctypes.byref(ret_typeid),
ctypes.byref(ret_success)))
value = RET_SWITCH[ret_typeid.value](ret_val)
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return value
def __hash__(self):
return _function_internal._raw_ptr(self)

Просмотреть файл

@ -37,6 +37,7 @@ using TVMAPINode = std::shared_ptr<Node>;
struct APIAttrGetter : public AttrVisitor {
std::string skey;
APIVariantValue* ret;
bool found_node_ref{false};
void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0];
@ -62,7 +63,10 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, NodeRef* value) final {
if (skey == key) *ret = value[0];
if (skey == key) {
*ret = value[0];
found_node_ref = true;
}
}
};
@ -198,7 +202,8 @@ int TVMNodeFree(NodeHandle handle) {
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
ArgVariant* ret_val,
int* ret_typeid) {
int* ret_typeid,
int* ret_success) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_value.type_id = kNull;
@ -209,11 +214,14 @@ int TVMNodeGetAttr(NodeHandle handle,
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_typeid = kStr;
*ret_success = 1;
} else {
(*tnode)->VisitAttrs(&getter);
if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid);
*ret_success = 1;
} else {
*ret_success = getter.found_node_ref ? 1 : 0;
*ret_typeid = kNull;
}
}

Просмотреть файл

@ -13,10 +13,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm {
Range::Range(Expr begin, Expr end)
: Range(std::make_shared<Halide::IR::RangeNode>(begin, end - begin)) {
// TODO(tqchen) add simplify to end - begin
: Range(std::make_shared<Halide::IR::RangeNode>(
begin,
is_zero(begin) ? end : (end - begin))) {
}
Range Range::make_with_min_extent(Expr min, Expr extent) {

Просмотреть файл

@ -18,27 +18,27 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) {
auto op_node = std::make_shared<ComputeOpNode>();
// compute dimension.
size_t ndim = shape.size();
std::vector<IterVar> dim_var;
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "dim_var" << i;
dim_var.push_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(dim_var.back()->var);
os << "ax" << i;
axis.emplace_back(IterVar(Range(0, shape[i]), os.str()));
args.push_back(axis.back()->var);
}
op_node->dim_var = Array<IterVar>(dim_var);
op_node->axis = Array<IterVar>(axis);
op_node->body = fcompute(args);
op_node->name = name;
return Operation(op_node).output(0);
}
Operation ComputeOpNode::make(std::string name,
Array<IterVar> dim_var,
Array<IterVar> axis,
Expr body) {
auto n = std::make_shared<ComputeOpNode>();
n->name = name;
n->dim_var = dim_var;
n->axis = axis;
n->body = body;
return Operation(n);
}
@ -54,7 +54,7 @@ Tensor Operation::output(size_t i) const {
}
Array<IterVar> ComputeOpNode::root_iter_vars() const {
return dim_var;
return axis;
}
std::string ComputeOpNode::output_name(size_t i) const {
@ -70,8 +70,8 @@ Type ComputeOpNode::output_dtype(size_t i) const {
Array<Expr> ComputeOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
std::vector<Expr> shape;
for (size_t i = 0; i < dim_var.size(); ++i) {
const Range& r = dim_var[i]->dom;
for (size_t i = 0; i < axis.size(); ++i) {
const Range& r = axis[i]->dom;
shape.push_back(r->extent);
}
return Array<Expr>(shape);

Просмотреть файл

@ -30,7 +30,15 @@ def test_attr():
stmt = tvm.make.AttrStmt(
y, "stride", 10, tvm.make.Evaluate(x + 1));
assert stmt.node == y
print(stmt)
a = tvm.convert(1)
assert a.value == 1
try:
a.no_field
assert False
except AttributeError:
pass
def test_basic():
a = tvm.Var('a')
@ -48,7 +56,6 @@ def test_stmt():
if __name__ == "__main__":
test_attr()
test_const()
test_make()
test_ir()

Просмотреть файл

@ -8,11 +8,11 @@ def test_bound1():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA2.split(A2.op.dim_var[0], 8)
xo, xi = sA2.split(A2.op.axis[0], 8)
sA1.compute_at(sA2, xo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
assert(bounds[A1.op.axis[0]].extent.value == 8)
def test_bound2():
m = tvm.Var('m')
@ -22,12 +22,12 @@ def test_bound2():
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, yo, xi, yi = sA2.tile(A2.op.dim_var[0], A2.op.dim_var[1], 8, 8)
xo, yo, xi, yi = sA2.tile(A2.op.axis[0], A2.op.axis[1], 8, 8)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value == 8)
assert(bounds[A1.op.dim_var[1]].extent.value == 8)
assert(bounds[A1.op.axis[0]].extent.value == 8)
assert(bounds[A1.op.axis[1]].extent.value == 8)
def test_bound3():
m = tvm.Var('m')
@ -38,16 +38,16 @@ def test_bound3():
sA1 = tvm.Schedule(A1.op, scope="shared")
sA2 = tvm.Schedule(A2.op)
thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x")
xo, xi = sA2.split(A2.op.dim_var[0], 32)
xo, xi = sA2.split(A2.op.axis[0], 32)
xi0, xi1 = sA2.split(xi, outer=thread_x)
yo, yi = sA2.split(A2.op.dim_var[1], 16)
yo, yi = sA2.split(A2.op.axis[1], 16)
sA2.reorder(xo, xi0, yo, xi1, yi)
sA1.compute_at(sA2, yo)
bounds = tvm.schedule.InferBound(sA2)
assert isinstance(bounds, tvm.collections.Map)
assert(bounds[A1.op.dim_var[0]].extent.value==32)
assert(bounds[A1.op.dim_var[1]].extent.value==16)
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)
def test_create_read_graph():

Просмотреть файл

@ -3,11 +3,10 @@ import tvm
def test_inline():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A(i) + 10, name='T')
X = T(100)
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
T = tvm.compute((m,), lambda i,: A[i] + 10, name='T')
stmt = tvm.make.Evaluate(T[10] + 11 * T[100])
stmt = tvm.ir_pass.Inline(
T, T.op.dim_var, T.op.body, stmt)
T, [x.var for x in T.op.axis], T.op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))

Просмотреть файл

@ -12,14 +12,14 @@ def test_schedule_create():
sch_T = tvm.Schedule(T.op, scope="shared")
sch_A = tvm.Schedule(AA.op, scope="global")
xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
xo, xi = sch_T.split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
sch_A.compute_at(sch_T, xi1)
xo, xi = sch_A.split(AA.op.dim_var[0], factor=10)
xo, xi = sch_A.split(AA.op.axis[0], factor=10)
sch_T.reorder(xi2, xi1)
assert T.op.dim_var[1] in sch_T.leaf_iter_vars
assert T.op.axis[1] in sch_T.leaf_iter_vars
def test_reorder():
m = tvm.Var('m')
@ -27,7 +27,7 @@ def test_reorder():
T = tvm.compute(m, lambda i: A[i+1])
sch_T = tvm.Schedule(T.op, scope="shared")
xo, xi = sch_T.split(T.op.dim_var[0], factor=10)
xo, xi = sch_T.split(T.op.axis[0], factor=10)
xi1, xi2 = sch_T.split(xi, factor=2)
order = (xi2, xi1, xo)
assert tuple(sch_T.leaf_iter_vars) != order
@ -40,7 +40,7 @@ def test_split():
T = tvm.compute((m,), lambda i: A[i])
sT = tvm.Schedule(T.op)
xo, xi = sT.split(T.op.dim_var[0], factor=10)
xo, xi = sT.split(T.op.axis[0], factor=10)
assert tuple(sT.leaf_iter_vars) == (xo, xi)
@ -51,7 +51,7 @@ def test_tile():
T = tvm.compute((m, n), lambda i, j: A[i, j])
sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
xo, yo, xi, yi = sch_T.tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi)
if __name__ == "__main__":

Просмотреть файл

@ -10,7 +10,7 @@ def test_tensor():
print(T)
print(T.op.body)
assert(tuple(T.shape) == (m, n, l))
assert(A.source is None)
assert(A.op is None)
def test_tensor_reduce():
m = tvm.Var('m')