diff --git a/ipc/glue/ProtocolUtils.cpp b/ipc/glue/ProtocolUtils.cpp index b9cc5ce73a72..1a022048fc95 100644 --- a/ipc/glue/ProtocolUtils.cpp +++ b/ipc/glue/ProtocolUtils.cpp @@ -17,6 +17,7 @@ #include "mozilla/ipc/MessageChannel.h" #include "mozilla/ipc/Transport.h" #include "mozilla/StaticMutex.h" +#include "mozilla/Unused.h" #include "nsPrintfCString.h" #if defined(MOZ_SANDBOX) && defined(XP_WIN) @@ -511,7 +512,8 @@ IToplevelProtocol::IToplevelProtocol(ProtocolId aProtoId, Side aSide) : IProtocol(aSide), mProtocolId(aProtoId), mOtherPid(mozilla::ipc::kInvalidProcessId), - mLastRouteId(mSide == ParentSide : 1 : 0) + mLastRouteId(aSide == ParentSide ? 1 : 0), + mLastShmemId(aSide == ParentSide ? 1 : 0) { } @@ -586,7 +588,7 @@ IToplevelProtocol::IsOnCxxStack() const int32_t IToplevelProtocol::Register(IProtocol* aRouted) { - int32_t id = mSide == ParentSide ? ++mLastRouteId : --mLastRouteId; + int32_t id = GetSide() == ParentSide ? ++mLastRouteId : --mLastRouteId; mActorMap.AddWithID(aRouted, id); return id; } @@ -611,5 +613,108 @@ IToplevelProtocol::Unregister(int32_t aId) return mActorMap.Remove(aId); } +Shmem::SharedMemory* +IToplevelProtocol::CreateSharedMemory(size_t aSize, + Shmem::SharedMemory::SharedMemoryType aType, + bool aUnsafe, + Shmem::id_t* aId) +{ + RefPtr segment( + Shmem::Alloc(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), aSize, aType, aUnsafe)); + if (!segment) { + return nullptr; + } + int32_t id = GetSide() == ParentSide ? ++mLastShmemId : --mLastShmemId; + Shmem shmem( + Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), + segment.get(), + id); + Message* descriptor = shmem.ShareTo( + Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), OtherPid(), MSG_ROUTING_CONTROL); + if (!descriptor) { + return nullptr; + } + Unused << GetIPCChannel()->Send(descriptor); + + *aId = shmem.Id(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead()); + Shmem::SharedMemory* rawSegment = segment.get(); + mShmemMap.AddWithID(segment.forget().take(), *aId); + return rawSegment; +} + +Shmem::SharedMemory* +IToplevelProtocol::LookupSharedMemory(Shmem::id_t aId) +{ + return mShmemMap.Lookup(aId); +} + +bool +IToplevelProtocol::IsTrackingSharedMemory(Shmem::SharedMemory* segment) +{ + return mShmemMap.HasData(segment); +} + +bool +IToplevelProtocol::DestroySharedMemory(Shmem& shmem) +{ + Shmem::id_t aId = shmem.Id(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead()); + Shmem::SharedMemory* segment = LookupSharedMemory(aId); + if (!segment) { + return false; + } + + Message* descriptor = shmem.UnshareFrom( + Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), OtherPid(), MSG_ROUTING_CONTROL); + + mShmemMap.Remove(aId); + Shmem::Dealloc(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), segment); + + if (!GetIPCChannel()->CanSend()) { + delete descriptor; + return true; + } + + return descriptor && GetIPCChannel()->Send(descriptor); +} + +void +IToplevelProtocol::DeallocShmems() +{ + for (IDMap::const_iterator cit = mShmemMap.begin(); cit != mShmemMap.end(); ++cit) { + Shmem::Dealloc(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), cit->second); + } + mShmemMap.Clear(); +} + +bool +IToplevelProtocol::ShmemCreated(const Message& aMsg) +{ + Shmem::id_t id; + RefPtr rawmem(Shmem::OpenExisting(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), aMsg, &id, true)); + if (!rawmem) { + return false; + } + mShmemMap.AddWithID(rawmem.forget().take(), id); + return true; +} + +bool +IToplevelProtocol::ShmemDestroyed(const Message& aMsg) +{ + Shmem::id_t id; + PickleIterator iter = PickleIterator(aMsg); + if (!IPC::ReadParam(&aMsg, &iter, &id)) { + return false; + } + aMsg.EndRead(iter); + + Shmem::SharedMemory* rawmem = LookupSharedMemory(id); + if (rawmem) { + mShmemMap.Remove(id); + Shmem::Dealloc(Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead(), rawmem); + } + return true; +} + } // namespace ipc } // namespace mozilla diff --git a/ipc/glue/ProtocolUtils.h b/ipc/glue/ProtocolUtils.h index 66b8f0ea3827..e84dd8a5f533 100644 --- a/ipc/glue/ProtocolUtils.h +++ b/ipc/glue/ProtocolUtils.h @@ -261,6 +261,17 @@ public: virtual IProtocol* Lookup(int32_t); virtual void Unregister(int32_t); + virtual Shmem::SharedMemory* CreateSharedMemory( + size_t, SharedMemory::SharedMemoryType, bool, int32_t*); + virtual Shmem::SharedMemory* LookupSharedMemory(int32_t); + virtual bool IsTrackingSharedMemory(Shmem::SharedMemory*); + virtual bool DestroySharedMemory(Shmem&); + + void DeallocShmems(); + + bool ShmemCreated(const Message& aMsg); + bool ShmemDestroyed(const Message& aMsg); + virtual bool ShouldContinueFromReplyTimeout() { return false; } @@ -330,6 +341,8 @@ private: base::ProcessId mOtherPid; IDMap mActorMap; int32_t mLastRouteId; + IDMap mShmemMap; + Shmem::id_t mLastShmemId; }; class IShmemAllocator diff --git a/ipc/ipdl/ipdl/lower.py b/ipc/ipdl/ipdl/lower.py index b96b92812e44..eebe5233740b 100644 --- a/ipc/ipdl/ipdl/lower.py +++ b/ipc/ipdl/ipdl/lower.py @@ -1218,41 +1218,6 @@ class Protocol(ipdl.ast.Protocol): return _cxxManagedContainerType(Type(_actorName(actortype.name(), side)), const=const, ref=ref) - # shmem stuff - def shmemMapType(self): - assert self.decl.type.isToplevel() - return Type('IDMap', T=_rawShmemType()) - - def shmemIteratorType(self): - assert self.decl.type.isToplevel() - # XXX breaks abstractions - return Type('IDMap::const_iterator') - - def shmemMapVar(self): - assert self.decl.type.isToplevel() - return ExprVar('mShmemMap') - - def lastShmemIdVar(self): - assert self.decl.type.isToplevel() - return ExprVar('mLastShmemId') - - def shmemIdInit(self, side): - assert self.decl.type.isToplevel() - # use the same scheme for shmem IDs as actor IDs - if side is 'parent': return _FREED_ACTOR_ID - elif side is 'child': return _NULL_ACTOR_ID - else: assert 0 - - def nextShmemIdExpr(self, side): - assert self.decl.type.isToplevel() - if side is 'parent': op = '++' - elif side is 'child': op = '--' - return ExprPrefixUnop(self.lastShmemIdVar(), op) - - def removeShmemId(self, idexpr): - return ExprCall(ExprSelect(self.shmemMapVar(), '.', 'Remove'), - args=[ idexpr ]) - # XXX this is sucky, fix def usesShmem(self): return _usesShmem(self) @@ -2907,8 +2872,6 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): ExprMemberInit(p.channelVar(), [ ExprCall(ExprVar('ALLOW_THIS_IN_INITIALIZER_LIST'), [ ExprVar.THIS ]) ]), - ExprMemberInit(p.lastShmemIdVar(), - [ p.shmemIdInit(self.side) ]), ExprMemberInit(p.stateVar(), [ p.startState() ]) ] @@ -3284,28 +3247,6 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): self.cls.addstmts([ deallocsubtree, Whitespace.NL ]) if ptype.isToplevel(): - ## DeallocShmem(): - # for (cit = map.begin(); cit != map.end(); ++cit) - # Dealloc(cit->second) - # map.Clear() - deallocshmem = MethodDefn(MethodDecl(deallocshmemvar.name)) - - citvar = ExprVar('cit') - begin = ExprCall(ExprSelect(p.shmemMapVar(), '.', 'begin')) - end = ExprCall(ExprSelect(p.shmemMapVar(), '.', 'end')) - shmem = ExprSelect(citvar, '->', 'second') - foreachdealloc = StmtFor( - Param(p.shmemIteratorType(), citvar.name, begin), - ExprBinary(citvar, '!=', end), - ExprPrefixUnop(citvar, '++')) - foreachdealloc.addstmt(StmtExpr(_shmemDealloc(shmem))) - - deallocshmem.addstmts([ - foreachdealloc, - StmtExpr(ExprCall(ExprSelect(p.shmemMapVar(), '.', 'Clear'))) - ]) - self.cls.addstmts([ deallocshmem, Whitespace.NL ]) - deallocself = MethodDefn(MethodDecl(deallocselfvar.name, virtual=1)) self.cls.addstmts([ deallocself, Whitespace.NL ]) @@ -3318,11 +3259,6 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): self.cls.addstmts([ StmtDecl(Decl(_actorIdType(), p.idVar().name)) ]) - if p.decl.type.isToplevel(): - self.cls.addstmts([ - StmtDecl(Decl(p.shmemMapType(), p.shmemMapVar().name)), - StmtDecl(Decl(_shmemIdType(), p.lastShmemIdVar().name)) - ]) self.cls.addstmt(StmtDecl(Decl(Type('State'), p.stateVar().name))) @@ -3350,30 +3286,6 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): methods = [] if p.decl.type.isToplevel(): - createshmem = MethodDefn(MethodDecl( - p.createSharedMemory().name, - ret=_rawShmemType(ptr=1), - params=[ Decl(Type.SIZE, sizevar.name), - Decl(_shmemTypeType(), typevar.name), - Decl(Type.BOOL, unsafevar.name), - Decl(_shmemIdType(ptr=1), idvar.name) ], - virtual=1)) - lookupshmem = MethodDefn(MethodDecl( - p.lookupSharedMemory().name, - ret=_rawShmemType(ptr=1), - params=[ Decl(_shmemIdType(), idvar.name) ], - virtual=1)) - destroyshmem = MethodDefn(MethodDecl( - p.destroySharedMemory().name, - ret=Type.BOOL, - params=[ Decl(_shmemType(ref=1), shmemvar.name) ], - virtual=1)) - istracking = MethodDefn(MethodDecl( - p.isTrackingSharedMemory().name, - ret=Type.BOOL, - params=[ Decl(_rawShmemType(ptr=1), rawvar.name) ], - virtual=1)) - getchannel = MethodDefn(MethodDecl( p.getChannelMethod().name, ret=Type('MessageChannel', ptr=1), @@ -3386,127 +3298,12 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): virtual=1, const=1)) getchannelconst.addstmt(StmtReturn(ExprAddrOf(p.channelVar()))) - methods += [ createshmem, - lookupshmem, - istracking, - destroyshmem, - getchannel, + methods += [ getchannel, getchannelconst ] if p.decl.type.isToplevel(): tmpvar = ExprVar('tmp') - # SharedMemory* CreateSharedMemory(size_t aSize, Type aType, bool aUnsafe, id_t* aId): - # RefPtr segment(Shmem::Alloc(aSize, aType, aUnsafe)); - # if (!segment) - # return nullptr; - # Shmem shmem(segment.get(), [nextshmemid]); - # Message descriptor = shmem.ShareTo(subprocess, mId, descriptor); - # if (!descriptor) - # return nullptr; - # mChannel.Send(descriptor); - # *aId = shmem.Id(); - # SharedMemory* rawSegment = segment.get(); - # mShmemMap.Add(segment.forget().take(), *aId); - # return rawSegment; - createshmem.addstmt(StmtDecl( - Decl(_refptr(_rawShmemType()), rawvar.name), - initargs=[ _shmemAlloc(sizevar, typevar, unsafevar) ])) - failif = StmtIf(ExprNot(rawvar)) - failif.addifstmt(StmtReturn(ExprLiteral.NULL)) - createshmem.addstmt(failif) - - descriptorvar = ExprVar('descriptor') - createshmem.addstmts([ - StmtDecl( - Decl(_shmemType(), shmemvar.name), - initargs=[ _shmemBackstagePass(), - _refptrGet(rawvar), - p.nextShmemIdExpr(self.side) ]), - StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name), - init=_shmemShareTo(shmemvar, - p.callOtherPid(), - p.routingId())) - ]) - failif = StmtIf(ExprNot(descriptorvar)) - failif.addifstmt(StmtReturn(ExprLiteral.NULL)) - createshmem.addstmt(failif) - - failif = StmtIf(ExprNot(ExprCall( - ExprSelect(p.callGetChannel(), '->', 'Send'), - args=[ descriptorvar ]))) - createshmem.addstmt(failif) - - rawsegmentvar = ExprVar('rawSegment') - createshmem.addstmts([ - StmtExpr(ExprAssn(ExprDeref(idvar), _shmemId(shmemvar))), - StmtDecl(Decl(_rawShmemType(ptr=1), rawsegmentvar.name), - init=_refptrGet(rawvar)), - StmtExpr(ExprCall( - ExprSelect(p.shmemMapVar(), '.', 'AddWithID'), - args=[ _refptrTake(_refptrForget(rawvar)), ExprDeref(idvar) ])), - StmtReturn(rawsegmentvar) - ]) - - # SharedMemory* Lookup(id) - lookupshmem.addstmt(StmtReturn(ExprCall( - ExprSelect(p.shmemMapVar(), '.', 'Lookup'), - args=[ idvar ]))) - - # bool IsTrackingSharedMemory(mem) - istracking.addstmt(StmtReturn(ExprCall( - ExprSelect(p.shmemMapVar(), '.', 'HasData'), - args=[ rawvar ]))) - - # bool DestroySharedMemory(shmem): - # id = shmem.Id() - # SharedMemory* rawmem = Lookup(id) - # if (!rawmem) - # return false; - # Message descriptor = UnShare(subprocess, mId, descriptor) - # mShmemMap.Remove(id) - # Shmem::Dealloc(rawmem) - # if (!mChannel.CanSend()) { - # delete descriptor; - # return true; - # } - # return descriptor && Send(descriptor) - destroyshmem.addstmts([ - StmtDecl(Decl(_shmemIdType(), idvar.name), - init=_shmemId(shmemvar)), - StmtDecl(Decl(_rawShmemType(ptr=1), rawvar.name), - init=_lookupShmem(idvar)) - ]) - - failif = StmtIf(ExprNot(rawvar)) - failif.addifstmt(StmtReturn.FALSE) - cansend = ExprCall(ExprSelect(p.channelVar(), '.', 'CanSend'), []) - returnif = StmtIf(ExprNot(cansend)) - returnif.addifstmts([ - StmtExpr(ExprDelete(descriptorvar)), - StmtReturn.TRUE]) - destroyshmem.addstmts([ - failif, - Whitespace.NL, - StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name), - init=_shmemUnshareFrom( - shmemvar, - p.callOtherPid(), - p.routingId())), - Whitespace.NL, - StmtExpr(p.removeShmemId(idvar)), - StmtExpr(_shmemDealloc(rawvar)), - Whitespace.NL, - returnif, - Whitespace.NL, - StmtReturn(ExprBinary( - descriptorvar, '&&', - ExprCall( - ExprSelect(p.channelVar(), p.channelSel(), 'Send'), - args=[ descriptorvar ]))) - ]) - - # "private" message that passes shmem mappings from one process # to the other if p.subtreeUsesShmem(): @@ -3593,25 +3390,12 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): case = StmtBlock() - rawvar = ExprVar('rawmem') - idvar = ExprVar('id') + ifstmt = StmtIf(ExprNot(ExprCall(ExprVar('ShmemCreated'), args=[self.msgvar]))) case.addstmts([ - StmtDecl(Decl(_shmemIdType(), idvar.name)), - StmtDecl(Decl(_refptr(_rawShmemType()), rawvar.name), - initargs=[ _shmemOpenExisting(self.msgvar, - ExprAddrOf(idvar)) ]) - ]) - failif = StmtIf(ExprNot(rawvar)) - failif.addifstmt(StmtReturn(_Result.PayloadError)) - - case.addstmts([ - failif, - StmtExpr(ExprCall( - ExprSelect(p.shmemMapVar(), '.', 'AddWithID'), - args=[ _refptrTake(_refptrForget(rawvar)), idvar ])), - Whitespace.NL, + ifstmt, StmtReturn(_Result.Processed) ]) + ifstmt.addifstmt(StmtReturn(_Result.PayloadError)) return case @@ -3621,42 +3405,12 @@ class _GenerateProtocolActorCode(ipdl.ast.Visitor): case = StmtBlock() - rawvar = ExprVar('rawmem') - idvar = ExprVar('id') - itervar = ExprVar('iter') + ifstmt = StmtIf(ExprNot(ExprCall(ExprVar('ShmemDestroyed'), args=[self.msgvar]))) case.addstmts([ - StmtDecl(Decl(_shmemIdType(), idvar.name)), - StmtDecl(Decl(_iterType(ptr=0), itervar.name), init=ExprCall(ExprVar('PickleIterator'), - args=[ self.msgvar ])) - ]) - - failif = StmtIf(ExprNot( - ExprCall(ExprVar('IPC::ReadParam'), - args=[ ExprAddrOf(self.msgvar), ExprAddrOf(itervar), - ExprAddrOf(idvar) ]))) - failif.addifstmt(StmtReturn(_Result.PayloadError)) - - case.addstmts([ - failif, - StmtExpr(ExprCall(ExprSelect(self.msgvar, '.', 'EndRead'), - args=[ itervar ])), - Whitespace.NL, - StmtDecl(Decl(_rawShmemType(ptr=1), rawvar.name), - init=ExprCall(p.lookupSharedMemory(), args=[ idvar ])) - ]) - - # Here we don't return an error if we failed to look the shmem up. This - # is because we don't have a way to know if it is because we failed to - # map the shmem or if the id is wrong. In the latter case it would be - # better to catch the error but the former case is legit... - lookupif = StmtIf(rawvar) - lookupif.addifstmt(StmtExpr(p.removeShmemId(idvar))) - lookupif.addifstmt(StmtExpr(_shmemDealloc(rawvar))) - - case.addstmts([ - lookupif, + ifstmt, StmtReturn(_Result.Processed) ]) + ifstmt.addifstmt(StmtReturn(_Result.PayloadError)) return case