[Relay][Prelude] Remove Peano nats from the prelude (#3045)
This commit is contained in:
Родитель
c93235d77f
Коммит
95bfd4a242
|
@ -17,7 +17,8 @@
|
|||
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
|
||||
"""Adds certain standard global functions and ADT definitions to the module."""
|
||||
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
|
||||
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem
|
||||
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
|
||||
from .op.tensor import add, subtract, equal
|
||||
from .adt import Constructor, TypeData, Clause, Match
|
||||
from .adt import PatternConstructor, PatternVar, PatternWildcard
|
||||
|
||||
|
@ -34,6 +35,7 @@ class Prelude:
|
|||
self.cons = Constructor("cons", [a, self.l(a)], self.l)
|
||||
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
|
||||
|
||||
|
||||
def define_list_hd(self):
|
||||
"""Defines a function to get the head of a list. Assume the list has at least one
|
||||
element.
|
||||
|
@ -48,6 +50,7 @@ class Prelude:
|
|||
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
|
||||
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
|
||||
|
||||
|
||||
def define_list_tl(self):
|
||||
"""Defines a function to get the tail of a list.
|
||||
|
||||
|
@ -61,39 +64,44 @@ class Prelude:
|
|||
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
|
||||
self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a])
|
||||
|
||||
|
||||
def define_list_nth(self):
|
||||
"""Defines a function to get the nth element of a list.
|
||||
|
||||
nth(l) : list[a] -> a
|
||||
nth(l) : list[a] -> Tensor[(), int32] -> a
|
||||
"""
|
||||
self.nth = GlobalVar("nth")
|
||||
a = TypeVar("a")
|
||||
x = Var("x", self.l(a))
|
||||
n = Var("n", self.nat())
|
||||
n = Var("n", scalar_type('int32'))
|
||||
|
||||
body = If(equal(n, const(0)),
|
||||
self.hd(x),
|
||||
self.nth(self.tl(x), subtract(n, const(1))))
|
||||
|
||||
self.mod[self.nth] = Function([x, n], body, a, [a])
|
||||
|
||||
y = Var("y")
|
||||
z_case = Clause(PatternConstructor(self.z), self.hd(x))
|
||||
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
|
||||
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
|
||||
|
||||
def define_list_update(self):
|
||||
"""Defines a function to update the nth element of a list and return the updated list.
|
||||
|
||||
update(l, i, v) : list[a] -> nat -> a -> list[a]
|
||||
update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a]
|
||||
"""
|
||||
self.update = GlobalVar("update")
|
||||
a = TypeVar("a")
|
||||
l = Var("l", self.l(a))
|
||||
n = Var("n", self.nat())
|
||||
n = Var("n", scalar_type('int32'))
|
||||
v = Var("v", a)
|
||||
|
||||
y = Var("y")
|
||||
body = If(equal(n, const(0)),
|
||||
self.cons(v, self.tl(l)),
|
||||
self.cons(self.hd(l),
|
||||
self.update(self.tl(l),
|
||||
subtract(n, const(1)),
|
||||
v)))
|
||||
|
||||
z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
|
||||
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
|
||||
self.cons(self.hd(l), self.update(self.tl(l), y, v)))
|
||||
self.mod[self.update] = Function([l, n, v], body, self.l(a), [a])
|
||||
|
||||
self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
|
||||
|
||||
def define_list_map(self):
|
||||
"""Defines a function for mapping a function over a list's
|
||||
|
@ -114,6 +122,7 @@ class Prelude:
|
|||
self.cons(f(y), self.map(f, z)))
|
||||
self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b])
|
||||
|
||||
|
||||
def define_list_foldl(self):
|
||||
"""Defines a left-way fold over a list.
|
||||
|
||||
|
@ -136,6 +145,7 @@ class Prelude:
|
|||
self.mod[self.foldl] = Function([f, av, bv],
|
||||
Match(bv, [nil_case, cons_case]), a, [a, b])
|
||||
|
||||
|
||||
def define_list_foldr(self):
|
||||
"""Defines a right-way fold over a list.
|
||||
|
||||
|
@ -158,6 +168,7 @@ class Prelude:
|
|||
self.mod[self.foldr] = Function([f, bv, av],
|
||||
Match(av, [nil_case, cons_case]), b, [a, b])
|
||||
|
||||
|
||||
def define_list_foldr1(self):
|
||||
"""Defines a right-way fold over a nonempty list.
|
||||
|
||||
|
@ -196,6 +207,7 @@ class Prelude:
|
|||
self.foldr(updater, l2, l1),
|
||||
self.l(a), [a])
|
||||
|
||||
|
||||
def define_list_filter(self):
|
||||
"""Defines a function that filters a list.
|
||||
|
||||
|
@ -214,6 +226,7 @@ class Prelude:
|
|||
If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t)))
|
||||
self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a])
|
||||
|
||||
|
||||
def define_list_zip(self):
|
||||
"""Defines a function that combines two lists into a list of tuples of their elements.
|
||||
|
||||
|
@ -238,6 +251,7 @@ class Prelude:
|
|||
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
|
||||
self.l(TupleType([a, b])), [a, b])
|
||||
|
||||
|
||||
def define_list_rev(self):
|
||||
"""Defines a function that reverses a list.
|
||||
|
||||
|
@ -253,6 +267,7 @@ class Prelude:
|
|||
self.foldl(updater, self.nil(), l),
|
||||
self.l(a), [a])
|
||||
|
||||
|
||||
def define_list_map_accumr(self):
|
||||
"""Defines an accumulative map, which is a fold that simulataneously updates
|
||||
an accumulator value and a list of results.
|
||||
|
@ -282,6 +297,7 @@ class Prelude:
|
|||
TupleType([a, self.l(c)]),
|
||||
[a, b, c])
|
||||
|
||||
|
||||
def define_list_map_accuml(self):
|
||||
"""Defines an accumulative map, which is a fold that simulataneously updates
|
||||
an accumulator value and a list of results.
|
||||
|
@ -321,6 +337,7 @@ class Prelude:
|
|||
self.none = Constructor("none", [], self.optional)
|
||||
self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none])
|
||||
|
||||
|
||||
def define_list_unfoldr(self):
|
||||
"""Defines a function that builds up a list starting from a seed value.
|
||||
|
||||
|
@ -343,6 +360,7 @@ class Prelude:
|
|||
self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]),
|
||||
self.l(b), [a, b])
|
||||
|
||||
|
||||
def define_list_unfoldl(self):
|
||||
"""Defines a function that builds up a list starting from a seed value.
|
||||
|
||||
|
@ -362,52 +380,29 @@ class Prelude:
|
|||
self.rev(self.unfoldr(f, s)),
|
||||
self.l(b), [a, b])
|
||||
|
||||
def define_nat_adt(self):
|
||||
"""Defines a Peano (unary) natural number ADT.
|
||||
Zero is represented by z(). s(n) adds 1 to a nat n."""
|
||||
self.nat = GlobalTypeVar("nat")
|
||||
self.z = Constructor("z", [], self.nat)
|
||||
self.s = Constructor("s", [self.nat()], self.nat)
|
||||
self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s])
|
||||
|
||||
def define_nat_double(self):
|
||||
"""Defines a function that doubles a nat."""
|
||||
self.double = GlobalVar("double")
|
||||
x = Var("x", self.nat())
|
||||
y = Var("y")
|
||||
z_case = Clause(PatternConstructor(self.z), self.z())
|
||||
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
|
||||
self.s(self.s(self.double(y))))
|
||||
self.mod[self.double] = Function([x], Match(x, [z_case, s_case]))
|
||||
|
||||
def define_nat_add(self):
|
||||
"""Defines a function that adds two nats."""
|
||||
self.add = GlobalVar("add")
|
||||
x = Var("x", self.nat())
|
||||
y = Var("y", self.nat())
|
||||
a = Var("a")
|
||||
z_case = Clause(PatternConstructor(self.z), y)
|
||||
s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]),
|
||||
self.s(self.add(a, y)))
|
||||
self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case]))
|
||||
|
||||
def define_list_sum(self):
|
||||
"""Defines a function that computes the sum of a list of nats."""
|
||||
"""Defines a function that computes the sum of a list of integer scalars."""
|
||||
self.sum = GlobalVar("sum")
|
||||
a = Var("a", self.l(self.nat()))
|
||||
self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a))
|
||||
a = Var("a", self.l(scalar_type('int32')))
|
||||
x = Var('x')
|
||||
y = Var('y')
|
||||
addf = Function([x, y], add(x, y))
|
||||
self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a))
|
||||
|
||||
|
||||
def define_list_length(self):
|
||||
"""Defines a function that returns the length of a list as a nat"""
|
||||
"""Defines a function that returns the length of a list"""
|
||||
self.length = GlobalVar("length")
|
||||
a = TypeVar("a")
|
||||
x = Var("x", self.l(a))
|
||||
y = Var("y")
|
||||
nil_case = Clause(PatternConstructor(self.nil), self.z())
|
||||
nil_case = Clause(PatternConstructor(self.nil), const(0))
|
||||
cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]),
|
||||
self.s(self.length(y)))
|
||||
add(const(1), self.length(y)))
|
||||
self.mod[self.length] = Function([x],
|
||||
Match(x, [nil_case, cons_case]), None, [a])
|
||||
Match(x, [nil_case, cons_case]), scalar_type('int32'), [a])
|
||||
|
||||
|
||||
def define_tree_adt(self):
|
||||
"""Defines a tree ADT. A tree can contain any type.
|
||||
|
@ -420,6 +415,7 @@ class Prelude:
|
|||
self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree)
|
||||
self.mod[self.tree] = TypeData(self.tree, [a], [self.rose])
|
||||
|
||||
|
||||
def define_tree_map(self):
|
||||
"""Defines a function that maps over a tree. The function
|
||||
is applied to each subtree's contents.
|
||||
|
@ -439,23 +435,24 @@ class Prelude:
|
|||
self.mod[self.tmap] = Function([f, t],
|
||||
Match(t, [rose_case]), self.tree(b), [a, b])
|
||||
|
||||
def define_tree_size(self):
|
||||
"""Defines a function that computes the size of a tree as a nat.
|
||||
|
||||
Signature: fn<a>(t : tree[a]) -> nat
|
||||
def define_tree_size(self):
|
||||
"""Defines a function that computes the size of a tree.
|
||||
|
||||
Signature: fn<a>(t : tree[a]) -> Tensor[(), int32]
|
||||
"""
|
||||
self.size = GlobalVar("size")
|
||||
a = TypeVar("a")
|
||||
t = Var("t", self.tree(a))
|
||||
x = Var("x", self.tree(a))
|
||||
z = Var("z")
|
||||
rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]),
|
||||
self.s(self.sum(self.map(Function([x], self.size(x)), z))))
|
||||
add(const(1), self.sum(self.map(self.size, z))))
|
||||
self.mod[self.size] = Function([t],
|
||||
Match(t, [rose_case]), self.nat(), [a])
|
||||
Match(t, [rose_case]), scalar_type('int32'), [a])
|
||||
|
||||
|
||||
def define_id(self):
|
||||
"""Defines a function that return it's argument.
|
||||
"""Defines a function that return its argument.
|
||||
|
||||
Signature: fn<a>(x : a) -> a
|
||||
"""
|
||||
|
@ -466,7 +463,7 @@ class Prelude:
|
|||
|
||||
|
||||
def define_compose(self):
|
||||
"""Defines a function that compose two function.
|
||||
"""Defines a function that composes two function.
|
||||
|
||||
Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
|
||||
"""
|
||||
|
@ -484,24 +481,26 @@ class Prelude:
|
|||
|
||||
|
||||
def define_iterate(self):
|
||||
"""Define a function that take a number n, a function f,
|
||||
and return a closure that apply f n time on it's argument.
|
||||
"""Defines a function that take a number n and a function f;
|
||||
returns a closure that takes an argument and applies f
|
||||
n times to its argument.
|
||||
|
||||
Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
|
||||
Signature: fn<a>(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a
|
||||
"""
|
||||
self.iterate = GlobalVar("iterate")
|
||||
a = TypeVar("a")
|
||||
f = Var("f", FuncType([a], a))
|
||||
x = Var("x", self.nat())
|
||||
y = Var("y", self.nat())
|
||||
z_case = Clause(PatternConstructor(self.z), self.id)
|
||||
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
|
||||
self.compose(f, self.iterate(f, y)))
|
||||
x = Var("x", scalar_type('int32'))
|
||||
body = If(equal(x, const(0)),
|
||||
self.id,
|
||||
self.compose(f,
|
||||
self.iterate(f, subtract(x, const(1)))))
|
||||
self.mod[self.iterate] = Function([f, x],
|
||||
Match(x, [z_case, s_case]),
|
||||
body,
|
||||
FuncType([a], a),
|
||||
[a])
|
||||
|
||||
|
||||
def __init__(self, mod):
|
||||
self.mod = mod
|
||||
self.define_list_adt()
|
||||
|
@ -522,9 +521,6 @@ class Prelude:
|
|||
self.define_list_unfoldr()
|
||||
self.define_list_unfoldl()
|
||||
|
||||
self.define_nat_adt()
|
||||
self.define_nat_double()
|
||||
self.define_nat_add()
|
||||
self.define_list_length()
|
||||
self.define_list_nth()
|
||||
self.define_list_update()
|
||||
|
|
|
@ -30,3 +30,4 @@ from . import densenet
|
|||
|
||||
from .config import ctx_list
|
||||
from .init import create_workload
|
||||
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Defines a unary natural number (Peano natural number) abstract
|
||||
data type for Relay and provides some utility functions for it.
|
||||
Nats are useful for testing purposes, as they make it easy to write
|
||||
test cases for recursion and pattern matching."""
|
||||
|
||||
from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar
|
||||
from tvm.relay.backend.interpreter import ConstructorValue
|
||||
from tvm.relay.expr import Var, Function, GlobalVar
|
||||
from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
|
||||
|
||||
def define_nat_adt(prelude):
|
||||
"""Defines a Peano (unary) natural number ADT.
|
||||
Zero is represented by z(). s(n) adds 1 to a nat n.
|
||||
Adds the fields nat, z, and s to the preluide, representing
|
||||
(respectively) the nat ADT and the z and s constructors.
|
||||
"""
|
||||
prelude.nat = GlobalTypeVar("nat")
|
||||
prelude.z = Constructor("z", [], prelude.nat)
|
||||
prelude.s = Constructor("s", [prelude.nat()], prelude.nat)
|
||||
prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s])
|
||||
|
||||
|
||||
def define_nat_double(prelude):
|
||||
"""Defines a function that doubles a nat. Adds a field called
|
||||
'double' to the prelude, giving the GlobalVar pointing to
|
||||
the function.
|
||||
"""
|
||||
prelude.double = GlobalVar("double")
|
||||
x = Var("x", prelude.nat())
|
||||
y = Var("y")
|
||||
z_case = Clause(PatternConstructor(prelude.z), prelude.z())
|
||||
s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
|
||||
prelude.s(prelude.s(prelude.double(y))))
|
||||
prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case]))
|
||||
|
||||
|
||||
def define_nat_add(prelude):
|
||||
"""Defines a function that adds two nats and adds a field to the
|
||||
prelude 'add' giving the GlobalVar pointing to that function.
|
||||
"""
|
||||
prelude.add = GlobalVar("add")
|
||||
x = Var("x", prelude.nat())
|
||||
y = Var("y", prelude.nat())
|
||||
a = Var("a")
|
||||
z_case = Clause(PatternConstructor(prelude.z), y)
|
||||
s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]),
|
||||
prelude.s(prelude.add(a, y)))
|
||||
prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case]))
|
||||
|
||||
|
||||
# versions of prelude functions that use nats instead of scalars
|
||||
|
||||
def define_nat_nth(prelude):
|
||||
"""Defines a function to get the nth eleemnt of a list using
|
||||
a nat to index into the list.
|
||||
|
||||
nat_nth(l, n): fun<a>(list[a], nat) -> a
|
||||
"""
|
||||
prelude.nat_nth = GlobalVar("nat_nth")
|
||||
a = TypeVar("a")
|
||||
x = Var("x", prelude.l(a))
|
||||
n = Var("n", prelude.nat())
|
||||
y = Var("y")
|
||||
|
||||
z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x))
|
||||
s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
|
||||
prelude.nat_nth(prelude.tl(x), y))
|
||||
|
||||
prelude.mod[prelude.nat_nth] = Function([x, n],
|
||||
Match(n, [z_case, s_case]),
|
||||
a, [a])
|
||||
|
||||
|
||||
def define_nat_update(prelude):
|
||||
"""Defines a function to update the nth element of a list and return the updated list.
|
||||
|
||||
nat_update(l, i, v) : fun<a>(list[a], nat, a) -> list[a]
|
||||
"""
|
||||
prelude.nat_update = GlobalVar("nat_update")
|
||||
a = TypeVar("a")
|
||||
# pylint: disable=invalid-name
|
||||
l = Var("l", prelude.l(a))
|
||||
n = Var("n", prelude.nat())
|
||||
v = Var("v", a)
|
||||
y = Var("y")
|
||||
|
||||
z_case = Clause(PatternConstructor(prelude.z),
|
||||
prelude.cons(v, prelude.tl(l)))
|
||||
s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
|
||||
prelude.cons(
|
||||
prelude.hd(l),
|
||||
prelude.nat_update(prelude.tl(l), y, v)))
|
||||
|
||||
prelude.mod[prelude.nat_update] = Function([l, n, v],
|
||||
Match(n, [z_case, s_case]),
|
||||
prelude.l(a), [a])
|
||||
|
||||
|
||||
def define_nat_iterate(prelude):
|
||||
"""Defines a function that takes a number n and a function f;
|
||||
returns a closure that takes an argument and applies f
|
||||
n times to its argument.
|
||||
|
||||
Signature: fn<a>(fn(a) -> a, nat) -> fn(a) -> a
|
||||
"""
|
||||
prelude.nat_iterate = GlobalVar("nat_iterate")
|
||||
a = TypeVar("a")
|
||||
f = Var("f", FuncType([a], a))
|
||||
x = Var("x", prelude.nat())
|
||||
y = Var("y", prelude.nat())
|
||||
|
||||
z_case = Clause(PatternConstructor(prelude.z), prelude.id)
|
||||
s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
|
||||
prelude.compose(f, prelude.nat_iterate(f, y)))
|
||||
|
||||
prelude.mod[prelude.nat_iterate] = Function([f, x],
|
||||
Match(x, [z_case, s_case]),
|
||||
FuncType([a], a),
|
||||
[a])
|
||||
|
||||
|
||||
def add_nat_definitions(prelude):
|
||||
"""Given a Relay prelude, adds a Peano nat ADT, as well as functions
|
||||
for adding nats and doubling nats. It also adds versions of
|
||||
update, nth, and iterate that take nats instead of scalars (the
|
||||
names are prefixed with 'nat_')."""
|
||||
define_nat_adt(prelude)
|
||||
define_nat_double(prelude)
|
||||
define_nat_add(prelude)
|
||||
define_nat_nth(prelude)
|
||||
define_nat_update(prelude)
|
||||
define_nat_iterate(prelude)
|
||||
|
||||
|
||||
# helper functions for working with nats
|
||||
|
||||
|
||||
def count(n):
|
||||
"""Takes a ConstructorValue corresponding to a nat ADT
|
||||
and converts it into a Python integer. This is an example of
|
||||
using an ADT value in Python.
|
||||
"""
|
||||
assert isinstance(n, ConstructorValue)
|
||||
if n.constructor.name_hint == 'z':
|
||||
return 0
|
||||
assert n.constructor.name_hint == 's'
|
||||
return 1 + count(n.fields[0])
|
||||
|
||||
|
||||
def make_nat_value(prelude, n):
|
||||
"""The inverse of count(): Given a non-negative Python integer,
|
||||
constructs a ConstructorValue representing that value as a nat.
|
||||
"""
|
||||
if n == 0:
|
||||
return ConstructorValue(prelude.z, [], [])
|
||||
return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
|
||||
|
||||
|
||||
def make_nat_expr(prelude, n):
|
||||
"""Given a non-negative Python integer, constructs a Python
|
||||
expression representing that integer's value as a nat.
|
||||
"""
|
||||
assert n >= 0
|
||||
ret = prelude.z()
|
||||
while n > 0:
|
||||
ret = prelude.s(ret)
|
||||
n = n - 1
|
||||
return ret
|
|
@ -14,15 +14,19 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
||||
from tvm.relay import testing, create_executor
|
||||
from tvm.relay.prelude import Prelude
|
||||
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
|
||||
|
||||
mod = relay.Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
|
||||
ctx = tvm.context("llvm", 0)
|
||||
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
|
||||
|
||||
|
@ -67,15 +71,6 @@ size = p.size
|
|||
compose = p.compose
|
||||
iterate = p.iterate
|
||||
|
||||
# this is an example of using the adt value in python side
|
||||
def count(n):
|
||||
assert isinstance(n, ConstructorValue)
|
||||
if n.constructor.name_hint == 's':
|
||||
return 1 + count(n.fields[0])
|
||||
else:
|
||||
assert n.constructor.name_hint == 'z'
|
||||
return 0
|
||||
|
||||
# this is an example of creating the adt value in python side
|
||||
def make_nat(n):
|
||||
if n != 0:
|
||||
|
@ -83,7 +78,7 @@ def make_nat(n):
|
|||
else:
|
||||
return ConstructorValue(z, [], [])
|
||||
|
||||
def build_nat(n):
|
||||
def make_nat_expr(n):
|
||||
assert n >= 0
|
||||
ret = z()
|
||||
while n > 0:
|
||||
|
@ -115,8 +110,14 @@ def tree_to_dict(t):
|
|||
ret['children'].append(l)
|
||||
return ret
|
||||
|
||||
|
||||
# turns a scalar-valued relay tensor value into a python number
|
||||
def get_scalar(tv):
|
||||
return tv.asnumpy().item()
|
||||
|
||||
|
||||
def test_nat_value():
|
||||
assert count(make_nat(10)) == 10
|
||||
assert count(make_nat_value(p, 10)) == 10
|
||||
assert count(intrp.evaluate(s(s(z())))) == 2
|
||||
|
||||
|
||||
|
@ -145,7 +146,7 @@ def test_hd_tl():
|
|||
expected = list(range(10))
|
||||
l = nil()
|
||||
for i in reversed(expected):
|
||||
l = cons(build_nat(i), l)
|
||||
l = cons(make_nat_expr(i), l)
|
||||
|
||||
got = []
|
||||
for i in range(len(expected)):
|
||||
|
@ -158,36 +159,35 @@ def test_nth():
|
|||
expected = list(range(10))
|
||||
l = nil()
|
||||
for i in reversed(expected):
|
||||
l = cons(build_nat(i), l)
|
||||
l = cons(relay.const(i), l)
|
||||
|
||||
got = []
|
||||
for i in range(len(expected)):
|
||||
got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
|
||||
item = intrp.evaluate(nth(l, relay.const(i)))
|
||||
assert get_scalar(item) == i
|
||||
|
||||
assert got == expected
|
||||
|
||||
def test_update():
|
||||
expected = list(range(10))
|
||||
l = nil()
|
||||
# create zero initialized list
|
||||
for i in range(len(expected)):
|
||||
l = cons(build_nat(0), l)
|
||||
l = cons(make_nat_expr(0), l)
|
||||
|
||||
# set value
|
||||
for i, v in enumerate(expected):
|
||||
l = update(l, build_nat(i), build_nat(v))
|
||||
l = update(l, relay.const(i), make_nat_expr(v))
|
||||
|
||||
got = []
|
||||
for i in range(len(expected)):
|
||||
got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
|
||||
got.append(count(intrp.evaluate(nth(l, relay.const(i)))))
|
||||
|
||||
assert got == expected
|
||||
|
||||
def test_length():
|
||||
a = relay.TypeVar("a")
|
||||
assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])
|
||||
assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
|
||||
res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
|
||||
assert count(res) == 3
|
||||
assert get_scalar(res) == 3
|
||||
|
||||
|
||||
def test_map():
|
||||
|
@ -216,9 +216,9 @@ def test_foldl():
|
|||
y = relay.Var("y")
|
||||
rev_dup = relay.Function([y, x], cons(x, cons(x, y)))
|
||||
res = intrp.evaluate(foldl(rev_dup, nil(),
|
||||
cons(build_nat(1),
|
||||
cons(build_nat(2),
|
||||
cons(build_nat(3), nil())))))
|
||||
cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(2),
|
||||
cons(make_nat_expr(3), nil())))))
|
||||
reversed = to_list(res)
|
||||
assert len(reversed) == 6
|
||||
assert count(reversed[0]) == 3 and count(reversed[1]) == 3
|
||||
|
@ -237,9 +237,9 @@ def test_foldr():
|
|||
y = relay.Var("y")
|
||||
identity = relay.Function([x, y], cons(x, y))
|
||||
res = intrp.evaluate(foldr(identity, nil(),
|
||||
cons(build_nat(1),
|
||||
cons(build_nat(2),
|
||||
cons(build_nat(3), nil())))))
|
||||
cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(2),
|
||||
cons(make_nat_expr(3), nil())))))
|
||||
same = to_list(res)
|
||||
assert len(same) == 3
|
||||
assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
|
||||
|
@ -255,25 +255,25 @@ def test_foldr1():
|
|||
y = relay.Var("y")
|
||||
f = relay.Function([x, y], add(x, y))
|
||||
res = intrp.evaluate(foldr1(f,
|
||||
cons(build_nat(1),
|
||||
cons(build_nat(2),
|
||||
cons(build_nat(3), nil())))))
|
||||
cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(2),
|
||||
cons(make_nat_expr(3), nil())))))
|
||||
|
||||
assert count(res) == 6
|
||||
|
||||
|
||||
def test_sum():
|
||||
assert mod[sum].checked_type == relay.FuncType([l(nat())], nat())
|
||||
res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil()))))
|
||||
assert count(res) == 3
|
||||
assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
|
||||
res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
|
||||
assert get_scalar(res) == 3
|
||||
|
||||
|
||||
def test_concat():
|
||||
a = relay.TypeVar("a")
|
||||
assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a])
|
||||
|
||||
l1 = cons(build_nat(1), cons(build_nat(2), nil()))
|
||||
l2 = cons(build_nat(3), cons(build_nat(4), nil()))
|
||||
l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil()))
|
||||
l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil()))
|
||||
res = intrp.evaluate(concat(l1, l2))
|
||||
|
||||
catted = to_list(res)
|
||||
|
@ -305,12 +305,12 @@ def test_filter():
|
|||
]))
|
||||
res = intrp.evaluate(
|
||||
filter(greater_than_one,
|
||||
cons(build_nat(1),
|
||||
cons(build_nat(1),
|
||||
cons(build_nat(3),
|
||||
cons(build_nat(1),
|
||||
cons(build_nat(5),
|
||||
cons(build_nat(1),
|
||||
cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(3),
|
||||
cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(5),
|
||||
cons(make_nat_expr(1),
|
||||
nil()))))))))
|
||||
filtered = to_list(res)
|
||||
assert len(filtered) == 2
|
||||
|
@ -325,7 +325,7 @@ def test_zip():
|
|||
l(relay.TupleType([a, b])), [a, b])
|
||||
assert mod[zip].checked_type == expected_type
|
||||
|
||||
l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
|
||||
l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
|
||||
l2 = cons(nil(),
|
||||
cons(cons(nil(), nil()),
|
||||
cons(cons(nil(), cons(nil(), nil())),
|
||||
|
@ -342,7 +342,7 @@ def test_zip():
|
|||
assert len(to_list(zipped[2][1])) == 2
|
||||
|
||||
# test truncation
|
||||
l3 = cons(build_nat(4), cons(build_nat(5), nil()))
|
||||
l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil()))
|
||||
shorter_res = intrp.evaluate(zip(l3, l2))
|
||||
truncated = to_list(shorter_res)
|
||||
assert len(truncated) == 2
|
||||
|
@ -363,9 +363,9 @@ def test_rev():
|
|||
a = relay.TypeVar("a")
|
||||
assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a])
|
||||
|
||||
res = intrp.evaluate(rev(cons(build_nat(1),
|
||||
cons(build_nat(2),
|
||||
cons(build_nat(3), nil())))))
|
||||
res = intrp.evaluate(rev(cons(make_nat_expr(1),
|
||||
cons(make_nat_expr(2),
|
||||
cons(make_nat_expr(3), nil())))))
|
||||
reversed = to_list(res)
|
||||
|
||||
assert len(reversed) == 3
|
||||
|
@ -392,7 +392,7 @@ def test_unfoldr():
|
|||
relay.Clause(relay.PatternConstructor(z, []), none())
|
||||
]))
|
||||
|
||||
res = intrp.evaluate(unfoldr(count_down, build_nat(3)))
|
||||
res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3)))
|
||||
unfolded = to_list(res)
|
||||
|
||||
assert len(unfolded) == 3
|
||||
|
@ -419,7 +419,7 @@ def test_unfoldl():
|
|||
relay.Clause(relay.PatternConstructor(z, []), none())
|
||||
]))
|
||||
|
||||
res = intrp.evaluate(unfoldl(count_down, build_nat(3)))
|
||||
res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3)))
|
||||
unfolded = to_list(res)
|
||||
|
||||
assert len(unfolded) == 3
|
||||
|
@ -444,7 +444,7 @@ def test_map_accumr():
|
|||
relay.Tuple([add(x, acc),
|
||||
add(x, acc)]))
|
||||
|
||||
vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
|
||||
vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
|
||||
res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals))
|
||||
|
||||
sum = count(res[0])
|
||||
|
@ -472,7 +472,7 @@ def test_map_accuml():
|
|||
add_to_acc = relay.Function([acc, x],
|
||||
relay.Tuple([add(x, acc), x]))
|
||||
|
||||
vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
|
||||
vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
|
||||
res = intrp.evaluate(map_accuml(add_to_acc, z(), vals))
|
||||
|
||||
sum = count(res[0])
|
||||
|
@ -497,8 +497,8 @@ def test_optional_matching():
|
|||
]))
|
||||
|
||||
res = intrp.evaluate(foldr(condense, nil(), cons(
|
||||
some(build_nat(3)),
|
||||
cons(none(), cons(some(build_nat(1)), nil())))))
|
||||
some(make_nat_expr(3)),
|
||||
cons(none(), cons(some(make_nat_expr(1)), nil())))))
|
||||
|
||||
reduced = to_list(res)
|
||||
assert len(reduced) == 2
|
||||
|
@ -532,7 +532,7 @@ def test_tmap():
|
|||
def test_size():
|
||||
a = relay.TypeVar("a")
|
||||
lhs = mod[size].checked_type
|
||||
rhs = relay.FuncType([tree(a)], nat(), [a])
|
||||
rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a])
|
||||
assert lhs == rhs
|
||||
|
||||
root = rose(z(), cons(rose(z(), nil()),
|
||||
|
@ -540,7 +540,7 @@ def test_size():
|
|||
nil())))
|
||||
t = rose(z(), cons(root, cons(root, cons(root, nil()))))
|
||||
res = intrp.evaluate(size(t))
|
||||
assert count(res) == 10
|
||||
assert get_scalar(res) == 10
|
||||
|
||||
|
||||
def test_wildcard_match_solo():
|
||||
|
@ -601,10 +601,10 @@ def test_nested_matches():
|
|||
inner_match)
|
||||
]), l(a), [a])
|
||||
|
||||
first_list = cons(build_nat(1), cons(build_nat(2),
|
||||
cons(build_nat(3), nil())))
|
||||
second_list = cons(build_nat(4), cons(build_nat(5),
|
||||
cons(build_nat(6), nil())))
|
||||
first_list = cons(make_nat_expr(1), cons(make_nat_expr(2),
|
||||
cons(make_nat_expr(3), nil())))
|
||||
second_list = cons(make_nat_expr(4), cons(make_nat_expr(5),
|
||||
cons(make_nat_expr(6), nil())))
|
||||
final_list = cons(first_list, cons(second_list, nil()))
|
||||
|
||||
res = intrp.evaluate(flatten(final_list))
|
||||
|
@ -660,6 +660,7 @@ def test_nested_pattern_match():
|
|||
|
||||
assert count(res) == 2
|
||||
|
||||
|
||||
def test_compose():
|
||||
n = relay.Var('n')
|
||||
inc = relay.Function([n], s(n))
|
||||
|
@ -667,11 +668,13 @@ def test_compose():
|
|||
res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
|
||||
assert count(res) == 5
|
||||
|
||||
|
||||
def test_iterate():
|
||||
expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)])
|
||||
expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)])
|
||||
res = intrp.evaluate(relay.Function([], expr)())
|
||||
assert count(res) == 12
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nat_constructor()
|
||||
test_double()
|
||||
|
|
|
@ -53,10 +53,12 @@ def test_adt():
|
|||
mod = relay.Module()
|
||||
p = Prelude(mod)
|
||||
x = relay.Var("x")
|
||||
s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x)
|
||||
some_case = relay.Clause(relay.PatternConstructor(p.some,
|
||||
[relay.PatternVar(x)]),
|
||||
x)
|
||||
default_case = relay.Clause(relay.PatternVar(x), x)
|
||||
m0 = relay.Match(p.z(), [default_case])
|
||||
m1 = relay.Match(p.z(), [s_case, default_case])
|
||||
m0 = relay.Match(p.none(), [default_case])
|
||||
m1 = relay.Match(p.none(), [some_case, default_case])
|
||||
assert well_formed(m0)
|
||||
assert not well_formed(m1)
|
||||
|
||||
|
|
|
@ -521,7 +521,7 @@ def test_match_alpha_equal():
|
|||
relay.PatternVar(a)]),
|
||||
p.cons(z, a))
|
||||
|
||||
data = p.cons(p.z(), p.cons(p.z(), p.nil()))
|
||||
data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))
|
||||
|
||||
match = relay.Match(data, [nil_case, cons_case])
|
||||
equivalent = relay.Match(data, [nil_case, equivalent_cons])
|
||||
|
@ -547,8 +547,8 @@ def test_match_alpha_equal():
|
|||
relay.Clause(relay.PatternWildcard(), p.nil())
|
||||
])
|
||||
wrong_constructors = relay.Match(data, [
|
||||
relay.Clause(relay.PatternConstructor(p.z), p.nil()),
|
||||
relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]),
|
||||
relay.Clause(relay.PatternConstructor(p.none), p.nil()),
|
||||
relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
|
||||
p.cons(x, p.nil()))
|
||||
])
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from tvm import relay
|
|||
from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
|
||||
from tvm.relay import create_executor
|
||||
from tvm.relay.prelude import Prelude
|
||||
from tvm.relay.testing import add_nat_definitions, make_nat_expr
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -174,13 +175,14 @@ def test_tuple():
|
|||
def test_pow():
|
||||
mod = relay.Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
shape = (10, 10)
|
||||
dtype = 'float32'
|
||||
t = relay.TensorType(shape, dtype)
|
||||
x = relay.var("x", t)
|
||||
double = relay.Function([x], x + x)
|
||||
i = relay.var("i", t)
|
||||
func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i]))
|
||||
func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
|
||||
back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
|
||||
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
|
||||
i_nd = rand(dtype, *shape)
|
||||
|
|
|
@ -21,6 +21,7 @@ from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type
|
|||
from tvm.relay import op, create_executor
|
||||
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
||||
from tvm.relay.prelude import Prelude
|
||||
from tvm.relay.testing import add_nat_definitions, count
|
||||
|
||||
|
||||
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
|
||||
|
@ -130,19 +131,10 @@ def test_ref():
|
|||
check_eval(to_a_normal_form(body), 3)
|
||||
|
||||
|
||||
# this is an example of using the adt value in python side
|
||||
def count(n):
|
||||
assert isinstance(n, ConstructorValue)
|
||||
if n.constructor.name_hint == 's':
|
||||
return 1 + count(n.fields[0])
|
||||
else:
|
||||
assert n.constructor.name_hint == 'z'
|
||||
return 0
|
||||
|
||||
|
||||
def test_add():
|
||||
def test_nat_add():
|
||||
mod = relay.Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
nat = p.nat
|
||||
add = p.add
|
||||
s = p.s
|
||||
|
@ -183,4 +175,5 @@ if __name__ == '__main__':
|
|||
test_ref()
|
||||
test_add()
|
||||
test_let()
|
||||
test_nat_add()
|
||||
test_function()
|
||||
|
|
Загрузка…
Ссылка в новой задаче