Support export ADT value in Python (#3299)
* Support export ADT value in Python * Cache original functions * Cleanup * Cleanup
This commit is contained in:
Родитель
b67afcd6b9
Коммит
713fc73bda
|
@ -182,17 +182,22 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
|
||||||
class ConstructorValue;
|
class ConstructorValue;
|
||||||
|
|
||||||
struct ConstructorValueNode : ValueNode {
|
struct ConstructorValueNode : ValueNode {
|
||||||
Constructor constructor;
|
int tag;
|
||||||
|
|
||||||
tvm::Array<Value> fields;
|
tvm::Array<Value> fields;
|
||||||
|
|
||||||
|
/*! \brief Optional field tracking ADT constructor. */
|
||||||
|
Constructor constructor;
|
||||||
|
|
||||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||||
v->Visit("constructor", &constructor);
|
v->Visit("tag", &tag);
|
||||||
v->Visit("fields", &fields);
|
v->Visit("fields", &fields);
|
||||||
|
v->Visit("constructor", &constructor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_DLL static ConstructorValue make(Constructor constructor,
|
TVM_DLL static ConstructorValue make(int tag,
|
||||||
tvm::Array<Value> fields);
|
tvm::Array<Value> fields,
|
||||||
|
Constructor construtor = {});
|
||||||
|
|
||||||
static constexpr const char* _type_key = "relay.ConstructorValue";
|
static constexpr const char* _type_key = "relay.ConstructorValue";
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
|
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
|
||||||
|
|
|
@ -73,9 +73,9 @@ class Closure(Value):
|
||||||
|
|
||||||
@register_relay_node
|
@register_relay_node
|
||||||
class ConstructorValue(Value):
|
class ConstructorValue(Value):
|
||||||
def __init__(self, constructor, fields, types):
|
def __init__(self, tag, fields, constructor, types):
|
||||||
self.__init_handle_by_constructor__(
|
self.__init_handle_by_constructor__(
|
||||||
_make.ConstructorValue, constructor, fields, types)
|
_make.ConstructorValue, tag, fields, constructor, types)
|
||||||
|
|
||||||
|
|
||||||
@register_relay_node
|
@register_relay_node
|
||||||
|
|
|
@ -97,7 +97,6 @@ def _eval_vm(mod, ctx, *args):
|
||||||
args: List[tvm.NDArray, np.ndarray]
|
args: List[tvm.NDArray, np.ndarray]
|
||||||
The arguments to evaluate.
|
The arguments to evaluate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mod = optimize(mod)
|
mod = optimize(mod)
|
||||||
args = list(args)
|
args = list(args)
|
||||||
assert isinstance(args, list)
|
assert isinstance(args, list)
|
||||||
|
|
|
@ -491,7 +491,6 @@ class Prelude:
|
||||||
def __init__(self, mod):
|
def __init__(self, mod):
|
||||||
self.mod = mod
|
self.mod = mod
|
||||||
self.load_prelude()
|
self.load_prelude()
|
||||||
|
|
||||||
self.define_list_adt()
|
self.define_list_adt()
|
||||||
self.define_list_hd()
|
self.define_list_hd()
|
||||||
self.define_list_tl()
|
self.define_list_tl()
|
||||||
|
|
|
@ -151,16 +151,16 @@ def add_nat_definitions(prelude):
|
||||||
# helper functions for working with nats
|
# helper functions for working with nats
|
||||||
|
|
||||||
|
|
||||||
def count(n):
|
def count(prelude, n):
|
||||||
"""Takes a ConstructorValue corresponding to a nat ADT
|
"""Takes a ConstructorValue corresponding to a nat ADT
|
||||||
and converts it into a Python integer. This is an example of
|
and converts it into a Python integer. This is an example of
|
||||||
using an ADT value in Python.
|
using an ADT value in Python.
|
||||||
"""
|
"""
|
||||||
assert isinstance(n, ConstructorValue)
|
assert isinstance(n, ConstructorValue)
|
||||||
if n.constructor.name_hint == 'z':
|
if n.tag == prelude.z.tag:
|
||||||
return 0
|
return 0
|
||||||
assert n.constructor.name_hint == 's'
|
assert n.tag == prelude.s.tag
|
||||||
return 1 + count(n.fields[0])
|
return 1 + count(prelude, n.fields[0])
|
||||||
|
|
||||||
|
|
||||||
def make_nat_value(prelude, n):
|
def make_nat_value(prelude, n):
|
||||||
|
@ -168,8 +168,8 @@ def make_nat_value(prelude, n):
|
||||||
constructs a ConstructorValue representing that value as a nat.
|
constructs a ConstructorValue representing that value as a nat.
|
||||||
"""
|
"""
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return ConstructorValue(prelude.z, [], [])
|
return ConstructorValue(prelude.z.tag, [], None, [])
|
||||||
return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
|
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
|
||||||
|
|
||||||
|
|
||||||
def make_nat_expr(prelude, n):
|
def make_nat_expr(prelude, n):
|
||||||
|
|
|
@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
p->stream << "RefValueNode(" << node->value << ")";
|
p->stream << "RefValueNode(" << node->value << ")";
|
||||||
});
|
});
|
||||||
|
|
||||||
ConstructorValue ConstructorValueNode::make(Constructor constructor,
|
ConstructorValue ConstructorValueNode::make(int tag,
|
||||||
tvm::Array<Value> fields) {
|
tvm::Array<Value> fields,
|
||||||
|
Constructor constructor) {
|
||||||
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
|
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
|
||||||
n->constructor = constructor;
|
n->tag = tag;
|
||||||
n->fields = fields;
|
n->fields = fields;
|
||||||
|
n->constructor = constructor;
|
||||||
return ConstructorValue(n);
|
return ConstructorValue(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
|
||||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
|
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
|
||||||
tvm::IRPrinter* p) {
|
tvm::IRPrinter* p) {
|
||||||
p->stream << "ConstructorValueNode(" << node->constructor
|
p->stream << "ConstructorValueNode(" << node->tag << ","
|
||||||
<< node->fields << ")";
|
<< node->fields << ")";
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -448,7 +450,7 @@ class Interpreter :
|
||||||
"fusing and lowering";
|
"fusing and lowering";
|
||||||
}
|
}
|
||||||
if (auto con = call->op.as<ConstructorNode>()) {
|
if (auto con = call->op.as<ConstructorNode>()) {
|
||||||
return ConstructorValueNode::make(GetRef<Constructor>(con), args);
|
return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
|
||||||
}
|
}
|
||||||
// Now we just evaluate and expect to find a closure.
|
// Now we just evaluate and expect to find a closure.
|
||||||
Value fn_val = Eval(call->op);
|
Value fn_val = Eval(call->op);
|
||||||
|
@ -544,9 +546,8 @@ class Interpreter :
|
||||||
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
|
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
|
||||||
CHECK(cvn) << "need to be a constructor for match";
|
CHECK(cvn) << "need to be a constructor for match";
|
||||||
CHECK_NE(op->constructor->tag, -1);
|
CHECK_NE(op->constructor->tag, -1);
|
||||||
CHECK_NE(cvn->constructor->tag, -1);
|
CHECK_NE(cvn->tag, -1);
|
||||||
if (op->constructor->tag == cvn->constructor->tag) {
|
if (op->constructor->tag == cvn->tag) {
|
||||||
// todo(M.K.): should use ptr equality but it is broken
|
|
||||||
CHECK_EQ(op->patterns.size(), cvn->fields.size());
|
CHECK_EQ(op->patterns.size(), cvn->fields.size());
|
||||||
for (size_t i = 0; i < op->patterns.size(); ++i) {
|
for (size_t i = 0; i < op->patterns.size(); ++i) {
|
||||||
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
|
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
|
||||||
|
|
|
@ -80,6 +80,8 @@ struct VMCompilerContext {
|
||||||
ConstTensorShapeMap const_tensor_shape_map;
|
ConstTensorShapeMap const_tensor_shape_map;
|
||||||
// List of lowered functions
|
// List of lowered functions
|
||||||
std::vector<LoweredFunc> lowered_funcs;
|
std::vector<LoweredFunc> lowered_funcs;
|
||||||
|
// The functions that have been lowered.
|
||||||
|
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Compute the constant pool, i.e a mapping from Constant node to constant index.
|
// Compute the constant pool, i.e a mapping from Constant node to constant index.
|
||||||
|
@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
size_t registers_num;
|
size_t registers_num;
|
||||||
CompileEngine engine;
|
CompileEngine engine;
|
||||||
|
|
||||||
/*! \brief The functions that have been lowered. */
|
|
||||||
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
|
|
||||||
|
|
||||||
/*! \brief Global shared meta data */
|
/*! \brief Global shared meta data */
|
||||||
VMCompilerContext* context;
|
VMCompilerContext* context;
|
||||||
|
|
||||||
|
@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
|
|
||||||
void VisitExpr_(const MatchNode* match_node) {
|
void VisitExpr_(const MatchNode* match_node) {
|
||||||
auto match = GetRef<Match>(match_node);
|
auto match = GetRef<Match>(match_node);
|
||||||
LOG(FATAL) << "translation of match nodes to the VM is"
|
LOG(FATAL) << "translation of match nodes to the VM is "
|
||||||
<< "currently unsupported" << std::endl;
|
<< "currently unsupported" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
}
|
}
|
||||||
|
|
||||||
void VisitExpr_(const GlobalVarNode* gvar) {
|
void VisitExpr_(const GlobalVarNode* gvar) {
|
||||||
LOG(FATAL) << "Global variables should only appear in the call position";
|
// TODO(wweic): Support Load GlobalVar into a register
|
||||||
|
LOG(FATAL) << "Loading GlobalVar into register is not yet supported";
|
||||||
}
|
}
|
||||||
|
|
||||||
void VisitExpr_(const IfNode* if_node) {
|
void VisitExpr_(const IfNode* if_node) {
|
||||||
|
@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
// TODO(jroesch): support lowered funcs for multiple targets
|
// TODO(jroesch): support lowered funcs for multiple targets
|
||||||
CHECK_EQ(cfunc->funcs.size(), 1);
|
CHECK_EQ(cfunc->funcs.size(), 1);
|
||||||
auto op_index = -1;
|
auto op_index = -1;
|
||||||
if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
|
if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
|
||||||
op_index = this->context->lowered_funcs.size();
|
op_index = this->context->lowered_funcs.size();
|
||||||
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
|
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
|
||||||
seen_funcs[cfunc->funcs[0]] = op_index;
|
this->context->seen_funcs[cfunc->funcs[0]] = op_index;
|
||||||
} else {
|
} else {
|
||||||
op_index = seen_funcs[cfunc->funcs[0]];
|
op_index = this->context->seen_funcs[cfunc->funcs[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
|
Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
|
||||||
|
@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
std::vector<Index> args_registers;
|
std::vector<Index> args_registers;
|
||||||
|
|
||||||
for (auto arg : call_node->args) {
|
for (auto arg : call_node->args) {
|
||||||
CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
|
|
||||||
this->VisitExpr(arg);
|
this->VisitExpr(arg);
|
||||||
args_registers.push_back(last_register);
|
args_registers.push_back(last_register);
|
||||||
}
|
}
|
||||||
|
@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
auto func = this->context->module->Lookup(global);
|
auto func = this->context->module->Lookup(global);
|
||||||
if (IsClosure(func)) {
|
if (IsClosure(func)) {
|
||||||
auto arity = func->params.size();
|
auto arity = func->params.size();
|
||||||
std::vector<Index> free_var_registers;
|
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
|
||||||
for (size_t i = 0; i < arity; ++i) {
|
|
||||||
free_var_registers.push_back(var_register_map.at(func->params[i]));
|
|
||||||
}
|
|
||||||
Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
|
|
||||||
} else {
|
} else {
|
||||||
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
|
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
|
||||||
}
|
}
|
||||||
} else if (auto constructor_node = op.as<ConstructorNode>()) {
|
} else if (auto constructor_node = op.as<ConstructorNode>()) {
|
||||||
auto constructor = GetRef<Constructor>(constructor_node);
|
auto constructor = GetRef<Constructor>(constructor_node);
|
||||||
auto tag = GetConstructorTag(constructor);
|
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
|
||||||
Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
|
NewRegister()));
|
||||||
} else if (auto var_node = op.as<VarNode>()) {
|
} else if (auto var_node = op.as<VarNode>()) {
|
||||||
VisitExpr(GetRef<Var>(var_node));
|
VisitExpr(GetRef<Var>(var_node));
|
||||||
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
|
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
|
||||||
|
@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetConstructorTag(tvm::relay::Constructor constructor) {
|
|
||||||
auto it = this->context->tag_map.find(constructor);
|
|
||||||
if (it != this->context->tag_map.end()) {
|
|
||||||
return it->second;
|
|
||||||
} else {
|
|
||||||
auto tag = this->context->tag_map.size();
|
|
||||||
this->context->tag_map[constructor] = tag;
|
|
||||||
this->context->tag_index_map[tag] = constructor;
|
|
||||||
return tag;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void VisitExpr_(const FunctionNode* func_node) {
|
void VisitExpr_(const FunctionNode* func_node) {
|
||||||
if (!func_node->IsPrimitive()) {
|
if (!func_node->IsPrimitive()) {
|
||||||
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
|
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
|
||||||
|
@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
|
||||||
}
|
}
|
||||||
|
|
||||||
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
|
VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
|
||||||
DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
|
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl;
|
||||||
size_t params = func->params.size();
|
size_t params = func->params.size();
|
||||||
VMCompiler compiler(context);
|
VMCompiler compiler(context);
|
||||||
compiler.Compile(func);
|
compiler.Compile(func);
|
||||||
|
|
|
@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
|
Value VMToValue(const relay::Module& module, Object obj) {
|
||||||
CHECK(module.defined() && type.defined());
|
CHECK(module.defined());
|
||||||
switch (obj->tag) {
|
switch (obj->tag) {
|
||||||
case ObjectTag::kTensor: {
|
case ObjectTag::kTensor: {
|
||||||
CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
|
|
||||||
return TensorValueNode::make(ToNDArray(obj));
|
return TensorValueNode::make(ToNDArray(obj));
|
||||||
}
|
}
|
||||||
case ObjectTag::kDatatype: {
|
case ObjectTag::kDatatype: {
|
||||||
// const auto* tuple_type
|
const auto& data_type = obj.AsDatatype();
|
||||||
// const auto& data_type = obj.AsDatatype();
|
|
||||||
|
|
||||||
// tvm::Array<Value> fields;
|
tvm::Array<Value> fields;
|
||||||
// for (size_t i = 0; i < data_type->fields.size(); ++i) {
|
for (size_t i = 0; i < data_type->fields.size(); ++i) {
|
||||||
// fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
|
fields.push_back(VMToValue(module, data_type->fields[i]));
|
||||||
// }
|
}
|
||||||
|
|
||||||
// return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
|
return ConstructorValueNode::make(data_type->tag, fields);
|
||||||
LOG(FATAL) << "fix me";
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
|
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
|
||||||
|
@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
|
||||||
LOG(FATAL) << "expected function or module";
|
LOG(FATAL) << "expected function or module";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto return_type = module->Lookup(module->entry_func)->ret_type;
|
|
||||||
|
|
||||||
std::vector<Object> vm_args;
|
std::vector<Object> vm_args;
|
||||||
for (auto i = 3; i < args.size(); i++) {
|
for (auto i = 3; i < args.size(); i++) {
|
||||||
Object obj = args[i];
|
Object obj = args[i];
|
||||||
|
@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
|
||||||
|
|
||||||
auto result = EvaluateModule(module, {ctx}, vm_args);
|
auto result = EvaluateModule(module, {ctx}, vm_args);
|
||||||
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
|
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
|
||||||
*ret = VMToValue(module, return_type, result);
|
*ret = VMToValue(module, result);
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace vm
|
} // namespace vm
|
||||||
|
|
|
@ -316,7 +316,8 @@ Module FunctionPassNode::operator()(const Module& mod,
|
||||||
Module updated_mod = mod;
|
Module updated_mod = mod;
|
||||||
// Execute the pass function and return a new module.
|
// Execute the pass function and return a new module.
|
||||||
std::vector<std::pair<GlobalVar, Function> > updates;
|
std::vector<std::pair<GlobalVar, Function> > updates;
|
||||||
for (const auto& it : mod->functions) {
|
auto original = mod->functions;
|
||||||
|
for (const auto& it : original) {
|
||||||
auto updated_func = SkipFunction(it.second)
|
auto updated_func = SkipFunction(it.second)
|
||||||
? it.second
|
? it.second
|
||||||
: pass_func(it.second, updated_mod, pass_ctx);
|
: pass_func(it.second, updated_mod, pass_ctx);
|
||||||
|
|
|
@ -21,12 +21,15 @@ from tvm.relay.ir_pass import infer_type
|
||||||
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
||||||
from tvm.relay import testing, create_executor
|
from tvm.relay import testing, create_executor
|
||||||
from tvm.relay.prelude import Prelude
|
from tvm.relay.prelude import Prelude
|
||||||
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
|
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
|
||||||
|
|
||||||
mod = relay.Module()
|
mod = relay.Module()
|
||||||
p = Prelude(mod)
|
p = Prelude(mod)
|
||||||
add_nat_definitions(p)
|
add_nat_definitions(p)
|
||||||
|
|
||||||
|
def count(e):
|
||||||
|
return count_(p, e)
|
||||||
|
|
||||||
ctx = tvm.context("llvm", 0)
|
ctx = tvm.context("llvm", 0)
|
||||||
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
|
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
|
||||||
|
|
||||||
|
@ -91,18 +94,18 @@ def to_list(l):
|
||||||
val = l
|
val = l
|
||||||
ret = []
|
ret = []
|
||||||
while True:
|
while True:
|
||||||
if val.constructor.name_hint == 'cons':
|
if val.tag == p.cons.tag:
|
||||||
ret.append(val.fields[0])
|
ret.append(val.fields[0])
|
||||||
val = val.fields[1]
|
val = val.fields[1]
|
||||||
else:
|
else:
|
||||||
assert val.constructor.name_hint == 'nil'
|
assert val.tag == p.nil.tag
|
||||||
break
|
break
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def tree_to_dict(t):
|
def tree_to_dict(t):
|
||||||
assert isinstance(t, ConstructorValue)
|
assert isinstance(t, ConstructorValue)
|
||||||
ret = {}
|
ret = {}
|
||||||
assert t.constructor.name_hint == 'rose'
|
assert t.tag == p.rose.tag
|
||||||
ret['member'] = t.fields[0]
|
ret['member'] = t.fields[0]
|
||||||
ret['children'] = []
|
ret['children'] = []
|
||||||
for subtree in to_list(t.fields[1]):
|
for subtree in to_list(t.fields[1]):
|
||||||
|
|
|
@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple():
|
||||||
prelude = relay.prelude.Prelude(mod)
|
prelude = relay.prelude.Prelude(mod)
|
||||||
intrp = create_executor("debug", mod)
|
intrp = create_executor("debug", mod)
|
||||||
|
|
||||||
nil_value = ConstructorValue(prelude.nil, [], [])
|
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
|
||||||
cons_value = ConstructorValue(prelude.cons, [
|
cons_value = ConstructorValue(prelude.cons.tag, [
|
||||||
TensorValue(np.random.rand(1, 10).astype('float32')),
|
TensorValue(np.random.rand(1, 10).astype('float32')),
|
||||||
nil_value
|
nil_value
|
||||||
], [relay.TensorType((1, 10), 'float32')])
|
], prelude.cons, [relay.TensorType((1, 10), 'float32')])
|
||||||
|
|
||||||
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
|
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
|
||||||
tuple_value = TupleValue(*[
|
tuple_value = TupleValue(*[
|
||||||
|
@ -197,16 +197,16 @@ def test_function_taking_adt_ref_tuple():
|
||||||
id_func = intrp.evaluate(prelude.id)
|
id_func = intrp.evaluate(prelude.id)
|
||||||
|
|
||||||
res_nil = id_func(nil_value)
|
res_nil = id_func(nil_value)
|
||||||
assert res_nil.constructor == nil_value.constructor
|
assert res_nil.tag == nil_value.tag
|
||||||
assert len(res_nil.fields) == 0
|
assert len(res_nil.fields) == 0
|
||||||
|
|
||||||
res_cons = id_func(cons_value)
|
res_cons = id_func(cons_value)
|
||||||
assert res_cons.constructor == cons_value.constructor
|
assert res_cons.tag == cons_value.tag
|
||||||
assert len(res_cons.fields) == len(cons_value.fields)
|
assert len(res_cons.fields) == len(cons_value.fields)
|
||||||
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
|
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
|
||||||
cons_value.fields[0].asnumpy())
|
cons_value.fields[0].asnumpy())
|
||||||
assert isinstance(res_cons.fields[1], ConstructorValue)
|
assert isinstance(res_cons.fields[1], ConstructorValue)
|
||||||
assert res_cons.fields[1].constructor == prelude.nil
|
assert res_cons.fields[1].tag == prelude.nil.tag
|
||||||
assert len(res_cons.fields[1].fields) == 0
|
assert len(res_cons.fields[1].fields) == 0
|
||||||
|
|
||||||
res_ref = id_func(ref_value)
|
res_ref = id_func(ref_value)
|
||||||
|
|
|
@ -142,8 +142,8 @@ def test_nat_add():
|
||||||
ctx = tvm.context("llvm", 0)
|
ctx = tvm.context("llvm", 0)
|
||||||
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
|
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
|
||||||
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
|
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
|
||||||
assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
|
assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
|
||||||
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
|
assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
|
||||||
assert "let" in mod[add].astext()
|
assert "let" in mod[add].astext()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -185,9 +185,7 @@ def test_tuple_second():
|
||||||
result = veval(f, (i_data, j_data))
|
result = veval(f, (i_data, j_data))
|
||||||
tvm.testing.assert_allclose(result.asnumpy(), j_data)
|
tvm.testing.assert_allclose(result.asnumpy(), j_data)
|
||||||
|
|
||||||
@nottest
|
|
||||||
def test_list_constructor():
|
def test_list_constructor():
|
||||||
# TODO(wweic): implement pattern match to support this test
|
|
||||||
def to_list(o):
|
def to_list(o):
|
||||||
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
|
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
|
||||||
return [o.data.asnumpy().tolist()]
|
return [o.data.asnumpy().tolist()]
|
||||||
|
@ -204,6 +202,11 @@ def test_list_constructor():
|
||||||
cons = p.cons
|
cons = p.cons
|
||||||
l = p.l
|
l = p.l
|
||||||
|
|
||||||
|
# remove all functions to not have pattern match to pass vm compilation
|
||||||
|
# TODO(wweic): remove the hack and implement pattern match
|
||||||
|
for v, _ in mod.functions.items():
|
||||||
|
mod[v] = relay.const(0)
|
||||||
|
|
||||||
one2 = cons(relay.const(1), nil())
|
one2 = cons(relay.const(1), nil())
|
||||||
one3 = cons(relay.const(2), one2)
|
one3 = cons(relay.const(2), one2)
|
||||||
one4 = cons(relay.const(3), one3)
|
one4 = cons(relay.const(3), one3)
|
||||||
|
@ -213,7 +216,6 @@ def test_list_constructor():
|
||||||
|
|
||||||
result = veval(mod)()
|
result = veval(mod)()
|
||||||
obj = to_list(result)
|
obj = to_list(result)
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
|
tvm.testing.assert_allclose(obj, np.array([3,2,1]))
|
||||||
|
|
||||||
def test_let_tensor():
|
def test_let_tensor():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче