diff --git a/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp b/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp index 5ddf57f32b26..a07b272e6bba 100644 --- a/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp +++ b/toolkit/components/url-classifier/src/nsUrlClassifierDBService.cpp @@ -489,9 +489,12 @@ private: nsresult GetLookupFragments(const nsCSubstring& spec, nsTArray& fragments); + // Check for a canonicalized IP address. + PRBool IsCanonicalizedIP(const nsACString& host); + // Get the database key for a given URI. This is the top three // domain components if they exist, otherwise the top two. - // hostname.com/foo/bar -> hostname + // hostname.com/foo/bar -> hostname.com // mail.hostname.com/foo/bar -> mail.hostname.com // www.mail.hostname.com/foo/bar -> mail.hostname.com nsresult GetKey(const nsACString& spec, nsUrlClassifierHash& hash); @@ -812,39 +815,47 @@ nsUrlClassifierDBServiceWorker::DoLookup(const nsACString& spec, } const nsCSubstring& host = Substring(begin, iter++); - nsCStringArray hostComponents; - hostComponents.ParseString(PromiseFlatCString(host).get(), "."); - if (hostComponents.Count() < 2) { - // no host or toplevel host, this won't match anything in the db - c->HandleEvent(EmptyCString()); - return NS_OK; - } - - // First check with two domain components - PRInt32 last = hostComponents.Count() - 1; - nsCAutoString lookupHost; - lookupHost.Assign(*hostComponents[last - 1]); - lookupHost.Append("."); - lookupHost.Append(*hostComponents[last]); - lookupHost.Append("/"); - nsUrlClassifierHash hash; - hash.FromPlaintext(lookupHost, mCryptoHash); - - // we ignore failures from CheckKey because we'd rather try to find - // more results than fail. nsTArray resultTables; - CheckKey(spec, hash, resultTables); - - // Now check with three domain components - if (hostComponents.Count() > 2) { - nsCAutoString lookupHost2; - lookupHost2.Assign(*hostComponents[last - 2]); - lookupHost2.Append("."); - lookupHost2.Append(lookupHost); - hash.FromPlaintext(lookupHost2, mCryptoHash); + nsUrlClassifierHash hash; + if (IsCanonicalizedIP(host)) { + // Don't break up the host into components + hash.FromPlaintext(host, mCryptoHash); CheckKey(spec, hash, resultTables); + } else { + nsCStringArray hostComponents; + hostComponents.ParseString(PromiseFlatCString(host).get(), "."); + + if (hostComponents.Count() < 2) { + // no host or toplevel host, this won't match anything in the db + c->HandleEvent(EmptyCString()); + return NS_OK; + } + + // First check with two domain components + PRInt32 last = hostComponents.Count() - 1; + nsCAutoString lookupHost; + lookupHost.Assign(*hostComponents[last - 1]); + lookupHost.Append("."); + lookupHost.Append(*hostComponents[last]); + lookupHost.Append("/"); + hash.FromPlaintext(lookupHost, mCryptoHash); + + // we ignore failures from CheckKey because we'd rather try to find + // more results than fail. + CheckKey(spec, hash, resultTables); + + // Now check with three domain components + if (hostComponents.Count() > 2) { + nsCAutoString lookupHost2; + lookupHost2.Assign(*hostComponents[last - 2]); + lookupHost2.Append("."); + lookupHost2.Append(lookupHost); + hash.FromPlaintext(lookupHost2, mCryptoHash); + + CheckKey(spec, hash, resultTables); + } } nsCAutoString result; @@ -1151,6 +1162,21 @@ nsUrlClassifierDBServiceWorker::WriteEntry(nsUrlClassifierEntry& entry) return NS_OK; } +PRBool +nsUrlClassifierDBServiceWorker::IsCanonicalizedIP(const nsACString& host) +{ + // The canonicalization process will have left IP addresses in dotted + // decimal with no surprises. + PRUint32 i1, i2, i3, i4; + char c; + if (PR_sscanf(PromiseFlatCString(host).get(), "%u.%u.%u.%u%c", + &i1, &i2, &i3, &i4, &c) == 4) { + return (i1 <= 0xFF && i1 <= 0xFF && i1 <= 0xFF && i1 <= 0xFF); + } + + return PR_FALSE; +} + nsresult nsUrlClassifierDBServiceWorker::GetKey(const nsACString& spec, nsUrlClassifierHash& hash) @@ -1164,7 +1190,12 @@ nsUrlClassifierDBServiceWorker::GetKey(const nsACString& spec, return NS_OK; } - const nsCSubstring& host = Substring(begin, iter++); + const nsCSubstring& host = Substring(begin, iter); + + if (IsCanonicalizedIP(host)) { + return hash.FromPlaintext(host, mCryptoHash); + } + nsCStringArray hostComponents; hostComponents.ParseString(PromiseFlatCString(host).get(), ".");