201 строка
5.5 KiB
Python
201 строка
5.5 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
|
|
"""The type nodes of the Relay language."""
|
|
from enum import IntEnum
|
|
from .base import NodeBase, register_relay_node
|
|
from . import _make
|
|
|
|
|
|
class Type(NodeBase):
|
|
"""The base type for all Relay types."""
|
|
|
|
def __eq__(self, other):
|
|
"""Compare two Relay types for structural equivalence using
|
|
alpha equivalence.
|
|
"""
|
|
return bool(_make._type_alpha_eq(self, other))
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def same_as(self, other):
|
|
"""Compares two Relay types by referential equality."""
|
|
return super().__eq__(other)
|
|
|
|
|
|
@register_relay_node
|
|
class TensorType(Type):
|
|
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
|
|
|
|
This is the type assigned to tensor's with a known dype and shape. For
|
|
example a tensor of `float32` and `(5, 5)`.
|
|
"""
|
|
|
|
def __init__(self, shape, dtype):
|
|
"""Construct a tensor type.
|
|
|
|
Parameters
|
|
----------
|
|
shape: list of tvm.Expr
|
|
dtype: str
|
|
|
|
Returns
|
|
-------
|
|
tensor_type: The TensorType
|
|
"""
|
|
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
|
|
|
|
|
|
class Kind(IntEnum):
|
|
"""The kind of a type parameter, represents a variable shape,
|
|
base type, type, or dimension.
|
|
|
|
This controls what a type parameter is allowed to be instantiated
|
|
with. For example one's of kind BaseType can only be `float32`, `int32`,
|
|
and so on.
|
|
"""
|
|
ShapeVar = 0
|
|
Shape = 1
|
|
BaseType = 2
|
|
Type = 3
|
|
|
|
|
|
@register_relay_node
|
|
class TypeParam(Type):
|
|
"""A type parameter used for generic types in Relay,
|
|
see tvm/relay/type.h for more details.
|
|
|
|
A type parameter represents a type placeholder which will
|
|
be filled in later on. This allows the user to write
|
|
functions which are generic over types.
|
|
"""
|
|
|
|
def __init__(self, var, kind):
|
|
"""Construct a TypeParam.
|
|
|
|
Parameters
|
|
----------
|
|
var: tvm.expr.Var
|
|
The tvm.Var which backs the type parameter.
|
|
|
|
kind: Kind
|
|
The kind of the type parameter.
|
|
|
|
Returns
|
|
-------
|
|
type_param: TypeParam
|
|
The type parameter.
|
|
"""
|
|
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
|
|
|
|
|
|
@register_relay_node
|
|
class TypeConstraint(Type):
|
|
"""Abstract class representing a type constraint."""
|
|
pass
|
|
|
|
|
|
@register_relay_node
|
|
class TupleType(Type):
|
|
"""A tuple type in Relay, see tvm/relay/type.h for more details.
|
|
|
|
Lists the type of each field in the tuple.
|
|
"""
|
|
|
|
def __init__(self, fields):
|
|
"""Constructs a tuple type
|
|
|
|
Parameters
|
|
----------
|
|
fields: list of tvm.Type
|
|
|
|
Returns
|
|
-------
|
|
tuple_type: the tuple type
|
|
"""
|
|
self.__init_handle_by_constructor__(_make.TupleType, fields)
|
|
|
|
|
|
@register_relay_node
|
|
class FuncType(Type):
|
|
"""A function type in Relay, see tvm/relay/type.h for more details.
|
|
|
|
This is the type assigned to functions in Relay. They consist of
|
|
a list of type parameters which enable the definition of generic
|
|
functions, a set of type constraints which we omit for the time
|
|
being, a sequence of argument types, and a return type.
|
|
|
|
We informally write them as:
|
|
`forall (type_params), (arg_types) -> ret_type where type_constraints`
|
|
"""
|
|
|
|
def __init__(self,
|
|
arg_types,
|
|
ret_type,
|
|
type_params,
|
|
type_constraints,
|
|
):
|
|
"""Construct a function type.
|
|
|
|
Parameters
|
|
----------
|
|
arg_types: list of Type
|
|
ret_type: Type
|
|
type_params: list of TypeParam
|
|
type_constraints: list of TypeConstraint
|
|
|
|
Returns
|
|
-------
|
|
func_type: FuncType
|
|
The function type.
|
|
"""
|
|
self.__init_handle_by_constructor__(
|
|
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
|
|
|
|
|
|
@register_relay_node
|
|
class IncompleteType(Type):
|
|
"""An incomplete type."""
|
|
|
|
def __init__(self, kind=Kind.Type):
|
|
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
|
|
|
|
@register_relay_node
|
|
class TypeRelation(TypeConstraint):
|
|
"""Type relation in relay.
|
|
|
|
Parameters
|
|
----------
|
|
func : EnvFunc
|
|
User defined relation function.
|
|
|
|
args : list of types
|
|
List of types to the func.
|
|
|
|
num_inputs: int
|
|
Number of input arguments in args,
|
|
this act as a hint for type inference.
|
|
|
|
attrs : Attrs
|
|
The attribute attached to the relation information
|
|
"""
|
|
def __init__(self, func, args, num_inputs, attrs):
|
|
self.__init_handle_by_constructor__(_make.TypeRelation,
|
|
func, args, num_inputs, attrs)
|