"""Tag class for TVM operators.""" import warnings from ._ffi.base import decorate class TagScope(object): """Tag scope object to set tag for operators, working as context manager and decorator both. See also tag_scope. """ _current = None @classmethod def get_current(cls): if cls._current: cls._current.accessed = True return cls._current def __init__(self, tag): self._old_scope = None self.tag = tag self.accessed = False def __enter__(self): if TagScope._current is not None: raise ValueError("nested op_tag is not allowed for now") self._old_scope = TagScope._current TagScope._current = self return self def __exit__(self, ptype, value, trace): assert self._old_scope is None if not self.accessed: warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,)) TagScope._current = self._old_scope def __call__(self, fdecl): def tagged_fdecl(func, *args, **kwargs): with self: return func(*args, **kwargs) return decorate(fdecl, tagged_fdecl) def tag_scope(tag): """The operator tag scope. Parameters ---------- tag: str The tag name. Returns ------- tag_scope: TagScope The tag scope object, which can be used as decorator or context manger. Example ------- .. code-block:: python n = tvm.var('n') m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((n, l), name='A') B = tvm.placeholder((m, l), name='B') k = tvm.reduce_axis((0, l), name='k') with tvm.tag_scope(tag='matmul'): C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k)) # or use tag_scope as decorator @tvm.tag_scope(tag="conv") def compute_relu(data): return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i))) """ return TagScope(tag)