Change name matching to SSA Variable and Function matching
This commit is contained in:
Родитель
062bb85326
Коммит
f52d0713ff
2
HalideIR
2
HalideIR
|
@ -1 +1 @@
|
||||||
Subproject commit 9070ac3697931ef5aeb8c373c23b2e8a2fec4627
|
Subproject commit 84a568ce86ca64ff4e186b78745152061499cbf4
|
|
@ -6,4 +6,4 @@ from ._ctypes._api import register_node
|
||||||
from . import expr
|
from . import expr
|
||||||
from . import stmt
|
from . import stmt
|
||||||
from . import make
|
from . import make
|
||||||
from . import domain
|
from . import collections
|
||||||
|
|
|
@ -14,3 +14,10 @@ class Array(NodeBase):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '[' + (','.join(str(x) for x in self)) + ']'
|
return '[' + (','.join(str(x) for x in self)) + ']'
|
||||||
|
|
||||||
|
|
||||||
|
@register_node
|
||||||
|
class Range(NodeBase):
|
||||||
|
def __repr__(self):
|
||||||
|
return ('Range(min='+ str(self.min) +
|
||||||
|
', extent=' + str(self.extent) + ')')
|
|
@ -41,4 +41,27 @@ def convert(value):
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def Range(begin, **kwargs):
|
||||||
|
"""Create a TVM Range object.
|
||||||
|
|
||||||
|
User can either call:
|
||||||
|
Range(10) to get a range in [0, 10)
|
||||||
|
|
||||||
|
or
|
||||||
|
Range(begin=1, extent=10), to get a range in [0, 11)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
begin : Expr
|
||||||
|
The beginning of the expression.
|
||||||
|
|
||||||
|
extent : optional, Expr
|
||||||
|
The extent(i.e. the length) of the range.
|
||||||
|
"""
|
||||||
|
if "extent" in kwargs:
|
||||||
|
return _function_internal._Range(begin, kwargs["extent"])
|
||||||
|
else:
|
||||||
|
return _function_internal._Range(0, begin);
|
||||||
|
|
||||||
_init_function_module("tvm")
|
_init_function_module("tvm")
|
||||||
|
|
|
@ -82,4 +82,11 @@ TVM_REGISTER_API(_ArraySize)
|
||||||
static_cast<const ArrayNode*>(sptr.get())->data.size());
|
static_cast<const ArrayNode*>(sptr.get())->data.size());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
TVM_REGISTER_API(_Range)
|
||||||
|
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||||
|
*ret = Range(args.at(0), args.at(1));
|
||||||
|
})
|
||||||
|
.add_argument("min", "Expr", "beginning of the range.")
|
||||||
|
.add_argument("extent", "Expr", "extent of the range");
|
||||||
|
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
Загрузка…
Ссылка в новой задаче