[RELAY][TypeSystem] Add support for populating type args (#1962)
This commit is contained in:
Родитель
3a1bb8c7d4
Коммит
3bfa5fc03f
|
@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
|
|||
return map_.get<ValueType>(op, def_value);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Check that an expression is a "primtive operator".
|
||||
*
|
||||
* Will return true if the expression is an operator which
|
||||
* matches the form of primtive operators registered directly
|
||||
* by the Relay codebase.
|
||||
*
|
||||
* That is the arguments are all type variables, and there is a single
|
||||
* type relation applied to the input and output types.
|
||||
*/
|
||||
inline bool IsPrimitiveOp(const Expr& expr) {
|
||||
const auto* op = expr.as<OpNode>();
|
||||
|
||||
if (!op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& fn_ty = op->op_type;
|
||||
if (fn_ty->type_constraints.size() != 1) return false;
|
||||
|
||||
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
|
||||
if (rel == nullptr) return false;
|
||||
// validate if the type parameter matches up
|
||||
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
|
||||
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_OP_H_
|
||||
|
|
|
@ -278,10 +278,7 @@ class TextPrinter :
|
|||
}
|
||||
|
||||
TextValue VisitExpr_(const CallNode* op) final {
|
||||
// TODO(tqchen, M.K.): support generic call
|
||||
// possibly through meta-data
|
||||
CHECK_EQ(op->type_args.size(), 0U)
|
||||
<< "generic call not yet supported";
|
||||
TextValue call_op = GetValue(op->op);
|
||||
std::vector<TextValue> args;
|
||||
for (Expr arg : op->args) {
|
||||
|
@ -289,7 +286,23 @@ class TextPrinter :
|
|||
}
|
||||
TextValue id = this->AllocTempVar();
|
||||
this->PrintIndent();
|
||||
stream_ << id << " = " << call_op << "(";
|
||||
|
||||
stream_ << id << " = " << call_op;
|
||||
|
||||
auto type_args = op->type_args;
|
||||
|
||||
if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
|
||||
stream_ << "<";
|
||||
for (size_t i = 0; i < op->type_args.size(); ++i) {
|
||||
this->PrintType(type_args[i], stream_);
|
||||
if (i + 1 != type_args.size()) {
|
||||
stream_ << ", ";
|
||||
}
|
||||
}
|
||||
stream_ << ">";
|
||||
}
|
||||
|
||||
stream_ << "(";
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
stream_ << args[i];
|
||||
if (i + 1 != args.size()) {
|
||||
|
|
|
@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
|
|||
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
|
||||
TupleGetItemRel);
|
||||
|
||||
struct ResolvedTypeInfo {
|
||||
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
|
||||
: checked_type(checked_type), type_args(type_args) {}
|
||||
ResolvedTypeInfo() {}
|
||||
|
||||
Type checked_type;
|
||||
// Only allocated when the expression is a call.
|
||||
|
||||
Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr));
|
||||
};
|
||||
|
||||
//
|
||||
// The inference algorithm can roughly be devided into three stages:
|
||||
// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
|
||||
|
@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
Environment env_;
|
||||
// map from expression to checked type
|
||||
// type inferencer will populate it up
|
||||
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_;
|
||||
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
|
||||
|
||||
// The solver used by the inferencer.
|
||||
TypeSolver solver_;
|
||||
// relation function
|
||||
|
@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
// will call visit to deduce it if it is not in the type_map_
|
||||
Type GetType(const Expr &expr) {
|
||||
auto it = type_map_.find(expr);
|
||||
if (it != type_map_.end()) {
|
||||
return it->second;
|
||||
if (it != type_map_.end() && it->second.checked_type.defined()) {
|
||||
return it->second.checked_type;
|
||||
}
|
||||
Type ret = this->VisitExpr(expr);
|
||||
type_map_[expr] = ret;
|
||||
ResolvedTypeInfo& rti = type_map_[expr];
|
||||
rti.checked_type = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
}
|
||||
CHECK(!type_map_.count(op->var));
|
||||
// NOTE: no scoping is necessary because var are unique in program
|
||||
type_map_[op->var] = vtype;
|
||||
type_map_[op->var].checked_type = vtype;
|
||||
return GetType(op->body);
|
||||
}
|
||||
|
||||
|
@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
subst_map.Set(ty_param, fresh);
|
||||
ty_args->push_back(fresh);
|
||||
}
|
||||
|
||||
Type ret_type = fn_ty->ret_type;
|
||||
|
||||
// If the function type is incomplete, place a new IncompleteType
|
||||
|
@ -234,6 +248,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
if (!ret_type.defined()) {
|
||||
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
|
||||
}
|
||||
|
||||
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
|
||||
ret_type, {},
|
||||
fn_ty->type_constraints);
|
||||
|
@ -241,49 +256,74 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
return Downcast<FuncType>(inst_ty);
|
||||
}
|
||||
|
||||
void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
|
||||
auto type_info = type_map_.find(expr);
|
||||
if (type_info == type_map_.end()) {
|
||||
type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
|
||||
} else {
|
||||
CHECK(!type_info->second.type_args.defined());
|
||||
type_info->second.type_args = type_args;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle general call node.
|
||||
Type GeneralCall(const CallNode* op, Array<Type> arg_types) {
|
||||
Type ftype = GetType(op->op);
|
||||
Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
|
||||
Type ftype = GetType(call->op);
|
||||
auto* fn_ty_node = ftype.as<FuncTypeNode>();
|
||||
|
||||
CHECK(fn_ty_node != nullptr)
|
||||
<< "only expressions with function types can be called, at "
|
||||
<< op->span;
|
||||
<< call->span;
|
||||
|
||||
Array<Type> type_args;
|
||||
FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
|
||||
|
||||
AddTypeArgs(GetRef<Call>(call), type_args);
|
||||
|
||||
size_t type_arity = fn_ty->arg_types.size();
|
||||
size_t number_of_args = arg_types.size();
|
||||
|
||||
if (type_arity != number_of_args) {
|
||||
if (type_arity < number_of_args) {
|
||||
LOG(FATAL) << "the function is provided too many arguments " << op->span;
|
||||
LOG(FATAL) << "the function is provided too many arguments " << call->span;
|
||||
} else {
|
||||
LOG(FATAL) << "the function is provided too few arguments" << op->span;
|
||||
LOG(FATAL) << "the function is provided too few arguments" << call->span;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
|
||||
this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span);
|
||||
this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span);
|
||||
}
|
||||
|
||||
for (auto cs : fn_ty->type_constraints) {
|
||||
solver_.AddConstraint(cs);
|
||||
if (auto tr = cs.as<TypeRelationNode>()) {
|
||||
solver_.AddConstraint(
|
||||
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
|
||||
} else {
|
||||
solver_.AddConstraint(cs);
|
||||
}
|
||||
}
|
||||
|
||||
return fn_ty->ret_type;
|
||||
}
|
||||
|
||||
Type VisitExpr_(const CallNode* op) final {
|
||||
// Fast path: well-formed primitive op
|
||||
Type VisitExpr_(const CallNode* call) final {
|
||||
Array<Type> arg_types;
|
||||
for (Expr arg : op->args) {
|
||||
for (Expr arg : call->args) {
|
||||
arg_types.push_back(GetType(arg));
|
||||
}
|
||||
if (const OpNode* opnode = op->op.as<OpNode>()) {
|
||||
|
||||
if (const OpNode* opnode = call->op.as<OpNode>()) {
|
||||
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
|
||||
arg_types,
|
||||
op->attrs);
|
||||
if (rtype.defined()) return rtype;
|
||||
call->attrs);
|
||||
if (rtype.defined()) {
|
||||
AddTypeArgs(GetRef<Call>(call), arg_types);
|
||||
return rtype;
|
||||
}
|
||||
}
|
||||
return GeneralCall(op, arg_types);
|
||||
|
||||
return GeneralCall(call, arg_types);
|
||||
}
|
||||
|
||||
Type VisitExpr_(const FunctionNode* f) final {
|
||||
|
@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
|
||||
class TypeInferencer::Resolver : public ExprMutator {
|
||||
public:
|
||||
Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap,
|
||||
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap,
|
||||
TypeSolver* solver)
|
||||
: tmap_(tmap), solver_(solver) {
|
||||
}
|
||||
|
@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator {
|
|||
Expr AttachCheckedType(const T* op) {
|
||||
auto it = tmap_.find(GetRef<Expr>(op));
|
||||
CHECK(it != tmap_.end());
|
||||
Type checked_type = solver_->Resolve(it->second);
|
||||
Type checked_type = solver_->Resolve(it->second.checked_type);
|
||||
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
|
||||
<< "Cannot resolve type of " << GetRef<Expr>(op)
|
||||
<< " at " << op->span;
|
||||
|
@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator {
|
|||
}
|
||||
new_e->checked_type_ = checked_type;
|
||||
}
|
||||
|
||||
if (it->second.type_args.defined()) {
|
||||
Call call = Downcast<Call>(new_e);
|
||||
const CallNode* const_call_ref = call.operator->();
|
||||
CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
|
||||
call_ref->type_args = it->second.type_args;
|
||||
|
||||
for (size_t i = 0; i < call->type_args.size(); i++) {
|
||||
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
|
||||
}
|
||||
}
|
||||
|
||||
return new_e;
|
||||
}
|
||||
|
||||
Type VisitType(const Type& t) final {
|
||||
Type VisitType(const Type &t) final {
|
||||
return solver_->Resolve(t);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_;
|
||||
const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
|
||||
TypeSolver* solver_;
|
||||
};
|
||||
|
||||
|
||||
Expr TypeInferencer::Infer(Expr expr) {
|
||||
// step 0: populate the constraints
|
||||
// Step 0: Populate the constraints.
|
||||
GetType(expr);
|
||||
// step 1: solve the constraints
|
||||
// Step 1: Solve the constraints.
|
||||
solver_.Solve();
|
||||
// step 2: attach resolved types to checked_type field
|
||||
// Step 2: Attach resolved types to checked_type field.
|
||||
return Resolver(type_map_, &solver_).VisitExpr(expr);
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,21 @@ def test_free_expr():
|
|||
yy = relay.ir_pass.infer_type(y)
|
||||
assert yy.checked_type == relay.scalar_type("float32")
|
||||
|
||||
def test_type_args():
|
||||
x = relay.var("x", shape=(10, 10))
|
||||
y = relay.var("y", shape=(1, 10))
|
||||
z = relay.add(x, y)
|
||||
ty_z = relay.ir_pass.infer_type(z)
|
||||
ty_args = ty_z.type_args
|
||||
assert len(ty_args) == 2
|
||||
assert ty_args[0].dtype == "float32"
|
||||
assert ty_args[1].dtype == "float32"
|
||||
sh1 = ty_args[0].shape
|
||||
sh2 = ty_args[1].shape
|
||||
assert sh1[0].value == 10
|
||||
assert sh1[1].value == 10
|
||||
assert sh2[0].value == 1
|
||||
assert sh2[1].value == 10
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_free_expr()
|
||||
|
@ -100,3 +115,5 @@ if __name__ == "__main__":
|
|||
test_decl()
|
||||
test_recursion()
|
||||
test_tuple()
|
||||
test_free_expr()
|
||||
test_type_args()
|
||||
|
|
Загрузка…
Ссылка в новой задаче