diff --git a/netwerk/protocol/websocket/nsWebSocketHandler.cpp b/netwerk/protocol/websocket/nsWebSocketHandler.cpp index 70266827951f..4504d374483a 100644 --- a/netwerk/protocol/websocket/nsWebSocketHandler.cpp +++ b/netwerk/protocol/websocket/nsWebSocketHandler.cpp @@ -266,6 +266,7 @@ class nsWSAdmissionManager { public: nsWSAdmissionManager() + : mConnectedCount(0) { MOZ_COUNT_CTOR(nsWSAdmissionManager); } @@ -328,6 +329,21 @@ public: } return PR_FALSE; } + + void IncrementConnectedCount() + { + PR_ATOMIC_INCREMENT(&mConnectedCount); + } + + void DecrementConnectedCount() + { + PR_ATOMIC_DECREMENT(&mConnectedCount); + } + + PRInt32 ConnectedCount() + { + return mConnectedCount; + } private: nsTArray mData; @@ -339,6 +355,10 @@ private: return i; return -1; } + + // ConnectedCount might be decremented from the main or the socket + // thread, so manage it with atomic counters + PRInt32 mConnectedCount; }; // similar to nsDeflateConverter except for the mandatory FLUSH calls @@ -486,6 +506,7 @@ nsWebSocketHandler::nsWebSocketHandler() : mAllowCompression(1), mAutoFollowRedirects(0), mReleaseOnTransmit(0), + mTCPClosed(0), mMaxMessageSize(16000000), mStopOnClose(NS_OK), mCloseCode(kCloseAbnormal), @@ -1286,6 +1307,36 @@ nsWebSocketHandler::EnsureHdrOut(PRUint32 size) mHdrOut = mDynamicOutput; } +void +nsWebSocketHandler::CleanupConnection() +{ + LOG(("WebSocketHandler::CleanupConnection() %p", this)); + + if (mLingeringCloseTimer) { + mLingeringCloseTimer->Cancel(); + mLingeringCloseTimer = nsnull; + } + + if (mSocketIn) { + if (sWebSocketAdmissions) + sWebSocketAdmissions->DecrementConnectedCount(); + mSocketIn->AsyncWait(nsnull, 0, 0, nsnull); + mSocketIn = nsnull; + } + + if (mSocketOut) { + mSocketOut->AsyncWait(nsnull, 0, 0, nsnull); + mSocketOut = nsnull; + } + + if (mTransport) { + mTransport->SetSecurityCallbacks(nsnull); + mTransport->SetEventSink(nsnull, nsnull); + mTransport->Close(NS_BASE_STREAM_CLOSED); + mTransport = nsnull; + } +} + void nsWebSocketHandler::StopSession(nsresult reason) { @@ -1312,7 +1363,7 @@ nsWebSocketHandler::StopSession(nsresult reason) mPingTimer = nsnull; } - if (mSocketIn) { + if (mSocketIn && !mTCPClosed) { // drain, within reason, this socket. if we leave any data // unconsumed (including the tcp fin) a RST will be generated // The right thing to do here is shutdown(SHUT_WR) and then wait @@ -1327,22 +1378,39 @@ nsWebSocketHandler::StopSession(nsresult reason) do { total += count; rv = mSocketIn->Read(buffer, 512, &count); + if (rv != NS_BASE_STREAM_WOULD_BLOCK && + (NS_FAILED(rv) || count == 0)) + mTCPClosed = PR_TRUE; } while (NS_SUCCEEDED(rv) && count > 0 && total < 32000); - - mSocketIn->AsyncWait(nsnull, 0, 0, nsnull); - mSocketIn = nsnull; } - if (mSocketOut) { - mSocketOut->AsyncWait(nsnull, 0, 0, nsnull); - mSocketOut = nsnull; - } + if (!mTCPClosed && mTransport && sWebSocketAdmissions && + sWebSocketAdmissions->ConnectedCount() < kLingeringCloseThreshold) { - if (mTransport) { - mTransport->SetSecurityCallbacks(nsnull); - mTransport->SetEventSink(nsnull, nsnull); - mTransport->Close(NS_BASE_STREAM_CLOSED); - mTransport = nsnull; + // 7.1.1 says that the client SHOULD wait for the server to close + // the TCP connection. This is so we can reuse port numbers before + // 2 MSL expires, which is not really as much of a concern for us + // as the amount of state that might be accrued by keeping this + // handler object around waiting for the server. We handle the SHOULD + // by waiting a short time in the common case, but not waiting in + // the case of high concurrency. + // + // Normally this will be taken care of in AbortSession() after mTCPClosed + // is set when the server close arrives without waiting for the timeout to + // expire. + + LOG(("nsWebSocketHandler::StopSession - Wait for Server TCP close")); + + nsresult rv; + mLingeringCloseTimer = do_CreateInstance("@mozilla.org/timer;1", &rv); + if (NS_SUCCEEDED(rv)) + mLingeringCloseTimer->InitWithCallback(this, kLingeringCloseTimeout, + nsITimer::TYPE_ONE_SHOT); + else + CleanupConnection(); + } + else { + CleanupConnection(); } if (mDNSRequest) { @@ -1381,6 +1449,17 @@ nsWebSocketHandler::AbortSession(nsresult reason) !(mRecvdHttpOnStartRequest && mRecvdHttpUpgradeTransport), "wrong thread"); + // When we are failing we need to close the TCP connection immediately + // as per 7.1.1 + mTCPClosed = PR_TRUE; + + if (mLingeringCloseTimer) { + NS_ABORT_IF_FALSE(mStopped, "Lingering without Stop"); + LOG(("Cleanup Connection based on TCP Close")); + CleanupConnection(); + return; + } + if (mStopped) return; mStopped = 1; @@ -1786,6 +1865,10 @@ nsWebSocketHandler::Notify(nsITimer *timer) AbortSession(NS_ERROR_NET_TIMEOUT); } } + else if (timer == mLingeringCloseTimer) { + LOG(("nsWebSocketHandler:: Lingering Close Timer")); + CleanupConnection(); + } else { NS_ABORT_IF_FALSE(0, "Unknown Timer"); } @@ -2100,10 +2183,13 @@ nsWebSocketHandler::OnTransportAvailable(nsISocketTransport *aTransport, NS_ABORT_IF_FALSE(NS_IsMainThread(), "not main thread"); NS_ABORT_IF_FALSE(!mRecvdHttpUpgradeTransport, "OTA duplicated"); + NS_ABORT_IF_FALSE(aSocketIn, "OTA with invalid socketIn"); mTransport = aTransport; mSocketIn = aSocketIn; mSocketOut = aSocketOut; + if (sWebSocketAdmissions) + sWebSocketAdmissions->IncrementConnectedCount(); nsresult rv; rv = mTransport->SetEventSink(nsnull, nsnull); @@ -2312,9 +2398,6 @@ nsWebSocketHandler::OnInputStreamReady(nsIAsyncInputStream *aStream) NS_ABORT_IF_FALSE(PR_GetCurrentThread() == gSocketThread, "not socket thread"); - if (mStopped) - return NS_ERROR_UNEXPECTED; - nsRefPtr deleteProtector1(mInflateReader); nsRefPtr deleteProtector2(mInflateStream); @@ -2334,15 +2417,23 @@ nsWebSocketHandler::OnInputStreamReady(nsIAsyncInputStream *aStream) } if (NS_FAILED(rv)) { + mTCPClosed = PR_TRUE; AbortSession(rv); return rv; } if (count == 0) { + mTCPClosed = PR_TRUE; AbortSession(NS_BASE_STREAM_CLOSED); return NS_OK; } + if (mStopped) { + NS_ABORT_IF_FALSE(mLingeringCloseTimer, + "OnInputReady after stop without linger"); + continue; + } + if (mInflateReader) { mInflateStream->ShareData(buffer, count); rv = mInflateReader->OnDataAvailable(nsnull, mSocketIn, diff --git a/netwerk/protocol/websocket/nsWebSocketHandler.h b/netwerk/protocol/websocket/nsWebSocketHandler.h index 0f489a7dd691..382dd0c05077 100644 --- a/netwerk/protocol/websocket/nsWebSocketHandler.h +++ b/netwerk/protocol/websocket/nsWebSocketHandler.h @@ -145,6 +145,7 @@ private: void StopSession(nsresult reason); void AbortSession(nsresult reason); void ReleaseSession(); + void CleanupConnection(); void EnsureHdrOut(PRUint32 size); void ApplyMask(PRUint32 mask, PRUint8 *data, PRUint64 len); @@ -226,6 +227,10 @@ private: PRUint32 mPingTimeout; /* milliseconds */ PRUint32 mPingResponseTimeout; /* milliseconds */ + nsCOMPtr mLingeringCloseTimer; + const static PRInt32 kLingeringCloseTimeout = 1000; + const static PRInt32 kLingeringCloseThreshold = 50; + PRUint32 mRecvdHttpOnStartRequest : 1; PRUint32 mRecvdHttpUpgradeTransport : 1; PRUint32 mRequestedClose : 1; @@ -237,6 +242,7 @@ private: PRUint32 mAllowCompression : 1; PRUint32 mAutoFollowRedirects : 1; PRUint32 mReleaseOnTransmit : 1; + PRUint32 mTCPClosed : 1; PRInt32 mMaxMessageSize; nsresult mStopOnClose;