Change name matching to SSA Variable and Function matching

This commit is contained in:
tqchen 2016-10-27 22:17:44 -07:00
Родитель 062bb85326
Коммит f52d0713ff
5 изменённых файлов: 39 добавлений и 2 удалений

@ -1 +1 @@
Subproject commit 9070ac3697931ef5aeb8c373c23b2e8a2fec4627
Subproject commit 84a568ce86ca64ff4e186b78745152061499cbf4

Просмотреть файл

@ -6,4 +6,4 @@ from ._ctypes._api import register_node
from . import expr
from . import stmt
from . import make
from . import domain
from . import collections

Просмотреть файл

@ -14,3 +14,10 @@ class Array(NodeBase):
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
@register_node
class Range(NodeBase):
def __repr__(self):
return ('Range(min='+ str(self.min) +
', extent=' + str(self.extent) + ')')

Просмотреть файл

@ -41,4 +41,27 @@ def convert(value):
else:
return value
def Range(begin, **kwargs):
"""Create a TVM Range object.
User can either call:
Range(10) to get a range in [0, 10)
or
Range(begin=1, extent=10), to get a range in [0, 11)
Parameters
----------
begin : Expr
The beginning of the expression.
extent : optional, Expr
The extent(i.e. the length) of the range.
"""
if "extent" in kwargs:
return _function_internal._Range(begin, kwargs["extent"])
else:
return _function_internal._Range(0, begin);
_init_function_module("tvm")

Просмотреть файл

@ -82,4 +82,11 @@ TVM_REGISTER_API(_ArraySize)
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_Range)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Range(args.at(0), args.at(1));
})
.add_argument("min", "Expr", "beginning of the range.")
.add_argument("extent", "Expr", "extent of the range");
} // namespace tvm