diff --git a/netwerk/base/public/nsITransport.idl b/netwerk/base/public/nsITransport.idl index af5a4900f59..bc4e4121990 100644 --- a/netwerk/base/public/nsITransport.idl +++ b/netwerk/base/public/nsITransport.idl @@ -48,7 +48,7 @@ interface nsITransport : nsISupports * Open an input stream on this transport. * * @param offset - read starting at this offset - * @param count - read this many bytes + * @param count - read this many bytes (pass ULONG_MAX if unlimited) * @param flags - optional transport specific flags */ nsIInputStream openInputStream(in unsigned long offset, @@ -59,7 +59,7 @@ interface nsITransport : nsISupports * Open an output stream on this transport. * * @param offset - write starting at this offset - * @param count - write no more than this many bytes + * @param count - write no more than this many bytes (pass ULONG_MAX if unlimited) * @param flags - optional transport specific flags */ nsIOutputStream openOutputStream(in unsigned long offset, @@ -72,7 +72,7 @@ interface nsITransport : nsISupports * @param listener - notify this listener when data is available * @param ctxt - opaque parameter passed to listener methods * @param offset - read starting at this offset - * @param count - read this many bytes + * @param count - read this many bytes (pass ULONG_MAX if unlimited) * @param flags - optional transport specific flags */ nsIRequest asyncRead(in nsIStreamListener listener, @@ -87,7 +87,7 @@ interface nsITransport : nsISupports * @param provider - notify this provider when data can be written * @param ctxt - opaque parameter passed to provider methods * @param offset - write starting at this offset - * @param count - write this many bytes + * @param count - write no more than this many bytes (pass ULONG_MAX if unlimited) * @param flags - optional transport specific flags */ nsIRequest asyncWrite(in nsIStreamProvider provider, @@ -95,6 +95,19 @@ interface nsITransport : nsISupports in unsigned long offset, in unsigned long count, in unsigned long flags); + + /** + * Callbacks from asyncRead and asyncWrite may be proxied from a + * background thread (if one exists) to the thread which initiated + * the request. This is the expected behavior of such a nsITransport + * implementation. A caller of asyncRead or asyncWrite can explicitly + * ask the transport to not proxy the callback. The caller must then + * be prepared to handle callbacks on any thread. + */ + const unsigned long DONT_PROXY_STREAM_OBSERVER = 1 << 0; + const unsigned long DONT_PROXY_STREAM_LISTENER = 1 << 1; + const unsigned long DONT_PROXY_STREAM_PROVIDER = 1 << 1; + }; [scriptable, uuid(d7abf5a4-ce72-482a-9217-a219a905c019)] diff --git a/netwerk/base/public/nsNetUtil.h b/netwerk/base/public/nsNetUtil.h index d9c837b2c50..cc3eb3c8523 100644 --- a/netwerk/base/public/nsNetUtil.h +++ b/netwerk/base/public/nsNetUtil.h @@ -375,8 +375,7 @@ NS_NewStreamObserverProxy(nsIStreamObserver **aResult, rv = proxy->Init(aObserver, aEventQ); if (NS_FAILED(rv)) return rv; - NS_ADDREF(*aResult = proxy); - return NS_OK; + return CallQueryInterface(proxy, aResult); } inline nsresult @@ -570,6 +569,18 @@ NS_AsyncWriteFromStream(nsIRequest **aRequest, aObserver); if (NS_FAILED(rv)) return rv; + // + // We can safely allow the transport impl to bypass proxying the provider + // since we are using a simple stream provider. + // + // A simple stream provider masks the OnDataWritable from consumers. + // Moreover, it makes an assumption about the underlying nsIInputStream + // implementation: namely, that it is thread-safe and blocking. + // + // So, let's always make this optimization. + // + aFlags |= nsITransport::DONT_PROXY_STREAM_PROVIDER; + return aTransport->AsyncWrite(provider, aContext, aOffset, aCount, diff --git a/netwerk/base/src/nsSocketTransport.cpp b/netwerk/base/src/nsSocketTransport.cpp index 79324fcedfd..2cc77d0aa13 100644 --- a/netwerk/base/src/nsSocketTransport.cpp +++ b/netwerk/base/src/nsSocketTransport.cpp @@ -1363,20 +1363,38 @@ nsSocketTransport::AsyncRead(nsIStreamListener* aListener, rv = NS_ERROR_IN_PROGRESS; nsCOMPtr listener; + nsCOMPtr observer; - if (NS_SUCCEEDED(rv)) - rv = NS_NewStreamListenerProxy(getter_AddRefs(listener), - aListener, nsnull, - mBufferSegmentSize, - mBufferMaxSize); + if (NS_SUCCEEDED(rv)) { + // Proxy the stream listener and observer methods by default. + if (!(aFlags & nsITransport::DONT_PROXY_STREAM_OBSERVER)) { + // Cannot proxy the listener unless the observer part is also proxied. + if (!(aFlags & nsITransport::DONT_PROXY_STREAM_PROVIDER)) { + rv = NS_NewStreamListenerProxy(getter_AddRefs(listener), + aListener, nsnull, + mBufferSegmentSize, + mBufferMaxSize); + observer = do_QueryInterface(listener); + } + else { + rv = NS_NewStreamObserverProxy(getter_AddRefs(observer), aListener); + listener = aListener; + } + } + else { + listener = aListener; + observer = do_QueryInterface(aListener); + } + } if (NS_SUCCEEDED(rv)) { NS_NEWXPCOM(mReadRequest, nsSocketReadRequest); if (mReadRequest) { NS_ADDREF(mReadRequest); mReadRequest->SetTransport(this); + mReadRequest->SetObserver(observer); + mReadRequest->SetContext(aContext); mReadRequest->SetListener(listener); - mReadRequest->SetListenerContext(aContext); //mReadRequest->SetTransferCount(aCount); } else @@ -1421,20 +1439,38 @@ nsSocketTransport::AsyncWrite(nsIStreamProvider* aProvider, rv = NS_ERROR_IN_PROGRESS; nsCOMPtr provider; + nsCOMPtr observer; - if (NS_SUCCEEDED(rv)) - rv = NS_NewStreamProviderProxy(getter_AddRefs(provider), - aProvider, nsnull, - mBufferSegmentSize, - mBufferMaxSize); + if (NS_SUCCEEDED(rv)) { + // Proxy the stream provider and observer methods by default. + if (!(aFlags & nsITransport::DONT_PROXY_STREAM_OBSERVER)) { + // Cannot proxy the provider unless the observer part is also proxied. + if (!(aFlags & nsITransport::DONT_PROXY_STREAM_PROVIDER)) { + rv = NS_NewStreamProviderProxy(getter_AddRefs(provider), + aProvider, nsnull, + mBufferSegmentSize, + mBufferMaxSize); + observer = do_QueryInterface(provider); + } + else { + rv = NS_NewStreamObserverProxy(getter_AddRefs(observer), aProvider); + provider = aProvider; + } + } + else { + provider = aProvider; + observer = do_QueryInterface(aProvider); + } + } if (NS_SUCCEEDED(rv)) { NS_NEWXPCOM(mWriteRequest, nsSocketWriteRequest); if (mWriteRequest) { NS_ADDREF(mWriteRequest); mWriteRequest->SetTransport(this); + mWriteRequest->SetObserver(observer); + mWriteRequest->SetContext(aContext); mWriteRequest->SetProvider(provider); - mWriteRequest->SetProviderContext(aContext); //mWriteRequest->SetTransferCount(aCount); } else @@ -1689,7 +1725,7 @@ nsSocketTransport::OnStatus(nsresult message) req = mWriteRequest; NS_ENSURE_TRUE(req, NS_ERROR_NOT_INITIALIZED); - return OnStatus(req, req->GetContext(), message); + return OnStatus(req, req->Context(), message); } // @@ -1799,8 +1835,8 @@ tryRead: } /* if (mTransport && total) { - mTransport->OnStatus(this, mListenerContext, NS_NET_STATUS_RECEIVING_FROM); - mTransport->OnProgress(this, mListenerContext, mOffset); + mTransport->OnStatus(this, mContext, NS_NET_STATUS_RECEIVING_FROM); + mTransport->OnProgress(this, mContext, mOffset); } */ *aBytesRead = (PRUint32) total; @@ -1887,8 +1923,8 @@ tryWrite: } /* if (mTransport && total) { - mTransport->OnStatus(this, mListenerContext, NS_NET_STATUS_SENDING_TO); - mTransport->OnProgress(this, mListenerContext, mOffset); + mTransport->OnStatus(this, mContext, NS_NET_STATUS_SENDING_TO); + mTransport->OnProgress(this, mContext, mOffset); } */ *aBytesWritten = (PRUint32) total; @@ -2042,6 +2078,18 @@ nsSocketOS::nsSocketOS() NS_INIT_ISUPPORTS(); } +NS_METHOD +nsSocketOS::WriteFromSegments(nsIInputStream *input, + void *closure, + const char *fromSegment, + PRUint32 offset, + PRUint32 count, + PRUint32 *countRead) +{ + nsSocketOS *self = NS_REINTERPRET_CAST(nsSocketOS *, closure); + return self->Write(fromSegment, count, countRead); +} + NS_IMETHODIMP nsSocketOS::Close() { @@ -2085,8 +2133,7 @@ nsSocketOS::Write(const char *aBuf, PRUint32 aCount, PRUint32 *aBytesWritten) NS_IMETHODIMP nsSocketOS::WriteFrom(nsIInputStream *aIS, PRUint32 aCount, PRUint32 *aBytesWritten) { - NS_NOTREACHED("nsSocketOS::WriteFrom"); - return NS_ERROR_NOT_IMPLEMENTED; + return aIS->ReadSegments(WriteFromSegments, this, aCount, aBytesWritten); } NS_IMETHODIMP @@ -2162,24 +2209,24 @@ nsSocketRequest::SetTransport(nsSocketTransport *aTransport) } nsresult -nsSocketRequest::OnStart(nsIStreamObserver *aObserver, nsISupports *aContext) +nsSocketRequest::OnStart() { - if (aObserver) { - aObserver->OnStartRequest(this, aContext); + if (mObserver && !mStartFired) { + mObserver->OnStartRequest(this, mContext); mStartFired = PR_TRUE; } return NS_OK; } nsresult -nsSocketRequest::OnStop(nsIStreamObserver *aObserver, nsISupports *aContext) +nsSocketRequest::OnStop() { - if (aObserver) { + if (mObserver) { if (!mStartFired) { - aObserver->OnStartRequest(this, aContext); + mObserver->OnStartRequest(this, mContext); mStartFired = PR_TRUE; } - aObserver->OnStopRequest(this, aContext, mStatus, nsnull); + mObserver->OnStopRequest(this, mContext, mStatus, nsnull); } return NS_OK; } @@ -2323,7 +2370,7 @@ nsSocketReadRequest::OnRead() this, offset, amount)); rv = mListener->OnDataAvailable(this, - mListenerContext, + mContext, mInputStream, offset, amount); @@ -2358,8 +2405,8 @@ nsSocketReadRequest::OnRead() } if (mTransport && total) { - mTransport->OnStatus(this, mListenerContext, NS_NET_STATUS_RECEIVING_FROM); - mTransport->OnProgress(this, mListenerContext, offset); + mTransport->OnStatus(this, mContext, NS_NET_STATUS_RECEIVING_FROM); + mTransport->OnProgress(this, mContext, offset); } } return rv; @@ -2410,7 +2457,7 @@ nsSocketWriteRequest::OnWrite() PRUint32 offset = mOutputStream->GetOffset(); rv = mProvider->OnDataWritable(this, - mProviderContext, + mContext, mOutputStream, offset, MAX_IO_TRANSFER_SIZE); @@ -2445,8 +2492,8 @@ nsSocketWriteRequest::OnWrite() } if (mTransport && total) { - mTransport->OnStatus(this, mProviderContext, NS_NET_STATUS_SENDING_TO); - mTransport->OnProgress(this, mProviderContext, offset); + mTransport->OnStatus(this, mContext, NS_NET_STATUS_SENDING_TO); + mTransport->OnProgress(this, mContext, offset); } } return rv; diff --git a/netwerk/base/src/nsSocketTransport.h b/netwerk/base/src/nsSocketTransport.h index 5c749127607..b02c85b29fe 100644 --- a/netwerk/base/src/nsSocketTransport.h +++ b/netwerk/base/src/nsSocketTransport.h @@ -353,6 +353,9 @@ public: PRErrorCode GetError() { return mError; } private: + static NS_METHOD WriteFromSegments(nsIInputStream *, void *, const char *, + PRUint32, PRUint32, PRUint32 *); + PRUint32 mOffset; PRFileDesc *mSock; PRErrorCode mError; @@ -376,22 +379,23 @@ public: PRBool IsCanceled() { return mCanceled; } void SetTransport(nsSocketTransport *); + void SetObserver(nsIStreamObserver *obs) { mObserver = obs; } + void SetContext(nsISupports *ctx) { mContext = ctx; } void SetStatus(nsresult status) { mStatus = status; } - - virtual nsISupports *GetContext() = 0; - virtual nsresult OnStart() = 0; - virtual nsresult OnStop() = 0; + nsISupports *Context() { return mContext; } + + nsresult OnStart(); + nsresult OnStop(); protected: - nsresult OnStart(nsIStreamObserver *, nsISupports *); - nsresult OnStop(nsIStreamObserver *, nsISupports *); - - nsSocketTransport *mTransport; - nsresult mStatus; - PRIntn mSuspendCount; - PRPackedBool mCanceled; - PRPackedBool mStartFired; + nsSocketTransport *mTransport; + nsCOMPtr mObserver; + nsCOMPtr mContext; + nsresult mStatus; + PRIntn mSuspendCount; + PRPackedBool mCanceled; + PRPackedBool mStartFired; }; /** @@ -405,18 +409,12 @@ public: void SetSocket(PRFileDesc *); void SetListener(nsIStreamListener *l) { mListener = l; } - void SetListenerContext(nsISupports *c) { mListenerContext = c; } - nsISupports *GetContext() { return mListenerContext; } - - nsresult OnStart() { return nsSocketRequest::OnStart(mListener, mListenerContext); } - nsresult OnStop() { return nsSocketRequest::OnStop(mListener, mListenerContext); } nsresult OnRead(); private: nsSocketIS *mInputStream; nsCOMPtr mListener; - nsCOMPtr mListenerContext; }; /** @@ -430,18 +428,12 @@ public: void SetSocket(PRFileDesc *); void SetProvider(nsIStreamProvider *p) { mProvider = p; } - void SetProviderContext(nsISupports *c) { mProviderContext = c; } - nsISupports *GetContext() { return mProviderContext; } - - nsresult OnStart() { return nsSocketRequest::OnStart(mProvider, mProviderContext); } - nsresult OnStop() { return nsSocketRequest::OnStop(mProvider, mProviderContext); } nsresult OnWrite(); private: nsSocketOS *mOutputStream; nsCOMPtr mProvider; - nsCOMPtr mProviderContext; }; #endif /* nsSocketTransport_h___ */ diff --git a/netwerk/protocol/ftp/src/nsFtpControlConnection.cpp b/netwerk/protocol/ftp/src/nsFtpControlConnection.cpp index ea2a3103aea..557e00b43d3 100644 --- a/netwerk/protocol/ftp/src/nsFtpControlConnection.cpp +++ b/netwerk/protocol/ftp/src/nsFtpControlConnection.cpp @@ -64,7 +64,7 @@ public: if (NS_FAILED(rv)) return rv; if (avail == 0) { - NS_STATIC_CAST(nsFtpControlConnection*, aContext)->mSuspendedWrite = PR_TRUE; + NS_STATIC_CAST(nsFtpControlConnection*, aContext)->mSuspendedWrite = 1; return NS_BASE_STREAM_WOULD_BLOCK; } PRUint32 bytesWritten; @@ -156,7 +156,9 @@ nsFtpControlConnection::Connect() rv = mCPipe->AsyncWrite(provider, NS_STATIC_CAST(nsISupports*, this), - 0, 0, 0, + 0, -1, + nsITransport::DONT_PROXY_STREAM_PROVIDER | + nsITransport::DONT_PROXY_STREAM_OBSERVER, getter_AddRefs(mWriteRequest)); if (NS_FAILED(rv)) return rv; @@ -193,10 +195,9 @@ nsFtpControlConnection::Write(nsCString& command) PRUint32 cnt; nsresult rv = mOutStream->Write(command.get(), len, &cnt); if (NS_SUCCEEDED(rv) && len==cnt) { - if (mSuspendedWrite) { - mSuspendedWrite = PR_FALSE; + PRInt32 writeWasSuspended = PR_AtomicSet(&mSuspendedWrite, 0); + if (writeWasSuspended) mWriteRequest->Resume(); - } return NS_OK; } diff --git a/netwerk/protocol/ftp/src/nsFtpControlConnection.h b/netwerk/protocol/ftp/src/nsFtpControlConnection.h index 9ef86f458e9..ca06f974120 100644 --- a/netwerk/protocol/ftp/src/nsFtpControlConnection.h +++ b/netwerk/protocol/ftp/src/nsFtpControlConnection.h @@ -57,7 +57,7 @@ public: nsCAutoString mCwd; // what dir are we in PRBool mList; // are we sending LIST or NLST nsAutoString mPassword; - PRBool mSuspendedWrite; + PRInt32 mSuspendedWrite; private: PRLock* mLock; // protects mListener.