diff --git a/ipc/ipdl/ipdl/ast.py b/ipc/ipdl/ipdl/ast.py index fbcd3865b1f..37c6f2d5442 100644 --- a/ipc/ipdl/ipdl/ast.py +++ b/ipc/ipdl/ipdl/ast.py @@ -42,8 +42,8 @@ class Visitor: cxxInc.accept(self) for protoInc in tu.protocolIncludes: protoInc.accept(self) - for union in tu.unions: - union.accept(self) + for su in tu.structsAndUnions: + su.accept(self) for using in tu.using: using.accept(self) tu.protocol.accept(self) @@ -56,6 +56,13 @@ class Visitor: # and pass-specific handling pass + def visitStructDecl(self, struct): + for f in struct.fields: + f.accept(self) + + def visitStructField(self, field): + field.type.accept(self) + def visitUnionDecl(self, union): for t in union.components: t.accept(self) @@ -171,12 +178,13 @@ class TranslationUnit(Node): self.cxxIncludes = [ ] self.protocolIncludes = [ ] self.using = [ ] - self.unions = [ ] + self.structsAndUnions = [ ] self.protocol = None def addCxxInclude(self, cxxInclude): self.cxxIncludes.append(cxxInclude) def addProtocolInclude(self, pInc): self.protocolIncludes.append(pInc) - def addUnionDecl(self, union): self.unions.append(union) + def addStructDecl(self, struct): self.structsAndUnions.append(struct) + def addUnionDecl(self, union): self.structsAndUnions.append(union) def addUsingStmt(self, using): self.using.append(using) def setProtocol(self, protocol): self.protocol = protocol @@ -267,6 +275,17 @@ class Protocol(NamespacedNode): self.transitionStmts = [ ] self.startStates = [ ] +class StructField(Node): + def __init__(self, loc, type, name): + Node.__init__(self, loc) + self.type = type + self.name = name + +class StructDecl(NamespacedNode): + def __init__(self, loc, name, fields): + NamespacedNode.__init__(self, loc, name) + self.fields = fields + class UnionDecl(NamespacedNode): def __init__(self, loc, name, components): NamespacedNode.__init__(self, loc, name) diff --git a/ipc/ipdl/ipdl/lower.py b/ipc/ipdl/ipdl/lower.py index 4e617291501..97438095ce0 100644 --- a/ipc/ipdl/ipdl/lower.py +++ b/ipc/ipdl/ipdl/lower.py @@ -314,6 +314,11 @@ def _ifLogging(stmts): iflogging.addifstmts(stmts) return iflogging +# We need the ASTs of structs and unions to generate pickling code for +# them, but the pickling codegen only has their type info. This map +# allows the pickling code to get these ASTs given the type info. +_typeToAST = { } # [ Type -> Node ] + # XXX we need to remove these and install proper error handling def _printErrorMessage(msg): if isinstance(msg, str): @@ -397,7 +402,7 @@ class _DestroyReason: class _ConvertToCxxType(TypeVisitor): def __init__(self, side): self.side = side - + def visitBuiltinCxxType(self, t): return Type(t.name()) @@ -407,6 +412,9 @@ class _ConvertToCxxType(TypeVisitor): def visitActorType(self, a): return Type(_actorName(a.protocol.name(), self.side), ptr=1) + def visitStructType(self, s): + return Type(s.name()) + def visitUnionType(self, u): return Type(u.name()) @@ -520,10 +528,53 @@ necessarily a C++ reference.""" ##-------------------------------------------------- -class UnionDecl(ipdl.ast.UnionDecl): +class HasFQName: def fqClassName(self): return self.decl.type.fullname() + +class StructDecl(ipdl.ast.StructDecl, HasFQName): + @staticmethod + def upgrade(structDecl): + assert isinstance(structDecl, ipdl.ast.StructDecl) + structDecl.__class__ = StructDecl + return structDecl + +class _StructField(_HybridDecl): + def __init__(self, ipdltype, name, sd, side=None, other=None): + special = _hasVisibleActor(ipdltype) + fname = name + if special: + fname += side.title() + + _HybridDecl.__init__(self, ipdltype, fname) + self.side = side + self.special = special + if special: + if other is not None: + self.other = other + else: + self.other = _StructField(ipdltype, name, sd, _otherSide(side), self) + self.sd = sd + + # @override the following methods to pass |self.side| instead of + # forcing the caller to remember which side we're declared to + # represent. + def bareType(self, side=None): + return _HybridDecl.bareType(self, self.side) + def refType(self, side=None): + return _HybridDecl.refType(self, self.side) + def constRefType(self, side=None): + return _HybridDecl.constRefType(self, self.side) + def ptrToType(self, side=None): + return _HybridDecl.ptrToType(self, self.side) + def constPtrToType(self, side=None): + return _HybridDecl.constPtrToType(self, self.side) + def inType(self, side=None): + return _HybridDecl.inType(self, self.side) + + +class UnionDecl(ipdl.ast.UnionDecl, HasFQName): def callType(self, var=None): func = ExprVar('type') if var is not None: @@ -1080,9 +1131,29 @@ with some new IPDL/C++ nodes that are tuned for C++ codegen.""" self.typedefs.append(Typedef(Type(using.decl.fullname), using.decl.shortname)) + def visitStructDecl(self, sd): + sd.decl.special = 0 + newfields = [ ] + for f in sd.fields: + ftype = f.decl.type + if _hasVisibleActor(ftype): + sd.decl.special = 1 + # if ftype has a visible actor, we need both + # |ActorParent| and |ActorChild| fields + newfields.append(_StructField(ftype, f.name, sd, side='parent')) + newfields.append(_StructField(ftype, f.name, sd, side='child')) + else: + newfields.append(_StructField(ftype, f.name, sd)) + sd.fields = newfields + StructDecl.upgrade(sd) + _typeToAST[sd.decl.type] = sd + + if sd.decl.fullname is not None: + self.typedefs.append(Typedef(Type(sd.fqClassName()), sd.name)) + + def visitUnionDecl(self, ud): ud.decl.special = 0 - ud.decl.type._ud = ud # sucky newcomponents = [ ] for ctype in ud.decl.type.components: if _hasVisibleActor(ctype): @@ -1095,6 +1166,7 @@ with some new IPDL/C++ nodes that are tuned for C++ codegen.""" newcomponents.append(_UnionMember(ctype, ud)) ud.components = newcomponents UnionDecl.upgrade(ud) + _typeToAST[ud.decl.type] = ud if ud.decl.fullname is not None: self.typedefs.append(Typedef(Type(ud.fqClassName()), ud.name)) @@ -1147,9 +1219,11 @@ child actors.''' def visitCxxInclude(self, inc): self.file.addthing(CppDirective('include', '"'+ inc.file +'"')) + def visitStructDecl(self, sd): + self.file.addthings(_generateCxxStruct(sd)) + def visitUnionDecl(self, ud): - self.file.addthings( - _generateCxxUnionStuff(ud)) + self.file.addthings(_generateCxxUnionStuff(ud)) def visitProtocol(self, p): @@ -1274,6 +1348,116 @@ def _generateMessageClass(clsname, msgid, typedefs, prettyName): ##-------------------------------------------------- +class _ComputeTypeDeps(TypeVisitor): + '''Pass that gathers the C++ types that a particular IPDL type +(recursively) depends on. There are two kinds of dependencies: (i) +types that need forward declaration; (ii) types that need a |using| +stmt. Some types generate both kinds.''' + + def __init__(self): + self.usingTypedefs = [ ] + self.forwardDeclStmts = [ ] + self.seen = set() + + def maybeTypedef(self, fqname, name): + if fqname != name: + self.usingTypedefs.append(Typedef(Type(fqname), name)) + + def visitBuiltinCxxType(self, t): + if t in self.seen: return + self.seen.add(t) + self.maybeTypedef(t.fullname(), t.name()) + + def visitImportedCxxType(self, t): + if t in self.seen: return + self.seen.add(t) + self.maybeTypedef(t.fullname(), t.name()) + + def visitActorType(self, t): + if t in self.seen: return + self.seen.add(t) + + fqname, name = t.fullname(), t.name() + + self.maybeTypedef(_actorName(fqname, 'Parent'), + _actorName(name, 'Parent')) + self.maybeTypedef(_actorName(fqname, 'Child'), + _actorName(name, 'Child')) + + self.forwardDeclStmts.extend([ + _makeForwardDecl(t.protocol, 'parent'), Whitespace.NL, + _makeForwardDecl(t.protocol, 'child'), Whitespace.NL + ]) + + def visitStructOrUnionType(self, su, defaultVisit): + if su in self.seen: return + self.seen.add(su) + self.maybeTypedef(su.fullname(), su.name()) + + return defaultVisit(self, su) + + def visitStructType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitStructType) + + def visitUnionType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType) + + def visitArrayType(self, t): + return TypeVisitor.visitArrayType(self, t) + + def visitVoidType(self, v): assert 0 + def visitMessageType(self, v): assert 0 + def visitProtocolType(self, v): assert 0 + def visitStateType(self, v): assert 0 + + +def _generateCxxStruct(sd): + ''' ''' + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps() + for f in sd.fields: + f.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + + struct = Class(sd.name, struct=1) + struct.addstmts( + [ Label.PRIVATE ] + + usingTypedefs + + [ Whitespace.NL, + Label.PUBLIC, + ConstructorDefn(ConstructorDecl(sd.name)), + Whitespace.NL, + ConstructorDefn( + ConstructorDecl( + sd.name, + params=[ Decl(f.constRefType(), '_'+ f.name) + for f in sd.fields ]), + memberinits=[ ExprCall(f.var(), + args=[ ExprVar('_'+ f.name) ]) + for f in sd.fields ]), + Whitespace.NL, + Whitespace('// Default copy ctor, op=, and dtor are OK\n\n', + indent=1) + ] + + [ StmtDecl(Decl(f.bareType(), f.name)) for f in sd.fields ]) + + + return ( + [ + Whitespace(""" +//----------------------------------------------------------------------------- +// Definition of the IPDL type |struct %s| +// +"""% (sd.name)) + ] + + forwarddeclstmts + + [ _putInNamespaces(struct, sd.namespaces) ]) + + +##-------------------------------------------------- + def _generateCxxUnionStuff(ud): # This Union class basically consists of a type (enum) and a # union for storage. The union can contain POD and non-POD @@ -1342,60 +1526,12 @@ def _generateCxxUnionStuff(ud): return ifdied # compute all the typedefs and forward decls we need to make - usingTypedefs = [ ] - forwarddeclstmts = [ ] - class computeTypeDeps(ipdl.type.TypeVisitor): - def __init__(self): self.seen = set() - - def maybeTypedef(self, fqname, name): - if fqname != name: - usingTypedefs.append(Typedef(Type(fqname), name)) - - def visitBuiltinCxxType(self, t): - if t in self.seen: return - self.seen.add(t) - self.maybeTypedef(t.fullname(), t.name()) - - def visitImportedCxxType(self, t): - if t in self.seen: return - self.seen.add(t) - self.maybeTypedef(t.fullname(), t.name()) - - def visitActorType(self, t): - if t in self.seen: return - self.seen.add(t) - - fqname, name = t.fullname(), t.name() - - self.maybeTypedef(_actorName(fqname, 'Parent'), - _actorName(name, 'Parent')) - self.maybeTypedef(_actorName(fqname, 'Child'), - _actorName(name, 'Child')) - - forwarddeclstmts.extend([ - _makeForwardDecl(t.protocol, 'parent'), Whitespace.NL, - _makeForwardDecl(t.protocol, 'child'), Whitespace.NL - ]) - - def visitUnionType(self, t): - if t == ud.decl.type or t in self.seen: return - self.seen.add(t) - self.maybeTypedef(t.fullname(), t.name()) - - return ipdl.type.TypeVisitor.visitUnionType(self, t) - - def visitArrayType(self, t): - return ipdl.type.TypeVisitor.visitArrayType(self, t) - - def visitVoidType(self, v): assert 0 - def visitMessageType(self, v): assert 0 - def visitProtocolType(self, v): assert 0 - def visitStateType(self, v): assert 0 - - gettypedeps = computeTypeDeps() + gettypedeps = _ComputeTypeDeps() for c in ud.components: c.ipdltype.accept(gettypedeps) + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts # the |Type| enum, used to switch on the discunion's real type cls.addstmt(Label.PUBLIC) @@ -3030,6 +3166,9 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): class findSpecialTypes(TypeVisitor): def visitActorType(self, a): specialtypes.add(a) def visitShmemType(self, s): specialtypes.add(s) + def visitStructType(self, s): + specialtypes.add(s) + return TypeVisitor.visitStructType(self, s) def visitUnionType(self, u): specialtypes.add(u) return TypeVisitor.visitUnionType(self, u) @@ -3053,6 +3192,7 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): if t.isActor(): self.implementActorPickling(t) elif t.isArray(): self.implementSpecialArrayPickling(t) elif t.isShmem(): self.implementShmemPickling(t) + elif t.isStruct(): self.implementStructPickling(t) elif t.isUnion(): self.implementUnionPickling(t) else: assert 0 and 'unknown special type' @@ -3268,15 +3408,46 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + def implementStructPickling(self, structtype): + msgvar = self.msgvar + itervar = self.itervar + var = self.var + intype = _cxxConstRefType(structtype, self.side) + outtype = _cxxPtrToType(structtype, self.side) + sd = _typeToAST[structtype] + + write = MethodDefn(self.writeMethodDecl(intype, var)) + read = MethodDefn(self.readMethodDecl(outtype, var)) + + def get(sel, f): + return ExprSelect(var, sel, f.name) + + for f in sd.fields: + writefield = StmtExpr(self.write(f.ipdltype, get('.', f), msgvar)) + readfield = self.checkedRead(f.ipdltype, + ExprAddrOf(get('->', f)), + msgvar, itervar, + errfn=errfnRead) + if f.special and f.side != self.side: + writefield = Whitespace( + "// skipping actor field that's meaningless on this side\n", indent=1) + readfield = Whitespace( + "// skipping actor field that's meaningless on this side\n", indent=1) + write.addstmt(writefield) + read.addstmt(readfield) + + read.addstmt(StmtReturn(ExprLiteral.TRUE)) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + def implementUnionPickling(self, uniontype): msgvar = self.msgvar itervar = self.itervar var = self.var intype = _cxxConstRefType(uniontype, self.side) outtype = _cxxPtrToType(uniontype, self.side) - - # sigh - ud = uniontype._ud + ud = _typeToAST[uniontype] typename = '__type' uniontdef = Typedef(_cxxBareType(uniontype, typename), typename) diff --git a/ipc/ipdl/ipdl/parser.py b/ipc/ipdl/ipdl/parser.py index ee0801e9f4b..7d5e4c95b96 100644 --- a/ipc/ipdl/ipdl/parser.py +++ b/ipc/ipdl/ipdl/parser.py @@ -165,6 +165,7 @@ reserved = set(( 'spawns', 'start', 'state', + 'struct', 'sync', 'union', 'using')) @@ -219,7 +220,9 @@ def p_TranslationUnit(p): assert 0 for thing in p[2]: - if isinstance(thing, UnionDecl): + if isinstance(thing, StructDecl): + tu.addStructDecl(thing) + elif isinstance(thing, UnionDecl): tu.addUnionDecl(thing) elif isinstance(thing, Protocol): if tu.protocol is not None: @@ -287,6 +290,7 @@ def p_NamespacedStuff(p): def p_NamespaceThing(p): """NamespaceThing : NAMESPACE ID '{' NamespacedStuff '}' + | StructDecl | UnionDecl | ProtocolDefn""" if 2 == len(p): @@ -295,7 +299,24 @@ def p_NamespaceThing(p): for thing in p[4]: thing.addOuterNamespace(Namespace(locFromTok(p, 1), p[2])) p[0] = p[4] - + +def p_StructDecl(p): + """StructDecl : STRUCT ID '{' StructFields '}' ';'""" + p[0] = StructDecl(locFromTok(p, 1), p[2], p[4]) + +def p_StructFields(p): + """StructFields : StructFields StructField ';' + | StructField ';'""" + if 3 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[2]) + p[0] = p[1] + +def p_StructField(p): + """StructField : Type ID""" + p[0] = StructField(locFromTok(p, 1), p[1], p[2]) + def p_UnionDecl(p): """UnionDecl : UNION ID '{' ComponentTypes '}' ';'""" p[0] = UnionDecl(locFromTok(p, 1), p[2], p[4]) diff --git a/ipc/ipdl/ipdl/type.py b/ipc/ipdl/ipdl/type.py index 281bd794ecf..c407fb5e77c 100644 --- a/ipc/ipdl/ipdl/type.py +++ b/ipc/ipdl/ipdl/type.py @@ -89,6 +89,13 @@ class TypeVisitor: a.protocol.accept(self, *args) a.state.accept(self, *args) + def visitStructType(self, s, *args): + for field in s.fields: + field.accept(self, *args) + + def visitFieldType(self, f, *args): + f.type.accept(self, *args) + def visitUnionType(self, u, *args): for component in u.components: component.accept(self, *args) @@ -195,6 +202,7 @@ class IPDLType(Type): def isMessage(self): return False def isProtocol(self): return False def isActor(self): return False + def isStruct(self): return False def isUnion(self): return False def isArray(self): return False def isShmem(self): return False @@ -332,6 +340,15 @@ class ActorType(IPDLType): def fullname(self): return self.protocol.fullname() +class StructType(IPDLType): + def __init__(self, qname, fields): + self.qname = qname + self.fields = fields # [ Type ] + + def isStruct(self): return True + def name(self): return self.qname.baseid + def fullname(self): return str(self.qname) + class UnionType(IPDLType): def __init__(self, qname, components): self.qname = qname @@ -369,6 +386,10 @@ def iteractortypes(type): elif type.isArray(): for actor in iteractortypes(type.basetype): yield actor + elif type.isStruct(): + for f in type.fields: + for actor in iteractortypes(f): + yield actor elif type.isUnion(): for c in type.components: for actor in iteractortypes(c): @@ -592,9 +613,8 @@ class GatherDecls(TcheckVisitor): for using in tu.using: using.accept(self) - # declare unions - for union in tu.unions: - union.accept(self) + for su in tu.structsAndUnions: + su.accept(self) # grab symbols in the protocol itself p.accept(self) @@ -613,6 +633,38 @@ class GatherDecls(TcheckVisitor): pi.tu.accept(self) self.symtab.declare(pi.tu.protocol.decl) + def visitStructDecl(self, sd): + qname = sd.qname() + if 0 == len(qname.quals): + fullname = None + else: + fullname = str(qname) + + sd.decl = self.declare( + loc=sd.loc, + type=StructType(qname, [ ]), + shortname=sd.name, + fullname=fullname) + stype = sd.decl.type + + self.symtab.enterScope(sd) + + for f in sd.fields: + ftypedecl = self.symtab.lookup(str(f.type)) + if ftypedecl is None: + self.error(f.loc, "field `%s' of struct `%s' has unknown type `%s'", + f.name, sd.name, str(f.type)) + continue + + f.decl = self.declare( + loc=f.loc, + type=self._canonicalType(ftypedecl.type, f.type), + shortname=f.name, + fullname=None) + stype.fields.append(f.decl.type) + + self.symtab.exitScope(sd) + def visitUnionDecl(self, ud): qname = ud.qname() if 0 == len(qname.quals):