gecko-dev/dom/flyweb/HttpServer.cpp

1323 строки
36 KiB
C++

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=8 sts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "mozilla/dom/HttpServer.h"
#include "nsISocketTransport.h"
#include "nsWhitespaceTokenizer.h"
#include "nsNetUtil.h"
#include "nsIStreamTransportService.h"
#include "nsIAsyncStreamCopier2.h"
#include "nsIPipe.h"
#include "nsIOService.h"
#include "nsIHttpChannelInternal.h"
#include "Base64.h"
#include "WebSocketChannel.h"
#include "nsCharSeparatedTokenizer.h"
#include "nsIX509Cert.h"
static NS_DEFINE_CID(kStreamTransportServiceCID, NS_STREAMTRANSPORTSERVICE_CID);
namespace mozilla {
namespace dom {
static LazyLogModule gHttpServerLog("HttpServer");
#undef LOG_I
#define LOG_I(...) MOZ_LOG(gHttpServerLog, mozilla::LogLevel::Debug, (__VA_ARGS__))
#undef LOG_V
#define LOG_V(...) MOZ_LOG(gHttpServerLog, mozilla::LogLevel::Verbose, (__VA_ARGS__))
#undef LOG_E
#define LOG_E(...) MOZ_LOG(gHttpServerLog, mozilla::LogLevel::Error, (__VA_ARGS__))
NS_IMPL_ISUPPORTS(HttpServer,
nsIServerSocketListener,
nsILocalCertGetCallback)
HttpServer::HttpServer(AbstractThread* aMainThread)
: mPort()
, mHttps()
, mAbstractMainThread(aMainThread)
{
}
HttpServer::~HttpServer()
{
}
void
HttpServer::Init(int32_t aPort, bool aHttps, HttpServerListener* aListener)
{
mPort = aPort;
mHttps = aHttps;
mListener = aListener;
if (mHttps) {
nsCOMPtr<nsILocalCertService> lcs =
do_CreateInstance("@mozilla.org/security/local-cert-service;1");
nsresult rv = lcs->GetOrCreateCert(NS_LITERAL_CSTRING("flyweb"), this);
if (NS_FAILED(rv)) {
NotifyStarted(rv);
}
} else {
// Make sure to always have an async step before notifying callbacks
HandleCert(nullptr, NS_OK);
}
}
NS_IMETHODIMP
HttpServer::HandleCert(nsIX509Cert* aCert, nsresult aResult)
{
nsresult rv = aResult;
if (NS_SUCCEEDED(rv)) {
rv = StartServerSocket(aCert);
}
if (NS_FAILED(rv) && mServerSocket) {
mServerSocket->Close();
mServerSocket = nullptr;
}
NotifyStarted(rv);
return NS_OK;
}
void
HttpServer::NotifyStarted(nsresult aStatus)
{
RefPtr<HttpServerListener> listener = mListener;
nsCOMPtr<nsIRunnable> event = NS_NewRunnableFunction([listener, aStatus] ()
{
listener->OnServerStarted(aStatus);
});
NS_DispatchToCurrentThread(event);
}
nsresult
HttpServer::StartServerSocket(nsIX509Cert* aCert)
{
nsresult rv;
mServerSocket =
do_CreateInstance(aCert ? "@mozilla.org/network/tls-server-socket;1"
: "@mozilla.org/network/server-socket;1", &rv);
NS_ENSURE_SUCCESS(rv, rv);
rv = mServerSocket->Init(mPort, false, -1);
NS_ENSURE_SUCCESS(rv, rv);
if (aCert) {
nsCOMPtr<nsITLSServerSocket> tls = do_QueryInterface(mServerSocket);
rv = tls->SetServerCert(aCert);
NS_ENSURE_SUCCESS(rv, rv);
rv = tls->SetSessionTickets(false);
NS_ENSURE_SUCCESS(rv, rv);
mCert = aCert;
}
rv = mServerSocket->AsyncListen(this);
NS_ENSURE_SUCCESS(rv, rv);
rv = mServerSocket->GetPort(&mPort);
NS_ENSURE_SUCCESS(rv, rv);
LOG_I("HttpServer::StartServerSocket(%p)", this);
return NS_OK;
}
NS_IMETHODIMP
HttpServer::OnSocketAccepted(nsIServerSocket* aServ,
nsISocketTransport* aTransport)
{
MOZ_ASSERT(SameCOMIdentity(aServ, mServerSocket));
nsresult rv;
RefPtr<Connection> conn = new Connection(aTransport, this, rv);
NS_ENSURE_SUCCESS(rv, rv);
LOG_I("HttpServer::OnSocketAccepted(%p) - Socket %p", this, conn.get());
mConnections.AppendElement(conn.forget());
return NS_OK;
}
NS_IMETHODIMP
HttpServer::OnStopListening(nsIServerSocket* aServ,
nsresult aStatus)
{
MOZ_ASSERT(aServ == mServerSocket || !mServerSocket);
LOG_I("HttpServer::OnStopListening(%p) - status 0x%" PRIx32,
this, static_cast<uint32_t>(aStatus));
Close();
return NS_OK;
}
void
HttpServer::SendResponse(InternalRequest* aRequest, InternalResponse* aResponse)
{
for (Connection* conn : mConnections) {
if (conn->TryHandleResponse(aRequest, aResponse)) {
return;
}
}
MOZ_ASSERT(false, "Unknown request");
}
already_AddRefed<nsITransportProvider>
HttpServer::AcceptWebSocket(InternalRequest* aConnectRequest,
const Optional<nsAString>& aProtocol,
ErrorResult& aRv)
{
for (Connection* conn : mConnections) {
if (!conn->HasPendingWebSocketRequest(aConnectRequest)) {
continue;
}
nsCOMPtr<nsITransportProvider> provider =
conn->HandleAcceptWebSocket(aProtocol, aRv);
if (aRv.Failed()) {
conn->Close();
}
// This connection is now owned by the websocket, or we just closed it
mConnections.RemoveElement(conn);
return provider.forget();
}
MOZ_ASSERT(false, "Unknown request");
aRv.Throw(NS_ERROR_UNEXPECTED);
return nullptr;
}
void
HttpServer::SendWebSocketResponse(InternalRequest* aConnectRequest,
InternalResponse* aResponse)
{
for (Connection* conn : mConnections) {
if (conn->HasPendingWebSocketRequest(aConnectRequest)) {
conn->HandleWebSocketResponse(aResponse);
return;
}
}
MOZ_ASSERT(false, "Unknown request");
}
void
HttpServer::Close()
{
if (mServerSocket) {
mServerSocket->Close();
mServerSocket = nullptr;
}
if (mListener) {
RefPtr<HttpServerListener> listener = mListener.forget();
listener->OnServerClose();
}
for (Connection* conn : mConnections) {
conn->Close();
}
mConnections.Clear();
}
void
HttpServer::GetCertKey(nsACString& aKey)
{
nsAutoString tmp;
if (mCert) {
mCert->GetSha256Fingerprint(tmp);
}
LossyCopyUTF16toASCII(tmp, aKey);
}
NS_IMPL_ISUPPORTS(HttpServer::TransportProvider,
nsITransportProvider)
HttpServer::TransportProvider::~TransportProvider()
{
}
NS_IMETHODIMP
HttpServer::TransportProvider::SetListener(nsIHttpUpgradeListener* aListener)
{
MOZ_ASSERT(!mListener);
MOZ_ASSERT(aListener);
mListener = aListener;
MaybeNotify();
return NS_OK;
}
NS_IMETHODIMP
HttpServer::TransportProvider::GetIPCChild(PTransportProviderChild** aChild)
{
MOZ_CRASH("Don't call this in parent process");
*aChild = nullptr;
return NS_OK;
}
void
HttpServer::TransportProvider::SetTransport(nsISocketTransport* aTransport,
nsIAsyncInputStream* aInput,
nsIAsyncOutputStream* aOutput)
{
MOZ_ASSERT(!mTransport);
MOZ_ASSERT(aTransport && aInput && aOutput);
mTransport = aTransport;
mInput = aInput;
mOutput = aOutput;
MaybeNotify();
}
void
HttpServer::TransportProvider::MaybeNotify()
{
if (mTransport && mListener) {
RefPtr<TransportProvider> self = this;
nsCOMPtr<nsIRunnable> event = NS_NewRunnableFunction([self, this] ()
{
DebugOnly<nsresult> rv = mListener->OnTransportAvailable(mTransport,
mInput, mOutput);
MOZ_ASSERT(NS_SUCCEEDED(rv));
});
NS_DispatchToCurrentThread(event);
}
}
NS_IMPL_ISUPPORTS(HttpServer::Connection,
nsIInputStreamCallback,
nsIOutputStreamCallback)
HttpServer::Connection::Connection(nsISocketTransport* aTransport,
HttpServer* aServer,
nsresult& rv)
: mServer(aServer)
, mTransport(aTransport)
, mState(eRequestLine)
, mPendingReqVersion()
, mRemainingBodySize()
, mCloseAfterRequest(false)
{
nsCOMPtr<nsIInputStream> input;
rv = mTransport->OpenInputStream(0, 0, 0, getter_AddRefs(input));
NS_ENSURE_SUCCESS_VOID(rv);
mInput = do_QueryInterface(input);
nsCOMPtr<nsIOutputStream> output;
rv = mTransport->OpenOutputStream(0, 0, 0, getter_AddRefs(output));
NS_ENSURE_SUCCESS_VOID(rv);
mOutput = do_QueryInterface(output);
if (mServer->mHttps) {
SetSecurityObserver(true);
} else {
mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread());
}
}
NS_IMETHODIMP
HttpServer::Connection::OnHandshakeDone(nsITLSServerSocket* aServer,
nsITLSClientStatus* aStatus)
{
LOG_I("HttpServer::Connection::OnHandshakeDone(%p)", this);
// XXX Verify connection security
SetSecurityObserver(false);
mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread());
return NS_OK;
}
void
HttpServer::Connection::SetSecurityObserver(bool aListen)
{
LOG_I("HttpServer::Connection::SetSecurityObserver(%p) - %s", this,
aListen ? "On" : "Off");
nsCOMPtr<nsISupports> secInfo;
mTransport->GetSecurityInfo(getter_AddRefs(secInfo));
nsCOMPtr<nsITLSServerConnectionInfo> tlsConnInfo =
do_QueryInterface(secInfo);
MOZ_ASSERT(tlsConnInfo);
tlsConnInfo->SetSecurityObserver(aListen ? this : nullptr);
}
HttpServer::Connection::~Connection()
{
}
NS_IMETHODIMP
HttpServer::Connection::OnInputStreamReady(nsIAsyncInputStream* aStream)
{
MOZ_ASSERT(!mInput || aStream == mInput);
LOG_I("HttpServer::Connection::OnInputStreamReady(%p)", this);
if (!mInput || mState == ePause) {
return NS_OK;
}
uint64_t avail;
nsresult rv = mInput->Available(&avail);
if (NS_FAILED(rv)) {
LOG_I("HttpServer::Connection::OnInputStreamReady(%p) - Connection closed", this);
mServer->mConnections.RemoveElement(this);
// Connection closed. Handle errors here.
return NS_OK;
}
uint32_t numRead;
rv = mInput->ReadSegments(ReadSegmentsFunc,
this,
UINT32_MAX,
&numRead);
NS_ENSURE_SUCCESS(rv, rv);
rv = mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread());
NS_ENSURE_SUCCESS(rv, rv);
return NS_OK;
}
nsresult
HttpServer::Connection::ReadSegmentsFunc(nsIInputStream* aIn,
void* aClosure,
const char* aBuffer,
uint32_t aToOffset,
uint32_t aCount,
uint32_t* aWriteCount)
{
const char* buffer = aBuffer;
nsresult rv = static_cast<HttpServer::Connection*>(aClosure)->
ConsumeInput(buffer, buffer + aCount);
*aWriteCount = buffer - aBuffer;
MOZ_ASSERT(*aWriteCount <= aCount);
return rv;
}
static const char*
findCRLF(const char* aBuffer, const char* aEnd)
{
if (aBuffer + 1 >= aEnd) {
return nullptr;
}
const char* pos;
while ((pos = static_cast<const char*>(memchr(aBuffer,
'\r',
aEnd - aBuffer - 1)))) {
if (*(pos + 1) == '\n') {
return pos;
}
aBuffer = pos + 1;
}
return nullptr;
}
nsresult
HttpServer::Connection::ConsumeInput(const char*& aBuffer,
const char* aEnd)
{
nsresult rv;
while (mState == eRequestLine ||
mState == eHeaders) {
// Consume line-by-line
// Check if buffer boundry ended up right between the CR and LF
if (!mInputBuffer.IsEmpty() && mInputBuffer.Last() == '\r' &&
*aBuffer == '\n') {
aBuffer++;
rv = ConsumeLine(mInputBuffer.BeginReading(), mInputBuffer.Length() - 1);
NS_ENSURE_SUCCESS(rv, rv);
mInputBuffer.Truncate();
}
// Look for a CRLF
const char* pos = findCRLF(aBuffer, aEnd);
if (!pos) {
mInputBuffer.Append(aBuffer, aEnd - aBuffer);
aBuffer = aEnd;
return NS_OK;
}
if (!mInputBuffer.IsEmpty()) {
mInputBuffer.Append(aBuffer, pos - aBuffer);
aBuffer = pos + 2;
rv = ConsumeLine(mInputBuffer.BeginReading(), mInputBuffer.Length() - 1);
NS_ENSURE_SUCCESS(rv, rv);
mInputBuffer.Truncate();
} else {
rv = ConsumeLine(aBuffer, pos - aBuffer);
NS_ENSURE_SUCCESS(rv, rv);
aBuffer = pos + 2;
}
}
if (mState == eBody) {
uint32_t size = std::min(mRemainingBodySize,
static_cast<uint32_t>(aEnd - aBuffer));
uint32_t written = size;
if (mCurrentRequestBody) {
rv = mCurrentRequestBody->Write(aBuffer, size, &written);
// Since we've given the pipe unlimited size, we should never
// end up needing to block.
MOZ_ASSERT(rv != NS_BASE_STREAM_WOULD_BLOCK);
if (NS_FAILED(rv)) {
written = size;
mCurrentRequestBody = nullptr;
}
}
aBuffer += written;
mRemainingBodySize -= written;
if (!mRemainingBodySize) {
mCurrentRequestBody->Close();
mCurrentRequestBody = nullptr;
mState = eRequestLine;
}
}
return NS_OK;
}
bool
ContainsToken(const nsCString& aList, const nsCString& aToken)
{
nsCCharSeparatedTokenizer tokens(aList, ',');
bool found = false;
while (!found && tokens.hasMoreTokens()) {
found = tokens.nextToken().Equals(aToken);
}
return found;
}
static bool
IsWebSocketRequest(InternalRequest* aRequest, uint32_t aHttpVersion)
{
if (aHttpVersion < 1) {
return false;
}
nsAutoCString str;
aRequest->GetMethod(str);
if (!str.EqualsLiteral("GET")) {
return false;
}
InternalHeaders* headers = aRequest->Headers();
ErrorResult res;
headers->GetFirst(NS_LITERAL_CSTRING("upgrade"), str, res);
MOZ_ASSERT(!res.Failed());
if (!str.EqualsLiteral("websocket")) {
return false;
}
headers->GetFirst(NS_LITERAL_CSTRING("connection"), str, res);
MOZ_ASSERT(!res.Failed());
if (!ContainsToken(str, NS_LITERAL_CSTRING("Upgrade"))) {
return false;
}
headers->GetFirst(NS_LITERAL_CSTRING("sec-websocket-key"), str, res);
MOZ_ASSERT(!res.Failed());
nsAutoCString binary;
if (NS_FAILED(Base64Decode(str, binary)) || binary.Length() != 16) {
return false;
}
nsresult rv;
headers->GetFirst(NS_LITERAL_CSTRING("sec-websocket-version"), str, res);
MOZ_ASSERT(!res.Failed());
if (str.ToInteger(&rv) != 13 || NS_FAILED(rv)) {
return false;
}
return true;
}
nsresult
HttpServer::Connection::ConsumeLine(const char* aBuffer,
size_t aLength)
{
MOZ_ASSERT(mState == eRequestLine ||
mState == eHeaders);
if (MOZ_LOG_TEST(gHttpServerLog, mozilla::LogLevel::Verbose)) {
nsCString line(aBuffer, aLength);
LOG_V("HttpServer::Connection::ConsumeLine(%p) - \"%s\"", this, line.get());
}
if (mState == eRequestLine) {
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Parsing request line", this);
NS_ENSURE_FALSE(mCloseAfterRequest, NS_ERROR_UNEXPECTED);
if (aLength == 0) {
// Ignore empty lines before the request line
return NS_OK;
}
MOZ_ASSERT(!mPendingReq);
// Process request line
nsCWhitespaceTokenizer tokens(Substring(aBuffer, aLength));
NS_ENSURE_TRUE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED);
nsDependentCSubstring method = tokens.nextToken();
NS_ENSURE_TRUE(NS_IsValidHTTPToken(method), NS_ERROR_UNEXPECTED);
NS_ENSURE_TRUE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED);
nsDependentCSubstring url = tokens.nextToken();
// Seems like it's also allowed to pass full urls with scheme+host+port.
// May need to support that.
NS_ENSURE_TRUE(url.First() == '/', NS_ERROR_UNEXPECTED);
mPendingReq = new InternalRequest(url, /* aURLFragment */ EmptyCString());
mPendingReq->SetMethod(method);
NS_ENSURE_TRUE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED);
nsDependentCSubstring version = tokens.nextToken();
NS_ENSURE_TRUE(StringBeginsWith(version, NS_LITERAL_CSTRING("HTTP/1.")),
NS_ERROR_UNEXPECTED);
nsresult rv;
// This integer parsing is likely not strict enough.
nsCString reqVersion;
reqVersion = Substring(version, MOZ_ARRAY_LENGTH("HTTP/1.") - 1);
mPendingReqVersion = reqVersion.ToInteger(&rv);
NS_ENSURE_SUCCESS(rv, NS_ERROR_UNEXPECTED);
NS_ENSURE_FALSE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED);
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Parsed request line", this);
mState = eHeaders;
return NS_OK;
}
if (aLength == 0) {
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Found end of headers", this);
MaybeAddPendingHeader();
ErrorResult res;
mPendingReq->Headers()->SetGuard(HeadersGuardEnum::Immutable, res);
// Check for WebSocket
if (IsWebSocketRequest(mPendingReq, mPendingReqVersion)) {
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Fire OnWebSocket", this);
mState = ePause;
mPendingWebSocketRequest = mPendingReq.forget();
mPendingReqVersion = 0;
RefPtr<HttpServerListener> listener = mServer->mListener;
RefPtr<InternalRequest> request = mPendingWebSocketRequest;
nsCOMPtr<nsIRunnable> event =
NS_NewRunnableFunction([listener, request] ()
{
listener->OnWebSocket(request);
});
NS_DispatchToCurrentThread(event);
return NS_OK;
}
nsAutoCString header;
mPendingReq->Headers()->GetFirst(NS_LITERAL_CSTRING("connection"),
header,
res);
MOZ_ASSERT(!res.Failed());
// 1.0 defaults to closing connections.
// 1.1 and higher defaults to keep-alive.
if (ContainsToken(header, NS_LITERAL_CSTRING("close")) ||
(mPendingReqVersion == 0 &&
!ContainsToken(header, NS_LITERAL_CSTRING("keep-alive")))) {
mCloseAfterRequest = true;
}
mPendingReq->Headers()->GetFirst(NS_LITERAL_CSTRING("content-length"),
header,
res);
MOZ_ASSERT(!res.Failed());
LOG_V("HttpServer::Connection::ConsumeLine(%p) - content-length is \"%s\"",
this, header.get());
if (!header.IsEmpty()) {
nsresult rv;
mRemainingBodySize = header.ToInteger(&rv);
NS_ENSURE_SUCCESS(rv, rv);
} else {
mRemainingBodySize = 0;
}
if (mRemainingBodySize) {
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Starting consume body", this);
mState = eBody;
// We use an unlimited buffer size here to ensure
// that we get to the next request even if the webpage hangs on
// to the request indefinitely without consuming the body.
nsCOMPtr<nsIInputStream> input;
nsCOMPtr<nsIOutputStream> output;
nsresult rv = NS_NewPipe(getter_AddRefs(input),
getter_AddRefs(output),
0, // Segment size
UINT32_MAX, // Unlimited buffer size
false, // not nonBlockingInput
true); // nonBlockingOutput
NS_ENSURE_SUCCESS(rv, rv);
mCurrentRequestBody = do_QueryInterface(output);
mPendingReq->SetBody(input);
} else {
LOG_V("HttpServer::Connection::ConsumeLine(%p) - No body", this);
mState = eRequestLine;
}
mPendingRequests.AppendElement(PendingRequest(mPendingReq, nullptr));
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Fire OnRequest", this);
RefPtr<HttpServerListener> listener = mServer->mListener;
RefPtr<InternalRequest> request = mPendingReq.forget();
nsCOMPtr<nsIRunnable> event =
NS_NewRunnableFunction([listener, request] ()
{
listener->OnRequest(request);
});
NS_DispatchToCurrentThread(event);
mPendingReqVersion = 0;
return NS_OK;
}
// Parse header line
if (aBuffer[0] == ' ' || aBuffer[0] == '\t') {
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Add to header %s",
this,
mPendingHeaderName.get());
NS_ENSURE_FALSE(mPendingHeaderName.IsEmpty(),
NS_ERROR_UNEXPECTED);
// We might need to do whitespace trimming/compression here.
mPendingHeaderValue.Append(aBuffer, aLength);
return NS_OK;
}
MaybeAddPendingHeader();
const char* colon = static_cast<const char*>(memchr(aBuffer, ':', aLength));
NS_ENSURE_TRUE(colon, NS_ERROR_UNEXPECTED);
ToLowerCase(Substring(aBuffer, colon - aBuffer), mPendingHeaderName);
mPendingHeaderValue.Assign(colon + 1, aLength - (colon - aBuffer) - 1);
NS_ENSURE_TRUE(NS_IsValidHTTPToken(mPendingHeaderName),
NS_ERROR_UNEXPECTED);
LOG_V("HttpServer::Connection::ConsumeLine(%p) - Parsed header %s",
this,
mPendingHeaderName.get());
return NS_OK;
}
void
HttpServer::Connection::MaybeAddPendingHeader()
{
if (mPendingHeaderName.IsEmpty()) {
return;
}
// We might need to do more whitespace trimming/compression here.
mPendingHeaderValue.Trim(" \t");
ErrorResult rv;
mPendingReq->Headers()->Append(mPendingHeaderName, mPendingHeaderValue, rv);
mPendingHeaderName.Truncate();
}
bool
HttpServer::Connection::TryHandleResponse(InternalRequest* aRequest,
InternalResponse* aResponse)
{
bool handledResponse = false;
for (uint32_t i = 0; i < mPendingRequests.Length(); ++i) {
PendingRequest& pending = mPendingRequests[i];
if (pending.first() == aRequest) {
MOZ_ASSERT(!handledResponse);
MOZ_ASSERT(!pending.second());
pending.second() = aResponse;
if (i != 0) {
return true;
}
handledResponse = true;
}
if (handledResponse && !pending.second()) {
// Shortcut if we've handled the response, and
// we don't have more responses to send
return true;
}
if (i == 0 && pending.second()) {
RefPtr<InternalResponse> resp = pending.second().forget();
mPendingRequests.RemoveElementAt(0);
QueueResponse(resp);
--i;
}
}
return handledResponse;
}
already_AddRefed<nsITransportProvider>
HttpServer::Connection::HandleAcceptWebSocket(const Optional<nsAString>& aProtocol,
ErrorResult& aRv)
{
MOZ_ASSERT(mPendingWebSocketRequest);
RefPtr<InternalResponse> response =
new InternalResponse(101, NS_LITERAL_CSTRING("Switching Protocols"));
InternalHeaders* headers = response->Headers();
headers->Set(NS_LITERAL_CSTRING("Upgrade"),
NS_LITERAL_CSTRING("websocket"),
aRv);
headers->Set(NS_LITERAL_CSTRING("Connection"),
NS_LITERAL_CSTRING("Upgrade"),
aRv);
if (aProtocol.WasPassed()) {
NS_ConvertUTF16toUTF8 protocol(aProtocol.Value());
nsAutoCString reqProtocols;
mPendingWebSocketRequest->Headers()->
GetFirst(NS_LITERAL_CSTRING("Sec-WebSocket-Protocol"), reqProtocols, aRv);
if (!ContainsToken(reqProtocols, protocol)) {
// Should throw a better error here
aRv.Throw(NS_ERROR_FAILURE);
return nullptr;
}
headers->Set(NS_LITERAL_CSTRING("Sec-WebSocket-Protocol"),
protocol, aRv);
}
nsAutoCString key, hash;
mPendingWebSocketRequest->Headers()->
GetFirst(NS_LITERAL_CSTRING("Sec-WebSocket-Key"), key, aRv);
nsresult rv = mozilla::net::CalculateWebSocketHashedSecret(key, hash);
if (NS_FAILED(rv)) {
aRv.Throw(rv);
return nullptr;
}
headers->Set(NS_LITERAL_CSTRING("Sec-WebSocket-Accept"), hash, aRv);
nsAutoCString extensions, negotiatedExtensions;
mPendingWebSocketRequest->Headers()->
GetFirst(NS_LITERAL_CSTRING("Sec-WebSocket-Extensions"), extensions, aRv);
mozilla::net::ProcessServerWebSocketExtensions(extensions,
negotiatedExtensions);
if (!negotiatedExtensions.IsEmpty()) {
headers->Set(NS_LITERAL_CSTRING("Sec-WebSocket-Extensions"),
negotiatedExtensions, aRv);
}
RefPtr<TransportProvider> result = new TransportProvider();
mWebSocketTransportProvider = result;
QueueResponse(response);
return result.forget();
}
void
HttpServer::Connection::HandleWebSocketResponse(InternalResponse* aResponse)
{
MOZ_ASSERT(mPendingWebSocketRequest);
mState = eRequestLine;
mPendingWebSocketRequest = nullptr;
mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread());
QueueResponse(aResponse);
}
void
HttpServer::Connection::QueueResponse(InternalResponse* aResponse)
{
bool chunked = false;
RefPtr<InternalHeaders> headers = new InternalHeaders(*aResponse->Headers());
{
ErrorResult res;
headers->SetGuard(HeadersGuardEnum::None, res);
}
nsCOMPtr<nsIInputStream> body;
int64_t bodySize;
aResponse->GetBody(getter_AddRefs(body), &bodySize);
if (body && bodySize >= 0) {
nsCString sizeStr;
sizeStr.AppendInt(bodySize);
LOG_V("HttpServer::Connection::QueueResponse(%p) - "
"Setting content-length to %s",
this, sizeStr.get());
ErrorResult res;
headers->Set(NS_LITERAL_CSTRING("content-length"), sizeStr, res);
} else if (body) {
// Use chunked transfer encoding
LOG_V("HttpServer::Connection::QueueResponse(%p) - Chunked transfer-encoding",
this);
ErrorResult res;
headers->Set(NS_LITERAL_CSTRING("transfer-encoding"),
NS_LITERAL_CSTRING("chunked"),
res);
headers->Delete(NS_LITERAL_CSTRING("content-length"), res);
chunked = true;
} else {
LOG_V("HttpServer::Connection::QueueResponse(%p) - "
"No body - setting content-length to 0", this);
ErrorResult res;
headers->Set(NS_LITERAL_CSTRING("content-length"),
NS_LITERAL_CSTRING("0"), res);
}
nsCString head(NS_LITERAL_CSTRING("HTTP/1.1 "));
head.AppendInt(aResponse->GetStatus());
// XXX is the statustext security checked?
head.Append(NS_LITERAL_CSTRING(" ") +
aResponse->GetStatusText() +
NS_LITERAL_CSTRING("\r\n"));
AutoTArray<InternalHeaders::Entry, 16> entries;
headers->GetEntries(entries);
for (auto header : entries) {
head.Append(header.mName +
NS_LITERAL_CSTRING(": ") +
header.mValue +
NS_LITERAL_CSTRING("\r\n"));
}
head.Append(NS_LITERAL_CSTRING("\r\n"));
mOutputBuffers.AppendElement()->mString = head;
if (body) {
OutputBuffer* bodyBuffer = mOutputBuffers.AppendElement();
bodyBuffer->mStream = body;
bodyBuffer->mChunked = chunked;
}
OnOutputStreamReady(mOutput);
}
namespace {
typedef MozPromise<nsresult, bool, false> StreamCopyPromise;
class StreamCopier final : public nsIOutputStreamCallback
, public nsIInputStreamCallback
, public nsIRunnable
{
public:
static RefPtr<StreamCopyPromise>
Copy(nsIInputStream* aSource, nsIAsyncOutputStream* aSink,
bool aChunked)
{
RefPtr<StreamCopier> copier = new StreamCopier(aSource, aSink, aChunked);
RefPtr<StreamCopyPromise> p = copier->mPromise.Ensure(__func__);
nsresult rv = copier->mTarget->Dispatch(copier, NS_DISPATCH_NORMAL);
if (NS_FAILED(rv)) {
copier->mPromise.Resolve(rv, __func__);
}
return p;
}
NS_DECL_THREADSAFE_ISUPPORTS
NS_DECL_NSIINPUTSTREAMCALLBACK
NS_DECL_NSIOUTPUTSTREAMCALLBACK
NS_DECL_NSIRUNNABLE
private:
StreamCopier(nsIInputStream* aSource, nsIAsyncOutputStream* aSink,
bool aChunked)
: mSource(aSource)
, mAsyncSource(do_QueryInterface(aSource))
, mSink(aSink)
, mTarget(do_GetService(NS_STREAMTRANSPORTSERVICE_CONTRACTID))
, mChunkRemaining(0)
, mChunked(aChunked)
, mAddedFinalSeparator(false)
, mFirstChunk(aChunked)
{
}
~StreamCopier() {}
static nsresult FillOutputBufferHelper(nsIOutputStream* aOutStr,
void* aClosure,
char* aBuffer,
uint32_t aOffset,
uint32_t aCount,
uint32_t* aCountRead);
nsresult FillOutputBuffer(char* aBuffer,
uint32_t aCount,
uint32_t* aCountRead);
nsCOMPtr<nsIInputStream> mSource;
nsCOMPtr<nsIAsyncInputStream> mAsyncSource;
nsCOMPtr<nsIAsyncOutputStream> mSink;
MozPromiseHolder<StreamCopyPromise> mPromise;
nsCOMPtr<nsIEventTarget> mTarget; // XXX we should cache this somewhere
uint32_t mChunkRemaining;
nsCString mSeparator;
bool mChunked;
bool mAddedFinalSeparator;
bool mFirstChunk;
};
NS_IMPL_ISUPPORTS(StreamCopier,
nsIOutputStreamCallback,
nsIInputStreamCallback,
nsIRunnable)
struct WriteState
{
StreamCopier* copier;
nsresult sourceRv;
};
// This function only exists to enable FillOutputBuffer to be a non-static
// function where we can use member variables more easily.
nsresult
StreamCopier::FillOutputBufferHelper(nsIOutputStream* aOutStr,
void* aClosure,
char* aBuffer,
uint32_t aOffset,
uint32_t aCount,
uint32_t* aCountRead)
{
WriteState* ws = static_cast<WriteState*>(aClosure);
ws->sourceRv = ws->copier->FillOutputBuffer(aBuffer, aCount, aCountRead);
return ws->sourceRv;
}
nsresult
CheckForEOF(nsIInputStream* aIn,
void* aClosure,
const char* aBuffer,
uint32_t aToOffset,
uint32_t aCount,
uint32_t* aWriteCount)
{
*static_cast<bool*>(aClosure) = true;
*aWriteCount = 0;
return NS_BINDING_ABORTED;
}
nsresult
StreamCopier::FillOutputBuffer(char* aBuffer,
uint32_t aCount,
uint32_t* aCountRead)
{
nsresult rv = NS_OK;
while (mChunked && mSeparator.IsEmpty() && !mChunkRemaining &&
!mAddedFinalSeparator) {
uint64_t avail;
rv = mSource->Available(&avail);
if (rv == NS_BASE_STREAM_CLOSED) {
avail = 0;
rv = NS_OK;
}
NS_ENSURE_SUCCESS(rv, rv);
mChunkRemaining = avail > UINT32_MAX ? UINT32_MAX :
static_cast<uint32_t>(avail);
if (!mChunkRemaining) {
// Either it's an non-blocking stream without any data
// currently available, or we're at EOF. Sadly there's no way
// to tell other than to read from the stream.
bool hadData = false;
uint32_t numRead;
rv = mSource->ReadSegments(CheckForEOF, &hadData, 1, &numRead);
if (rv == NS_BASE_STREAM_CLOSED) {
avail = 0;
rv = NS_OK;
}
NS_ENSURE_SUCCESS(rv, rv);
MOZ_ASSERT(numRead == 0);
if (hadData) {
// The source received data between the call to Available and the
// call to ReadSegments. Restart with a new call to Available
continue;
}
// We're at EOF, write a separator with 0
mAddedFinalSeparator = true;
}
if (mFirstChunk) {
mFirstChunk = false;
MOZ_ASSERT(mSeparator.IsEmpty());
} else {
// For all chunks except the first, add the newline at the end
// of the previous chunk of data
mSeparator.AssignLiteral("\r\n");
}
mSeparator.AppendInt(mChunkRemaining, 16);
mSeparator.AppendLiteral("\r\n");
if (mAddedFinalSeparator) {
mSeparator.AppendLiteral("\r\n");
}
break;
}
// If we're doing chunked encoding, we should either have a chunk size,
// or we should have reached the end of the input stream.
MOZ_ASSERT_IF(mChunked, mChunkRemaining || mAddedFinalSeparator);
// We should only have a separator if we're doing chunked encoding
MOZ_ASSERT_IF(!mSeparator.IsEmpty(), mChunked);
if (!mSeparator.IsEmpty()) {
*aCountRead = std::min(mSeparator.Length(), aCount);
memcpy(aBuffer, mSeparator.BeginReading(), *aCountRead);
mSeparator.Cut(0, *aCountRead);
rv = NS_OK;
} else if (mChunked) {
*aCountRead = 0;
if (mChunkRemaining) {
rv = mSource->Read(aBuffer,
std::min(aCount, mChunkRemaining),
aCountRead);
mChunkRemaining -= *aCountRead;
}
} else {
rv = mSource->Read(aBuffer, aCount, aCountRead);
}
if (NS_SUCCEEDED(rv) && *aCountRead == 0) {
rv = NS_BASE_STREAM_CLOSED;
}
return rv;
}
NS_IMETHODIMP
StreamCopier::Run()
{
nsresult rv;
while (1) {
WriteState state = { this, NS_OK };
uint32_t written;
rv = mSink->WriteSegments(FillOutputBufferHelper, &state,
mozilla::net::nsIOService::gDefaultSegmentSize,
&written);
MOZ_ASSERT(NS_SUCCEEDED(rv) || NS_SUCCEEDED(state.sourceRv));
if (rv == NS_BASE_STREAM_WOULD_BLOCK) {
mSink->AsyncWait(this, 0, 0, mTarget);
return NS_OK;
}
if (NS_FAILED(rv)) {
mPromise.Resolve(rv, __func__);
return NS_OK;
}
if (state.sourceRv == NS_BASE_STREAM_WOULD_BLOCK) {
MOZ_ASSERT(mAsyncSource);
mAsyncSource->AsyncWait(this, 0, 0, mTarget);
mSink->AsyncWait(this, nsIAsyncInputStream::WAIT_CLOSURE_ONLY,
0, mTarget);
return NS_OK;
}
if (state.sourceRv == NS_BASE_STREAM_CLOSED) {
// We're done!
// No longer interested in callbacks about either stream closing
mSink->AsyncWait(nullptr, 0, 0, nullptr);
if (mAsyncSource) {
mAsyncSource->AsyncWait(nullptr, 0, 0, nullptr);
}
mSource->Close();
mSource = nullptr;
mAsyncSource = nullptr;
mSink = nullptr;
mPromise.Resolve(NS_OK, __func__);
return NS_OK;
}
if (NS_FAILED(state.sourceRv)) {
mPromise.Resolve(state.sourceRv, __func__);
return NS_OK;
}
}
MOZ_ASSUME_UNREACHABLE_MARKER();
}
NS_IMETHODIMP
StreamCopier::OnInputStreamReady(nsIAsyncInputStream* aStream)
{
MOZ_ASSERT(aStream == mAsyncSource ||
(!mSource && !mAsyncSource && !mSink));
return mSource ? Run() : NS_OK;
}
NS_IMETHODIMP
StreamCopier::OnOutputStreamReady(nsIAsyncOutputStream* aStream)
{
MOZ_ASSERT(aStream == mSink ||
(!mSource && !mAsyncSource && !mSink));
return mSource ? Run() : NS_OK;
}
} // namespace
NS_IMETHODIMP
HttpServer::Connection::OnOutputStreamReady(nsIAsyncOutputStream* aStream)
{
MOZ_ASSERT(aStream == mOutput || !mOutput);
if (!mOutput) {
return NS_OK;
}
nsresult rv;
while (!mOutputBuffers.IsEmpty()) {
if (!mOutputBuffers[0].mStream) {
nsCString& buffer = mOutputBuffers[0].mString;
while (!buffer.IsEmpty()) {
uint32_t written = 0;
rv = mOutput->Write(buffer.BeginReading(),
buffer.Length(),
&written);
buffer.Cut(0, written);
if (rv == NS_BASE_STREAM_WOULD_BLOCK) {
return mOutput->AsyncWait(this, 0, 0, NS_GetCurrentThread());
}
if (NS_FAILED(rv)) {
Close();
return NS_OK;
}
}
mOutputBuffers.RemoveElementAt(0);
} else {
if (mOutputCopy) {
// we're already copying the stream
return NS_OK;
}
mOutputCopy =
StreamCopier::Copy(mOutputBuffers[0].mStream,
mOutput,
mOutputBuffers[0].mChunked);
RefPtr<Connection> self = this;
mOutputCopy->
Then(mServer->mAbstractMainThread,
__func__,
[self, this] (nsresult aStatus) {
MOZ_ASSERT(mOutputBuffers[0].mStream);
LOG_V("HttpServer::Connection::OnOutputStreamReady(%p) - "
"Sent body. Status 0x%" PRIx32,
this, static_cast<uint32_t>(aStatus));
mOutputBuffers.RemoveElementAt(0);
mOutputCopy = nullptr;
OnOutputStreamReady(mOutput);
},
[] (bool) { MOZ_ASSERT_UNREACHABLE("Reject unexpected"); });
}
}
if (mPendingRequests.IsEmpty()) {
if (mCloseAfterRequest) {
LOG_V("HttpServer::Connection::OnOutputStreamReady(%p) - Closing channel",
this);
Close();
} else if (mWebSocketTransportProvider) {
mInput->AsyncWait(nullptr, 0, 0, nullptr);
mOutput->AsyncWait(nullptr, 0, 0, nullptr);
mWebSocketTransportProvider->SetTransport(mTransport, mInput, mOutput);
mTransport = nullptr;
mInput = nullptr;
mOutput = nullptr;
mWebSocketTransportProvider = nullptr;
}
}
return NS_OK;
}
void
HttpServer::Connection::Close()
{
if (!mTransport) {
MOZ_ASSERT(!mOutput && !mInput);
return;
}
mTransport->Close(NS_BINDING_ABORTED);
if (mInput) {
mInput->Close();
mInput = nullptr;
}
if (mOutput) {
mOutput->Close();
mOutput = nullptr;
}
mTransport = nullptr;
mInputBuffer.Truncate();
mOutputBuffers.Clear();
mPendingRequests.Clear();
}
} // namespace net
} // namespace mozilla