gecko-dev/netwerk/dns/TRR.cpp

1112 строки
32 KiB
C++

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim:set ts=4 sw=2 sts=2 et cin: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "DNS.h"
#include "nsCharSeparatedTokenizer.h"
#include "nsContentUtils.h"
#include "nsHostResolver.h"
#include "nsIHttpChannel.h"
#include "nsIHttpChannelInternal.h"
#include "nsIIOService.h"
#include "nsIInputStream.h"
#include "nsISupportsBase.h"
#include "nsISupportsUtils.h"
#include "nsIUploadChannel2.h"
#include "nsNetUtil.h"
#include "nsStringStream.h"
#include "nsThreadUtils.h"
#include "nsURLHelper.h"
#include "TRR.h"
#include "TRRService.h"
#include "mozilla/Base64.h"
#include "mozilla/DebugOnly.h"
#include "mozilla/Logging.h"
#include "mozilla/Preferences.h"
#include "mozilla/Telemetry.h"
#include "mozilla/TimeStamp.h"
#include "mozilla/Tokenizer.h"
namespace mozilla {
namespace net {
#undef LOG
extern mozilla::LazyLogModule gHostResolverLog;
#define LOG(args) MOZ_LOG(gHostResolverLog, mozilla::LogLevel::Debug, args)
#define LOG_ENABLED() \
MOZ_LOG_TEST(mozilla::net::gHostResolverLog, mozilla::LogLevel::Debug)
NS_IMPL_ISUPPORTS(TRR, nsIHttpPushListener, nsIInterfaceRequestor,
nsIStreamListener, nsIRunnable)
const uint8_t kDNS_CLASS_IN = 1;
NS_IMETHODIMP
TRR::Notify(nsITimer *aTimer) {
if (aTimer == mTimeout) {
mTimeout = nullptr;
Cancel();
} else {
MOZ_CRASH("Unknown timer");
}
return NS_OK;
}
// convert a given host request to a DOH 'body'
//
nsresult TRR::DohEncode(nsCString &aBody, bool aDisableECS) {
aBody.Truncate();
// Header
aBody += '\0';
aBody += '\0'; // 16 bit id
aBody += 0x01; // |QR| Opcode |AA|TC|RD| Set the RD bit
aBody += '\0'; // |RA| Z | RCODE |
aBody += '\0';
aBody += 1; // QDCOUNT (number of entries in the question section)
aBody += '\0';
aBody += '\0'; // ANCOUNT
aBody += '\0';
aBody += '\0'; // NSCOUNT
aBody += '\0'; // ARCOUNT
aBody += aDisableECS ? 1 : '\0'; // ARCOUNT low byte for EDNS(0)
// Question
// The input host name should be converted to a sequence of labels, where
// each label consists of a length octet followed by that number of
// octets. The domain name terminates with the zero length octet for the
// null label of the root.
// Followed by 16 bit QTYPE and 16 bit QCLASS
int32_t index = 0;
int32_t offset = 0;
do {
bool dotFound = false;
int32_t labelLength;
index = mHost.FindChar('.', offset);
if (kNotFound != index) {
dotFound = true;
labelLength = index - offset;
} else {
labelLength = mHost.Length() - offset;
}
if (labelLength > 63) {
// too long label!
return NS_ERROR_ILLEGAL_VALUE;
}
aBody += static_cast<unsigned char>(labelLength);
nsDependentCSubstring label = Substring(mHost, offset, labelLength);
aBody.Append(label);
if (!dotFound) {
aBody += '\0'; // terminate with a final zero
break;
}
offset += labelLength + 1; // move over label and dot
} while (true);
aBody += '\0'; // upper 8 bit TYPE
aBody += static_cast<uint8_t>(mType);
aBody += '\0'; // upper 8 bit CLASS
aBody += kDNS_CLASS_IN; // IN - "the Internet"
if (aDisableECS) {
// EDNS(0) is RFC 6891, ECS is RFC 7871
aBody += '\0'; // NAME | domain name | MUST be 0 (root domain) |
aBody += '\0';
aBody += 41; // TYPE | u_int16_t | OPT (41) |
aBody += 16; // CLASS | u_int16_t | requestor's UDP payload size |
aBody +=
'\0'; // advertise 4K (high-byte: 16 | low-byte: 0), ignored by DoH
aBody += '\0'; // TTL | u_int32_t | extended RCODE and flags |
aBody += '\0';
aBody += '\0';
aBody += '\0';
aBody += '\0'; // upper 8 bit RDLEN
aBody += 8; // RDLEN | u_int16_t | length of all RDATA |
// RDATA | octet stream | {attribute,value} pairs |
// The RDATA is just the ECS option setting zero subnet prefix
aBody += '\0'; // upper 8 bit OPTION-CODE ECS
aBody += 8; // OPTION-CODE, 2 octets, for ECS is 8
aBody += '\0'; // upper 8 bit OPTION-LENGTH
aBody += 4; // OPTION-LENGTH, 2 octets, contains the length of the payload
// after OPTION-LENGTH
aBody += '\0'; // upper 8 bit FAMILY. IANA Address Family Numbers registry,
// not the AF_* constants!
aBody += 1; // FAMILY (Ipv4), 2 octets
aBody += '\0'; // SOURCE PREFIX-LENGTH | SCOPE PREFIX-LENGTH |
aBody += '\0';
// ADDRESS, minimum number of octets == nothing because zero bits
}
return NS_OK;
}
NS_IMETHODIMP
TRR::Run() {
MOZ_ASSERT(NS_IsMainThread());
if ((gTRRService == nullptr) || NS_FAILED(SendHTTPRequest())) {
FailData(NS_ERROR_FAILURE);
// The dtor will now be run
}
return NS_OK;
}
nsresult TRR::SendHTTPRequest() {
// This is essentially the "run" method - created from nsHostResolver
MOZ_ASSERT(NS_IsMainThread(), "wrong thread");
if ((mType != TRRTYPE_A) && (mType != TRRTYPE_AAAA) &&
(mType != TRRTYPE_NS) && (mType != TRRTYPE_TXT)) {
// limit the calling interface because nsHostResolver has explicit slots for
// these types
return NS_ERROR_FAILURE;
}
if ((mType == TRRTYPE_A) || (mType == TRRTYPE_AAAA)) {
// let NS resolves skip the blacklist check
MOZ_ASSERT(mRec);
if (gTRRService->IsTRRBlacklisted(mHost, mOriginSuffix, mPB, true)) {
if (mType == TRRTYPE_A) {
// count only blacklist for A records to avoid double counts
Telemetry::Accumulate(Telemetry::DNS_TRR_BLACKLISTED, true);
}
// not really an error but no TRR is issued
return NS_ERROR_UNKNOWN_HOST;
} else {
if (mType == TRRTYPE_A) {
Telemetry::Accumulate(Telemetry::DNS_TRR_BLACKLISTED, false);
}
}
}
nsresult rv;
nsCOMPtr<nsIIOService> ios(do_GetIOService(&rv));
NS_ENSURE_SUCCESS(rv, rv);
bool useGet = gTRRService->UseGET();
nsAutoCString body;
nsCOMPtr<nsIURI> dnsURI;
bool disableECS = gTRRService->DisableECS();
LOG(("TRR::SendHTTPRequest resolve %s type %u\n", mHost.get(), mType));
if (useGet) {
nsAutoCString tmp;
rv = DohEncode(tmp, disableECS);
NS_ENSURE_SUCCESS(rv, rv);
/* For GET requests, the outgoing packet needs to be Base64url-encoded and
then appended to the end of the URI. */
rv = Base64URLEncode(tmp.Length(),
reinterpret_cast<const unsigned char *>(tmp.get()),
Base64URLEncodePaddingPolicy::Omit, body);
NS_ENSURE_SUCCESS(rv, rv);
nsAutoCString uri;
gTRRService->GetURI(uri);
uri.Append(NS_LITERAL_CSTRING("?dns="));
uri.Append(body);
LOG(("TRR::SendHTTPRequest GET dns=%s\n", body.get()));
rv = NS_NewURI(getter_AddRefs(dnsURI), uri);
} else {
rv = DohEncode(body, disableECS);
NS_ENSURE_SUCCESS(rv, rv);
nsAutoCString uri;
gTRRService->GetURI(uri);
rv = NS_NewURI(getter_AddRefs(dnsURI), uri);
}
if (NS_FAILED(rv)) {
LOG(("TRR:SendHTTPRequest: NewURI failed!\n"));
return rv;
}
rv = NS_NewChannel(
getter_AddRefs(mChannel), dnsURI, nsContentUtils::GetSystemPrincipal(),
nsILoadInfo::SEC_ALLOW_CROSS_ORIGIN_DATA_IS_NULL,
nsIContentPolicy::TYPE_OTHER,
nullptr, // nsICookieSettings
nullptr, // PerformanceStorage
nullptr, // aLoadGroup
this,
nsIRequest::LOAD_ANONYMOUS | (mPB ? nsIRequest::INHIBIT_CACHING : 0),
ios);
if (NS_FAILED(rv)) {
LOG(("TRR:SendHTTPRequest: NewChannel failed!\n"));
return rv;
}
nsCOMPtr<nsIHttpChannel> httpChannel = do_QueryInterface(mChannel);
if (!httpChannel) {
return NS_ERROR_UNEXPECTED;
}
rv = httpChannel->SetRequestHeader(
NS_LITERAL_CSTRING("Accept"),
NS_LITERAL_CSTRING("application/dns-message"), false);
NS_ENSURE_SUCCESS(rv, rv);
nsAutoCString cred;
gTRRService->GetCredentials(cred);
if (!cred.IsEmpty()) {
rv = httpChannel->SetRequestHeader(NS_LITERAL_CSTRING("Authorization"),
cred, false);
NS_ENSURE_SUCCESS(rv, rv);
}
nsCOMPtr<nsIHttpChannelInternal> internalChannel =
do_QueryInterface(mChannel);
if (!internalChannel) {
return NS_ERROR_UNEXPECTED;
}
// setting a small stream window means the h2 stack won't pipeline a window
// update with each HEADERS or reply to a DATA with a WINDOW UPDATE
rv = internalChannel->SetInitialRwin(127 * 1024);
NS_ENSURE_SUCCESS(rv, rv);
rv = internalChannel->SetTrr(true);
NS_ENSURE_SUCCESS(rv, rv);
mAllowRFC1918 = gTRRService->AllowRFC1918();
if (useGet) {
rv = httpChannel->SetRequestMethod(NS_LITERAL_CSTRING("GET"));
NS_ENSURE_SUCCESS(rv, rv);
} else {
rv = httpChannel->SetRequestHeader(NS_LITERAL_CSTRING("Cache-Control"),
NS_LITERAL_CSTRING("no-store"), false);
NS_ENSURE_SUCCESS(rv, rv);
nsCOMPtr<nsIUploadChannel2> uploadChannel = do_QueryInterface(httpChannel);
if (!uploadChannel) {
return NS_ERROR_UNEXPECTED;
}
uint32_t streamLength = body.Length();
nsCOMPtr<nsIInputStream> uploadStream;
rv =
NS_NewCStringInputStream(getter_AddRefs(uploadStream), std::move(body));
NS_ENSURE_SUCCESS(rv, rv);
rv = uploadChannel->ExplicitSetUploadStream(
uploadStream, NS_LITERAL_CSTRING("application/dns-message"),
streamLength, NS_LITERAL_CSTRING("POST"), false);
NS_ENSURE_SUCCESS(rv, rv);
}
// set the *default* response content type
if (NS_FAILED(httpChannel->SetContentType(
NS_LITERAL_CSTRING("application/dns-message")))) {
LOG(("TRR::SendHTTPRequest: couldn't set content-type!\n"));
}
if (NS_SUCCEEDED(httpChannel->AsyncOpen(this))) {
NS_NewTimerWithCallback(getter_AddRefs(mTimeout), this,
gTRRService->GetRequestTimeout(),
nsITimer::TYPE_ONE_SHOT);
return NS_OK;
}
mChannel = nullptr;
return NS_ERROR_UNEXPECTED;
}
NS_IMETHODIMP
TRR::GetInterface(const nsIID &iid, void **result) {
if (!iid.Equals(NS_GET_IID(nsIHttpPushListener))) {
return NS_ERROR_NO_INTERFACE;
}
nsCOMPtr<nsIHttpPushListener> copy(this);
*result = copy.forget().take();
return NS_OK;
}
nsresult TRR::DohDecodeQuery(const nsCString &query, nsCString &host,
enum TrrType &type) {
FallibleTArray<uint8_t> binary;
bool found_dns = false;
LOG(("TRR::DohDecodeQuery %s!\n", query.get()));
// extract "dns=" from the query string
nsCCharSeparatedTokenizer tokenizer(query, '&');
nsAutoCString data;
while (tokenizer.hasMoreTokens()) {
const nsACString &token = tokenizer.nextToken();
nsDependentCSubstring dns = Substring(token, 0, 4);
nsAutoCString check(dns);
if (check.Equals("dns=")) {
nsDependentCSubstring q = Substring(token, 4, -1);
data = q;
found_dns = true;
break;
}
}
if (!found_dns) {
LOG(("TRR::DohDecodeQuery no dns= in pushed URI query string\n"));
return NS_ERROR_ILLEGAL_VALUE;
}
nsresult rv =
Base64URLDecode(data, Base64URLDecodePaddingPolicy::Ignore, binary);
NS_ENSURE_SUCCESS(rv, rv);
uint32_t avail = binary.Length();
if (avail < 12) {
return NS_ERROR_FAILURE;
}
// check the query bit and the opcode
if ((binary[2] & 0xf8) != 0) {
return NS_ERROR_FAILURE;
}
uint32_t qdcount = (binary[4] << 8) + binary[5];
if (!qdcount) {
return NS_ERROR_FAILURE;
}
uint32_t index = 12;
uint32_t length = 0;
host.Truncate();
do {
if (avail < (index + 1)) {
return NS_ERROR_UNEXPECTED;
}
length = binary[index];
if (length) {
if (host.Length()) {
host.Append(".");
}
if (avail < (index + 1 + length)) {
return NS_ERROR_UNEXPECTED;
}
host.Append((const char *)(&binary[0]) + index + 1, length);
}
index += 1 + length; // skip length byte + label
} while (length);
LOG(("TRR::DohDecodeQuery host %s\n", host.get()));
if (avail < (index + 2)) {
return NS_ERROR_UNEXPECTED;
}
uint16_t i16 = 0;
i16 += binary[index] << 8;
i16 += binary[index + 1];
type = (enum TrrType)i16;
LOG(("TRR::DohDecodeQuery type %d\n", (int)type));
return NS_OK;
}
nsresult TRR::ReceivePush(nsIHttpChannel *pushed, nsHostRecord *pushedRec) {
if (!mHostResolver) {
return NS_ERROR_UNEXPECTED;
}
LOG(("TRR::ReceivePush: PUSH incoming!\n"));
nsCOMPtr<nsIURI> uri;
pushed->GetURI(getter_AddRefs(uri));
nsAutoCString query;
if (uri) {
uri->GetQuery(query);
}
PRNetAddr tempAddr;
if (NS_FAILED(DohDecodeQuery(query, mHost, mType)) ||
(PR_StringToNetAddr(mHost.get(), &tempAddr) == PR_SUCCESS)) { // literal
LOG(("TRR::ReceivePush failed to decode %s\n", mHost.get()));
return NS_ERROR_UNEXPECTED;
}
if ((mType != TRRTYPE_A) && (mType != TRRTYPE_AAAA) &&
(mType != TRRTYPE_TXT)) {
LOG(("TRR::ReceivePush unknown type %d\n", mType));
return NS_ERROR_UNEXPECTED;
}
RefPtr<nsHostRecord> hostRecord;
nsresult rv;
rv = mHostResolver->GetHostRecord(
mHost, (mType != TRRTYPE_TXT) ? 0 : nsIDNSService::RESOLVE_TYPE_TXT,
pushedRec->flags, pushedRec->af, pushedRec->pb, pushedRec->originSuffix,
getter_AddRefs(hostRecord));
if (NS_FAILED(rv)) {
return rv;
}
rv = mHostResolver->TrrLookup_unlocked(hostRecord, this);
if (NS_FAILED(rv)) {
return rv;
}
rv = pushed->AsyncOpen(this);
if (NS_FAILED(rv)) {
return rv;
}
// OK!
mChannel = pushed;
mRec.swap(hostRecord);
return NS_OK;
}
NS_IMETHODIMP
TRR::OnPush(nsIHttpChannel *associated, nsIHttpChannel *pushed) {
LOG(("TRR::OnPush entry\n"));
MOZ_ASSERT(associated == mChannel);
if (!mRec) {
return NS_ERROR_FAILURE;
}
RefPtr<TRR> trr = new TRR(mHostResolver, mPB);
return trr->ReceivePush(pushed, mRec);
}
NS_IMETHODIMP
TRR::OnStartRequest(nsIRequest *aRequest) {
LOG(("TRR::OnStartRequest %p %s %d\n", this, mHost.get(), mType));
mStartTime = TimeStamp::Now();
return NS_OK;
}
static uint16_t get16bit(unsigned char *aData, int index) {
return ((aData[index] << 8) | aData[index + 1]);
}
static uint32_t get32bit(unsigned char *aData, int index) {
return (aData[index] << 24) | (aData[index + 1] << 16) |
(aData[index + 2] << 8) | aData[index + 3];
}
nsresult TRR::PassQName(unsigned int &index) {
uint8_t length;
do {
if (mBodySize < (index + 1)) {
LOG(("TRR: PassQName:%d fail at index %d\n", __LINE__, index));
return NS_ERROR_ILLEGAL_VALUE;
}
length = static_cast<uint8_t>(mResponse[index]);
if ((length & 0xc0) == 0xc0) {
// name pointer, advance over it and be done
if (mBodySize < (index + 2)) {
return NS_ERROR_ILLEGAL_VALUE;
}
index += 2;
break;
}
if (length & 0xc0) {
LOG(("TRR: illegal label length byte (%x) at index %d\n", length, index));
return NS_ERROR_ILLEGAL_VALUE;
}
// pass label
if (mBodySize < (index + 1 + length)) {
LOG(("TRR: PassQName:%d fail at index %d\n", __LINE__, index));
return NS_ERROR_ILLEGAL_VALUE;
}
index += 1 + length;
} while (length);
return NS_OK;
}
// GetQname: retrieves the qname (stores in 'aQname') and stores the index
// after qname was parsed into the 'aIndex'.
nsresult TRR::GetQname(nsAutoCString &aQname, unsigned int &aIndex) {
uint8_t clength = 0;
unsigned int cindex = aIndex;
unsigned int loop = 128; // a valid DNS name can never loop this much
unsigned int endindex = 0; // index position after this data
do {
if (cindex >= mBodySize) {
LOG(("TRR: bad cname packet\n"));
return NS_ERROR_ILLEGAL_VALUE;
}
clength = static_cast<uint8_t>(mResponse[cindex]);
if ((clength & 0xc0) == 0xc0) {
// name pointer, get the new offset (14 bits)
if ((cindex + 1) >= mBodySize) {
return NS_ERROR_ILLEGAL_VALUE;
}
// extract the new index position for the next label
uint16_t newpos = (clength & 0x3f) << 8 | mResponse[cindex + 1];
if (!endindex) {
// only update on the first "jump"
endindex = cindex + 2;
}
cindex = newpos;
continue;
} else if (clength & 0xc0) {
// any of those bits set individually is an error
LOG(("TRR: bad cname packet\n"));
return NS_ERROR_ILLEGAL_VALUE;
} else {
cindex++;
}
if (clength) {
if (!aQname.IsEmpty()) {
aQname.Append(".");
}
if ((cindex + clength) > mBodySize) {
return NS_ERROR_ILLEGAL_VALUE;
}
aQname.Append((const char *)(&mResponse[cindex]), clength);
cindex += clength; // skip label
}
} while (clength && --loop);
if (!loop) {
LOG(("TRR::DohDecode pointer loop error\n"));
return NS_ERROR_ILLEGAL_VALUE;
}
if (!endindex) {
// there was no "jump"
endindex = cindex;
}
aIndex = endindex;
return NS_OK;
}
//
// DohDecode() collects the TTL and the IP addresses in the response
//
nsresult TRR::DohDecode(nsCString &aHost) {
// The response has a 12 byte header followed by the question (returned)
// and then the answer. The answer section itself contains the name, type
// and class again and THEN the record data.
// www.example.com response:
// header:
// abcd 8180 0001 0001 0000 0000
// the question:
// 0377 7777 0765 7861 6d70 6c65 0363 6f6d 0000 0100 01
// the answer:
// 03 7777 7707 6578 616d 706c 6503 636f 6d00 0001 0001
// 0000 0080 0004 5db8 d822
unsigned int index = 12;
uint8_t length;
nsAutoCString host;
nsresult rv;
LOG(("doh decode %s %d bytes\n", aHost.get(), mBodySize));
mCname.Truncate();
if (mBodySize < 12 || mResponse[0] || mResponse[1]) {
LOG(("TRR bad incoming DOH, eject!\n"));
return NS_ERROR_ILLEGAL_VALUE;
}
uint8_t rcode = mResponse[3] & 0x0F;
if (rcode) {
LOG(("TRR Decode %s RCODE %d\n", aHost.get(), rcode));
return NS_ERROR_FAILURE;
}
uint16_t questionRecords = get16bit(mResponse, 4); // qdcount
// iterate over the single(?) host name in question
while (questionRecords) {
do {
if (mBodySize < (index + 1)) {
return NS_ERROR_ILLEGAL_VALUE;
}
length = static_cast<uint8_t>(mResponse[index]);
if (length) {
if (host.Length()) {
host.Append(".");
}
if (mBodySize < (index + 1 + length)) {
return NS_ERROR_ILLEGAL_VALUE;
}
host.Append(((char *)mResponse) + index + 1, length);
}
index += 1 + length; // skip length byte + label
} while (length);
if (mBodySize < (index + 4)) {
return NS_ERROR_ILLEGAL_VALUE;
}
index += 4; // skip question's type, class
questionRecords--;
}
// Figure out the number of answer records from ANCOUNT
uint16_t answerRecords = get16bit(mResponse, 6);
LOG(("TRR Decode: %d answer records (%u bytes body) %s index=%u\n",
answerRecords, mBodySize, host.get(), index));
while (answerRecords) {
nsAutoCString qname;
rv = GetQname(qname, index);
if (NS_FAILED(rv)) {
return rv;
}
// 16 bit TYPE
if (mBodySize < (index + 2)) {
LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index + 2));
return NS_ERROR_ILLEGAL_VALUE;
}
uint16_t TYPE = get16bit(mResponse, index);
if ((TYPE != TRRTYPE_CNAME) && (TYPE != static_cast<uint16_t>(mType))) {
// Not the same type as was asked for nor CNAME
LOG(("TRR: Dohdecode:%d asked for type %d got %d\n", __LINE__, mType,
TYPE));
return NS_ERROR_UNEXPECTED;
}
index += 2;
// 16 bit class
if (mBodySize < (index + 2)) {
LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index + 2));
return NS_ERROR_ILLEGAL_VALUE;
}
uint16_t CLASS = get16bit(mResponse, index);
if (kDNS_CLASS_IN != CLASS) {
LOG(("TRR bad CLASS (%u) at index %d\n", CLASS, index));
return NS_ERROR_UNEXPECTED;
}
index += 2;
// 32 bit TTL (seconds)
if (mBodySize < (index + 4)) {
LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index));
return NS_ERROR_ILLEGAL_VALUE;
}
uint32_t TTL = get32bit(mResponse, index);
index += 4;
// 16 bit RDLENGTH
if (mBodySize < (index + 2)) {
LOG(("TRR: Dohdecode:%d fail at index %d\n", __LINE__, index));
return NS_ERROR_ILLEGAL_VALUE;
}
uint16_t RDLENGTH = get16bit(mResponse, index);
index += 2;
if (mBodySize < (index + RDLENGTH)) {
LOG(("TRR: Dohdecode:%d fail RDLENGTH=%d at index %d\n", __LINE__,
RDLENGTH, index));
return NS_ERROR_ILLEGAL_VALUE;
}
if (qname.Equals(aHost)) {
// RDATA
// - A (TYPE 1): 4 bytes
// - AAAA (TYPE 28): 16 bytes
// - NS (TYPE 2): N bytes
switch (TYPE) {
case TRRTYPE_A:
if (RDLENGTH != 4) {
LOG(("TRR bad length for A (%u)\n", RDLENGTH));
return NS_ERROR_UNEXPECTED;
}
rv = mDNS.Add(TTL, mResponse, index, RDLENGTH, mAllowRFC1918);
if (NS_FAILED(rv)) {
LOG(
("TRR:DohDecode failed: local IP addresses or unknown IP "
"family\n"));
return rv;
}
break;
case TRRTYPE_AAAA:
if (RDLENGTH != 16) {
LOG(("TRR bad length for AAAA (%u)\n", RDLENGTH));
return NS_ERROR_UNEXPECTED;
}
rv = mDNS.Add(TTL, mResponse, index, RDLENGTH, mAllowRFC1918);
if (NS_FAILED(rv)) {
LOG(("TRR got unique/local IPv6 address!\n"));
return rv;
}
break;
case TRRTYPE_NS:
break;
case TRRTYPE_CNAME:
if (mCname.IsEmpty()) {
nsAutoCString qname;
unsigned int qnameindex = index;
rv = GetQname(qname, qnameindex);
if (NS_FAILED(rv)) {
return rv;
}
if (!qname.IsEmpty()) {
mCname = qname;
LOG(("TRR::DohDecode CNAME host %s => %s\n", host.get(),
mCname.get()));
} else {
LOG(("TRR::DohDecode empty CNAME for host %s!\n", host.get()));
}
} else {
LOG(("TRR::DohDecode CNAME - ignoring another entry\n"));
}
break;
case TRRTYPE_TXT: {
// TXT record RRDATA sections are a series of character-strings
// each character string is a length byte followed by that many data
// bytes
nsAutoCString txt;
unsigned int txtIndex = index;
uint16_t available = RDLENGTH;
while (available > 0) {
uint8_t characterStringLen = mResponse[txtIndex++];
available--;
if (characterStringLen > available) {
LOG(("TRR::DohDecode MALFORMED TXT RECORD\n"));
break;
}
txt.Append((const char *)(&mResponse[txtIndex]),
characterStringLen);
txtIndex += characterStringLen;
available -= characterStringLen;
}
mTxt.AppendElement(txt);
if (mTxtTtl > TTL) {
mTxtTtl = TTL;
}
LOG(("TRR::DohDecode TXT host %s => %s\n", host.get(), txt.get()));
break;
}
default:
// skip unknown record types
LOG(("TRR unsupported TYPE (%u) RDLENGTH %u\n", TYPE, RDLENGTH));
break;
}
} else {
LOG(("TRR asked for %s data but got %s\n", aHost.get(), qname.get()));
}
index += RDLENGTH;
LOG(("done with record type %u len %u index now %u of %u\n", TYPE, RDLENGTH,
index, mBodySize));
answerRecords--;
}
// NSCOUNT
uint16_t nsRecords = get16bit(mResponse, 8);
LOG(("TRR Decode: %d ns records (%u bytes body)\n", nsRecords, mBodySize));
while (nsRecords) {
rv = PassQName(index);
if (NS_FAILED(rv)) {
return rv;
}
if (mBodySize < (index + 8)) {
return NS_ERROR_ILLEGAL_VALUE;
}
index += 2; // type
index += 2; // class
index += 4; // ttl
// 16 bit RDLENGTH
if (mBodySize < (index + 2)) {
return NS_ERROR_ILLEGAL_VALUE;
}
uint16_t RDLENGTH = get16bit(mResponse, index);
index += 2;
if (mBodySize < (index + RDLENGTH)) {
return NS_ERROR_ILLEGAL_VALUE;
}
index += RDLENGTH;
LOG(("done with nsRecord now %u of %u\n", index, mBodySize));
nsRecords--;
}
// additional resource records
uint16_t arRecords = get16bit(mResponse, 10);
LOG(("TRR Decode: %d additional resource records (%u bytes body)\n",
arRecords, mBodySize));
while (arRecords) {
rv = PassQName(index);
if (NS_FAILED(rv)) {
return rv;
}
if (mBodySize < (index + 8)) {
return NS_ERROR_ILLEGAL_VALUE;
}
index += 2; // type
index += 2; // class
index += 4; // ttl
// 16 bit RDLENGTH
if (mBodySize < (index + 2)) {
return NS_ERROR_ILLEGAL_VALUE;
}
uint16_t RDLENGTH = get16bit(mResponse, index);
index += 2;
if (mBodySize < (index + RDLENGTH)) {
return NS_ERROR_ILLEGAL_VALUE;
}
index += RDLENGTH;
LOG(("done with additional rr now %u of %u\n", index, mBodySize));
arRecords--;
}
if (index != mBodySize) {
LOG(("DohDecode failed to parse entire response body, %u out of %u bytes\n",
index, mBodySize));
// failed to parse 100%, do not continue
return NS_ERROR_ILLEGAL_VALUE;
}
if ((mType != TRRTYPE_NS) && mCname.IsEmpty() &&
!mDNS.mAddresses.getFirst() && mTxt.IsEmpty()) {
// no entries were stored!
LOG(("TRR: No entries were stored!\n"));
return NS_ERROR_FAILURE;
}
return NS_OK;
}
nsresult TRR::ReturnData() {
if (mType != TRRTYPE_TXT) {
// create and populate an AddrInfo instance to pass on
RefPtr<AddrInfo> ai(new AddrInfo(mHost, mType));
DOHaddr *item;
uint32_t ttl = AddrInfo::NO_TTL_DATA;
while ((item = static_cast<DOHaddr *>(mDNS.mAddresses.popFirst()))) {
PRNetAddr prAddr;
NetAddrToPRNetAddr(&item->mNet, &prAddr);
auto *addrElement = new NetAddrElement(&prAddr);
ai->AddAddress(addrElement);
if (item->mTtl < ttl) {
// While the DNS packet might return individual TTLs for each address,
// we can only return one value in the AddrInfo class so pick the
// lowest number.
ttl = item->mTtl;
}
}
ai->ttl = ttl;
if (!mHostResolver) {
return NS_ERROR_FAILURE;
}
(void)mHostResolver->CompleteLookup(mRec, NS_OK, ai, mPB, mOriginSuffix);
mHostResolver = nullptr;
mRec = nullptr;
} else {
(void)mHostResolver->CompleteLookupByType(mRec, NS_OK, &mTxt, mTxtTtl, mPB);
}
return NS_OK;
}
nsresult TRR::FailData(nsresult error) {
if (!mHostResolver) {
return NS_ERROR_FAILURE;
}
if (mType == TRRTYPE_TXT) {
(void)mHostResolver->CompleteLookupByType(mRec, error, nullptr, 0, mPB);
} else {
// create and populate an TRR AddrInfo instance to pass on to signal that
// this comes from TRR
RefPtr<AddrInfo> ai = new AddrInfo(mHost, mType);
(void)mHostResolver->CompleteLookup(mRec, error, ai, mPB, mOriginSuffix);
}
mHostResolver = nullptr;
mRec = nullptr;
return NS_OK;
}
nsresult TRR::On200Response() {
// decode body and create an AddrInfo struct for the response
nsresult rv = DohDecode(mHost);
if (NS_SUCCEEDED(rv)) {
if (!mDNS.mAddresses.getFirst() && !mCname.IsEmpty() &&
mType != TRRTYPE_TXT) {
nsCString cname = mCname;
LOG(("TRR: check for CNAME record for %s within previous response\n",
cname.get()));
rv = DohDecode(cname);
if (NS_SUCCEEDED(rv) && mDNS.mAddresses.getFirst()) {
LOG(("TRR: Got the CNAME record without asking for it\n"));
ReturnData();
return NS_OK;
}
// restore mCname as DohDecode() change it
mCname = cname;
if (!--mCnameLoop) {
LOG(("TRR::On200Response CNAME loop, eject!\n"));
} else {
LOG(("TRR::On200Response CNAME %s => %s (%u)\n", mHost.get(),
mCname.get(), mCnameLoop));
RefPtr<TRR> trr =
new TRR(mHostResolver, mRec, mCname, mType, mCnameLoop, mPB);
rv = NS_DispatchToMainThread(trr);
if (NS_SUCCEEDED(rv)) {
return rv;
}
}
} else {
// pass back the response data
ReturnData();
return NS_OK;
}
} else {
LOG(("TRR::On200Response DohDecode %x\n", (int)rv));
}
return NS_ERROR_FAILURE;
}
NS_IMETHODIMP
TRR::OnStopRequest(nsIRequest *aRequest, nsresult aStatusCode) {
// The dtor will be run after the function returns
LOG(("TRR:OnStopRequest %p %s %d failed=%d code=%X\n", this, mHost.get(),
mType, mFailed, (unsigned int)aStatusCode));
nsCOMPtr<nsIChannel> channel;
channel.swap(mChannel);
// Bad content is still considered "okay" if the HTTP response is okay
gTRRService->TRRIsOkay(NS_SUCCEEDED(aStatusCode) ? TRRService::OKAY_NORMAL
: TRRService::OKAY_BAD);
// if status was "fine", parse the response and pass on the answer
if (!mFailed && NS_SUCCEEDED(aStatusCode)) {
nsCOMPtr<nsIHttpChannel> httpChannel = do_QueryInterface(aRequest);
if (!httpChannel) {
return NS_ERROR_UNEXPECTED;
}
nsresult rv = NS_OK;
nsAutoCString contentType;
httpChannel->GetContentType(contentType);
if (contentType.Length() &&
!contentType.LowerCaseEqualsLiteral("application/dns-message")) {
LOG(("TRR:OnStopRequest %p %s %d wrong content type %s\n", this,
mHost.get(), mType, contentType.get()));
FailData(NS_ERROR_UNEXPECTED);
return NS_OK;
}
uint32_t httpStatus;
rv = httpChannel->GetResponseStatus(&httpStatus);
if (NS_SUCCEEDED(rv) && httpStatus == 200) {
rv = On200Response();
if (NS_SUCCEEDED(rv)) {
return rv;
}
} else {
LOG(("TRR:OnStopRequest:%d %p rv %x httpStatus %d\n", __LINE__, this,
(int)rv, httpStatus));
}
}
LOG(("TRR:OnStopRequest %p status %x mFailed %d\n", this, (int)aStatusCode,
mFailed));
FailData(NS_ERROR_UNKNOWN_HOST);
return NS_OK;
}
NS_IMETHODIMP
TRR::OnDataAvailable(nsIRequest *aRequest, nsIInputStream *aInputStream,
uint64_t aOffset, const uint32_t aCount) {
LOG(("TRR:OnDataAvailable %p %s %d failed=%d aCount=%u\n", this, mHost.get(),
mType, mFailed, (unsigned int)aCount));
// receive DNS response into the local buffer
if (mFailed) {
return NS_ERROR_FAILURE;
}
if (aCount + mBodySize > kMaxSize) {
LOG(("TRR::OnDataAvailable:%d fail\n", __LINE__));
mFailed = true;
return NS_ERROR_FAILURE;
}
uint32_t count;
nsresult rv =
aInputStream->Read((char *)mResponse + mBodySize, aCount, &count);
if (NS_FAILED(rv)) {
LOG(("TRR::OnDataAvailable:%d fail\n", __LINE__));
mFailed = true;
return rv;
}
MOZ_ASSERT(count == aCount);
mBodySize += aCount;
return NS_OK;
}
nsresult DOHresp::Add(uint32_t TTL, unsigned char *dns, int index, uint16_t len,
bool aLocalAllowed) {
nsAutoPtr<DOHaddr> doh(new DOHaddr);
NetAddr *addr = &doh->mNet;
if (4 == len) {
// IPv4
addr->inet.family = AF_INET;
addr->inet.port = 0; // unknown
addr->inet.ip = ntohl(get32bit(dns, index));
} else if (16 == len) {
// IPv6
addr->inet6.family = AF_INET6;
addr->inet6.port = 0; // unknown
addr->inet6.flowinfo = 0; // unknown
addr->inet6.scope_id = 0; // unknown
for (int i = 0; i < 16; i++, index++) {
addr->inet6.ip.u8[i] = dns[index];
}
} else {
return NS_ERROR_UNEXPECTED;
}
if (IsIPAddrLocal(addr) && !aLocalAllowed) {
return NS_ERROR_FAILURE;
}
doh->mTtl = TTL;
if (LOG_ENABLED()) {
char buf[128];
NetAddrToString(addr, buf, sizeof(buf));
LOG(("DOHresp:Add %s\n", buf));
}
mAddresses.insertBack(doh.forget());
return NS_OK;
}
class ProxyCancel : public Runnable {
public:
explicit ProxyCancel(TRR *aTRR) : Runnable("proxyTrrCancel"), mTRR(aTRR) {}
NS_IMETHOD Run() override {
mTRR->Cancel();
mTRR = nullptr;
return NS_OK;
}
private:
RefPtr<TRR> mTRR;
};
void TRR::Cancel() {
if (!NS_IsMainThread()) {
NS_DispatchToMainThread(new ProxyCancel(this));
return;
}
if (mChannel) {
LOG(("TRR: %p canceling Channel %p %s %d\n", this, mChannel.get(),
mHost.get(), mType));
mChannel->Cancel(NS_ERROR_ABORT);
gTRRService->TRRIsOkay(TRRService::OKAY_TIMEOUT);
}
}
#undef LOG
// namespace
} // namespace net
} // namespace mozilla