diff --git a/modules/libpref/init/StaticPrefList.yaml b/modules/libpref/init/StaticPrefList.yaml index d34f41da1c58..1a6ec3e532f7 100644 --- a/modules/libpref/init/StaticPrefList.yaml +++ b/modules/libpref/init/StaticPrefList.yaml @@ -8400,6 +8400,14 @@ value: 2000 mirror: always +# When true on Windows DNS resolutions for single label domains +# (domains that don't contain a dot) will be resolved using the DnsQuery +# API instead of PR_GetAddrInfoByName +- name: network.dns.dns_query_single_label + type: RelaxedAtomicBool + value: true + mirror: always + # The proxy type. See nsIProtocolProxyService.idl # PROXYCONFIG_DIRECT = 0 # PROXYCONFIG_MANUAL = 1 diff --git a/netwerk/dns/GetAddrInfo.cpp b/netwerk/dns/GetAddrInfo.cpp index fba22a1e1209..d9b7f4e75696 100644 --- a/netwerk/dns/GetAddrInfo.cpp +++ b/netwerk/dns/GetAddrInfo.cpp @@ -5,6 +5,18 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "GetAddrInfo.h" + +#ifdef DNSQUERY_AVAILABLE +// There is a bug in windns.h where the type of parameter ppQueryResultsSet for +// DnsQuery_A is dependent on UNICODE being set. It should *always* be +// PDNS_RECORDA, but if UNICODE is set it is PDNS_RECORDW. To get around this +// we make sure that UNICODE is unset. +# undef UNICODE +# include +# undef GetAddrInfo +# include +#endif // DNSQUERY_AVAILABLE + #include "mozilla/ClearOnShutdown.h" #include "mozilla/net/DNS.h" #include "NativeDNSResolverOverrideParent.h" @@ -19,17 +31,6 @@ #include "mozilla/Logging.h" #include "mozilla/StaticPrefs_network.h" -#ifdef DNSQUERY_AVAILABLE -// There is a bug in windns.h where the type of parameter ppQueryResultsSet for -// DnsQuery_A is dependent on UNICODE being set. It should *always* be -// PDNS_RECORDA, but if UNICODE is set it is PDNS_RECORDW. To get around this -// we make sure that UNICODE is unset. -# undef UNICODE -# include -# undef GetAddrInfo -# include -#endif - namespace mozilla { namespace net { @@ -42,6 +43,11 @@ static LazyLogModule gGetAddrInfoLog("GetAddrInfo"); MOZ_LOG(gGetAddrInfoLog, LogLevel::Warning, ("[DNS]: " msg, ##__VA_ARGS__)) #ifdef DNSQUERY_AVAILABLE + +# define COMPUTER_NAME_BUFFER_SIZE 100 +static char sDNSComputerName[COMPUTER_NAME_BUFFER_SIZE]; +static char sNETBIOSComputerName[MAX_COMPUTERNAME_LENGTH + 1]; + //////////////////////////// // WINDOWS IMPLEMENTATION // //////////////////////////// @@ -56,47 +62,65 @@ static_assert(PR_AF_INET == AF_INET && PR_AF_INET6 == AF_INET6 && // equal with the one already there. Gets the TTL value by calling // to DnsQuery_A and iterating through the returned // records to find the one with the smallest TTL value. -static MOZ_ALWAYS_INLINE nsresult _GetMinTTLForRequestType_Windows( - const char* aHost, uint16_t aRequestType, unsigned int* aResult) { - MOZ_ASSERT(aHost); - MOZ_ASSERT(aResult); +static MOZ_ALWAYS_INLINE nsresult _CallDnsQuery_A_Windows( + const nsACString& aHost, uint16_t aAddressFamily, DWORD aFlags, + std::function aCallback) { + NS_ConvertASCIItoUTF16 name(aHost); - PDNS_RECORDA dnsData = nullptr; - DNS_STATUS status = DnsQuery_A( - aHost, aRequestType, - (DNS_QUERY_STANDARD | DNS_QUERY_NO_NETBT | DNS_QUERY_NO_HOSTS_FILE | - DNS_QUERY_NO_MULTICAST | DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE | - DNS_QUERY_DONT_RESET_TTL_VALUES), - nullptr, &dnsData, nullptr); - if (status == DNS_INFO_NO_RECORDS || status == DNS_ERROR_RCODE_NAME_ERROR || - !dnsData) { - LOG("No DNS records found for %s. status=%X. aRequestType = %X\n", aHost, - status, aRequestType); - return NS_ERROR_FAILURE; - } else if (status != NOERROR) { - LOG_WARNING("DnsQuery_A failed with status %X.\n", status); - return NS_ERROR_UNEXPECTED; - } - - for (PDNS_RECORDA curRecord = dnsData; curRecord; - curRecord = curRecord->pNext) { - // Only records in the answer section are important - if (curRecord->Flags.S.Section != DnsSectionAnswer) { - continue; + auto callDnsQuery_A = [&](uint16_t reqFamily) { + PDNS_RECORDA dnsData = nullptr; + DNS_STATUS status = DnsQuery_A(aHost.BeginReading(), reqFamily, aFlags, + nullptr, &dnsData, nullptr); + if (status == DNS_INFO_NO_RECORDS || status == DNS_ERROR_RCODE_NAME_ERROR || + !dnsData) { + LOG("No DNS records found for %s. status=%X. reqFamily = %X\n", + aHost.BeginReading(), status, reqFamily); + return NS_ERROR_FAILURE; + } else if (status != NOERROR) { + LOG_WARNING("DnsQuery_A failed with status %X.\n", status); + return NS_ERROR_UNEXPECTED; } - if (curRecord->wType == aRequestType) { - *aResult = std::min(*aResult, curRecord->dwTtl); - } else { - LOG("Received unexpected record type %u in response for %s.\n", - curRecord->wType, aHost); + for (PDNS_RECORDA curRecord = dnsData; curRecord; + curRecord = curRecord->pNext) { + // Only records in the answer section are important + if (curRecord->Flags.S.Section != DnsSectionAnswer) { + continue; + } + if (curRecord->wType != reqFamily) { + continue; + } + + aCallback(curRecord); } + + DnsFree(dnsData, DNS_FREE_TYPE::DnsFreeRecordList); + return NS_OK; + }; + + if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET) { + callDnsQuery_A(DNS_TYPE_A); } - DnsFree(dnsData, DNS_FREE_TYPE::DnsFreeRecordList); + if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET6) { + callDnsQuery_A(DNS_TYPE_AAAA); + } return NS_OK; } +bool recordTypeMatchesRequest(uint16_t wType, uint16_t aAddressFamily) { + if (aAddressFamily == PR_AF_UNSPEC) { + return wType == DNS_TYPE_A || wType == DNS_TYPE_AAAA; + } + if (aAddressFamily == PR_AF_INET) { + return wType == DNS_TYPE_A; + } + if (aAddressFamily == PR_AF_INET6) { + return wType == DNS_TYPE_AAAA; + } + return false; +} + static MOZ_ALWAYS_INLINE nsresult _GetTTLData_Windows(const nsACString& aHost, uint32_t* aResult, uint16_t aAddressFamily) { @@ -110,13 +134,21 @@ static MOZ_ALWAYS_INLINE nsresult _GetTTLData_Windows(const nsACString& aHost, // In order to avoid using ANY records which are not always implemented as a // "Gimme what you have" request in hostname resolvers, we should send A // and/or AAAA requests, based on the address family requested. + const DWORD ttlFlags = + (DNS_QUERY_STANDARD | DNS_QUERY_NO_NETBT | DNS_QUERY_NO_HOSTS_FILE | + DNS_QUERY_NO_MULTICAST | DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE | + DNS_QUERY_DONT_RESET_TTL_VALUES); unsigned int ttl = (unsigned int)-1; - if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET) { - _GetMinTTLForRequestType_Windows(aHost.BeginReading(), DNS_TYPE_A, &ttl); - } - if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET6) { - _GetMinTTLForRequestType_Windows(aHost.BeginReading(), DNS_TYPE_AAAA, &ttl); - } + _CallDnsQuery_A_Windows( + aHost, aAddressFamily, ttlFlags, + [&ttl, &aHost, aAddressFamily](PDNS_RECORDA curRecord) { + if (recordTypeMatchesRequest(curRecord->wType, aAddressFamily)) { + ttl = std::min(ttl, curRecord->dwTtl); + } else { + LOG("Received unexpected record type %u in response for %s.\n", + curRecord->wType, aHost.BeginReading()); + } + }); if (ttl == (unsigned int)-1) { LOG("No useable TTL found."); @@ -126,6 +158,41 @@ static MOZ_ALWAYS_INLINE nsresult _GetTTLData_Windows(const nsACString& aHost, *aResult = ttl; return NS_OK; } + +static MOZ_ALWAYS_INLINE nsresult +_DNSQuery_A_SingleLabel(const nsACString& aCanonHost, uint16_t aAddressFamily, + uint16_t aFlags, AddrInfo** aAddrInfo) { + bool setCanonName = aFlags & nsHostResolver::RES_CANON_NAME; + nsAutoCString canonName; + const DWORD flags = (DNS_QUERY_STANDARD | DNS_QUERY_NO_MULTICAST | + DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE); + nsTArray addresses; + + _CallDnsQuery_A_Windows( + aCanonHost, aAddressFamily, flags, [&](PDNS_RECORDA curRecord) { + MOZ_DIAGNOSTIC_ASSERT(curRecord->wType == DNS_TYPE_A || + curRecord->wType == DNS_TYPE_AAAA); + if (setCanonName) { + canonName.Assign(curRecord->pName); + } + NetAddr addr{}; + addr.inet.family = AF_INET; + addr.inet.ip = curRecord->Data.A.IpAddress; + addresses.AppendElement(addr); + }); + + LOG("Query for: %s has %u results", aCanonHost.BeginReading(), + addresses.Length()); + if (addresses.IsEmpty()) { + return NS_ERROR_UNKNOWN_HOST; + } + RefPtr ai( + new AddrInfo(aCanonHost, canonName, 0, std::move(addresses))); + ai.forget(aAddrInfo); + + return NS_OK; +} + #endif //////////////////////////////////// @@ -153,6 +220,24 @@ _GetAddrInfo_Portable(const nsACString& aCanonHost, uint16_t aAddressFamily, aAddressFamily = PR_AF_UNSPEC; } +#if defined(DNSQUERY_AVAILABLE) + if (StaticPrefs::network_dns_dns_query_single_label() && + !aCanonHost.Contains('.') && aCanonHost != "localhost"_ns) { + // For some reason we can't use DnsQuery_A to get the computer's IP. + if (!aCanonHost.Equals(nsDependentCString(sDNSComputerName), + nsCaseInsensitiveCStringComparator) && + !aCanonHost.Equals(nsDependentCString(sNETBIOSComputerName), + nsCaseInsensitiveCStringComparator)) { + // This is a single label name resolve without a dot. + // We use DNSQuery_A for these. + LOG("Resolving %s using DnsQuery_A (computername: %s)\n", + aCanonHost.BeginReading(), sDNSComputerName); + return _DNSQuery_A_SingleLabel(aCanonHost, aAddressFamily, aFlags, + aAddrInfo); + } + } +#endif + PRAddrInfo* prai = PR_GetAddrInfoByName(aCanonHost.BeginReading(), aAddressFamily, prFlags); @@ -184,6 +269,19 @@ _GetAddrInfo_Portable(const nsACString& aCanonHost, uint16_t aAddressFamily, ////////////////////////////////////// nsresult GetAddrInfoInit() { LOG("Initializing GetAddrInfo.\n"); + +#ifdef DNSQUERY_AVAILABLE + DWORD namesize = COMPUTER_NAME_BUFFER_SIZE; + if (!GetComputerNameEx(ComputerNameDnsHostname, sDNSComputerName, + &namesize)) { + sDNSComputerName[0] = 0; + } + namesize = MAX_COMPUTERNAME_LENGTH + 1; + if (!GetComputerNameEx(ComputerNameNetBIOS, sNETBIOSComputerName, + &namesize)) { + sNETBIOSComputerName[0] = 0; + } +#endif return NS_OK; } diff --git a/netwerk/dns/moz.build b/netwerk/dns/moz.build index e53672b6a832..126c6187c312 100644 --- a/netwerk/dns/moz.build +++ b/netwerk/dns/moz.build @@ -52,6 +52,7 @@ EXPORTS.mozilla.net += [ ] SOURCES += [ + "GetAddrInfo.cpp", # Undefines UNICODE "nsEffectiveTLDService.cpp", # Excluded from UNIFIED_SOURCES due to special build flags. "nsHostResolver.cpp", # Redefines LOG ] @@ -63,7 +64,6 @@ UNIFIED_SOURCES += [ "DNSRequestChild.cpp", "DNSRequestParent.cpp", "DNSResolverInfo.cpp", - "GetAddrInfo.cpp", "HTTPSSVC.cpp", "IDNBlocklistUtils.cpp", "NativeDNSResolverOverrideChild.cpp",