diff --git a/netwerk/dns/HTTPSSVC.cpp b/netwerk/dns/HTTPSSVC.cpp index 100a4691fcfa..bc30a18a1e5e 100644 --- a/netwerk/dns/HTTPSSVC.cpp +++ b/netwerk/dns/HTTPSSVC.cpp @@ -54,7 +54,7 @@ NS_INTERFACE_MAP_END NS_IMETHODIMP SvcParam::GetType(uint16_t* aType) { *aType = mValue.match( - [](Nothing&) { return SvcParamKeyNone; }, + [](Nothing&) { return SvcParamKeyMandatory; }, [](SvcParamAlpn&) { return SvcParamKeyAlpn; }, [](SvcParamNoDefaultAlpn&) { return SvcParamKeyNoDefaultAlpn; }, [](SvcParamPort&) { return SvcParamKeyPort; }, diff --git a/netwerk/dns/HTTPSSVC.h b/netwerk/dns/HTTPSSVC.h index d4eb734f0875..7c8184dd40c2 100644 --- a/netwerk/dns/HTTPSSVC.h +++ b/netwerk/dns/HTTPSSVC.h @@ -14,7 +14,7 @@ namespace mozilla { namespace net { enum SvcParamKey : uint16_t { - SvcParamKeyNone = 0, + SvcParamKeyMandatory = 0, SvcParamKeyAlpn = 1, SvcParamKeyNoDefaultAlpn = 2, SvcParamKeyPort = 3, @@ -82,7 +82,7 @@ struct SVCB { mSvcDomainName == aOther.mSvcDomainName && mSvcFieldValue == aOther.mSvcFieldValue; } - uint16_t mSvcFieldPriority = SvcParamKeyNone; + uint16_t mSvcFieldPriority = 0; nsCString mSvcDomainName; CopyableTArray mSvcFieldValue; }; diff --git a/netwerk/dns/TRR.cpp b/netwerk/dns/TRR.cpp index ee36c77838de..2933f8e5153d 100644 --- a/netwerk/dns/TRR.cpp +++ b/netwerk/dns/TRR.cpp @@ -1074,7 +1074,8 @@ nsresult TRR::DohDecode(nsCString& aHost) { svcbIndex += len; // If this is an unknown key, we will simply ignore it. - if (key == SvcParamKeyNone || key > SvcParamKeyLast) { + // We also don't need to record SvcParamKeyMandatory + if (key == SvcParamKeyMandatory || key > SvcParamKeyLast) { continue; } parsed.mSvcFieldValue.AppendElement(value); @@ -1202,6 +1203,24 @@ nsresult TRR::DohDecode(nsCString& aHost) { nsresult TRR::ParseSvcParam(unsigned int svcbIndex, uint16_t key, SvcFieldValue& field, uint16_t length) { switch (key) { + case SvcParamKeyMandatory: { + if (length % 2 != 0) { + // This key should encode a list of uint16_t + return NS_ERROR_UNEXPECTED; + } + while (length > 0) { + uint16_t mandatoryKey = get16bit(mResponse, svcbIndex); + length -= 2; + svcbIndex += 2; + + if (mandatoryKey > SvcParamKeyLast) { + LOG(("The mandatory field includes a key we don't support %u", + mandatoryKey)); + return NS_ERROR_UNEXPECTED; + } + } + break; + } case SvcParamKeyAlpn: { field.mValue = AsVariant(SvcParamAlpn{ .mValue = nsCString((const char*)(&mResponse[svcbIndex]), length)}); diff --git a/netwerk/test/unit/test_trr_httpssvc.js b/netwerk/test/unit/test_trr_httpssvc.js index 7dc77795fb71..ecb55107e3cd 100644 --- a/netwerk/test/unit/test_trr_httpssvc.js +++ b/netwerk/test/unit/test_trr_httpssvc.js @@ -505,4 +505,84 @@ add_task(async function test_aliasform() { !Components.isSuccessCode(inStatus2), `${inStatus2} should be an error code` ); + + // mandatory svcparam + await trrServer.registerDoHAnswers("mandatory.com", "HTTPS", [ + { + name: "mandatory.com", + ttl: 55, + type: "HTTPS", + flush: false, + data: { + priority: 1, + name: "h3pool", + values: [ + { key: "mandatory", value: ["key100"] }, + { key: "alpn", value: "h2,h3" }, + { key: "key100" }, + ], + }, + }, + ]); + + listener = new DNSListener(); + request = dns.asyncResolveByType( + "mandatory.com", + dns.RESOLVE_TYPE_HTTPSSVC, + 0, + listener, + mainThread, + defaultOriginAttributes + ); + + [inRequest, inRecord, inStatus2] = await listener; + Assert.equal(inRequest, request, "correct request was used"); + Assert.ok(!Components.isSuccessCode(inStatus2), `${inStatus2} should fail`); + + // mandatory svcparam + await trrServer.registerDoHAnswers("mandatory2.com", "HTTPS", [ + { + name: "mandatory2.com", + ttl: 55, + type: "HTTPS", + flush: false, + data: { + priority: 1, + name: "h3pool", + values: [ + { + key: "mandatory", + value: [ + "alpn", + "no-default-alpn", + "port", + "ipv4hint", + "echconfig", + "ipv6hint", + ], + }, + { key: "alpn", value: "h2,h3" }, + { key: "no-default-alpn" }, + { key: "port", value: 8888 }, + { key: "ipv4hint", value: "1.2.3.4" }, + { key: "echconfig", value: "123..." }, + { key: "ipv6hint", value: "::1" }, + ], + }, + }, + ]); + + listener = new DNSListener(); + request = dns.asyncResolveByType( + "mandatory2.com", + dns.RESOLVE_TYPE_HTTPSSVC, + 0, + listener, + mainThread, + defaultOriginAttributes + ); + + [inRequest, inRecord, inStatus2] = await listener; + Assert.equal(inRequest, request, "correct request was used"); + Assert.ok(Components.isSuccessCode(inStatus2), `${inStatus2} should succeed`); }); diff --git a/testing/xpcshell/dns-packet/index.js b/testing/xpcshell/dns-packet/index.js index a9fa112dd3f1..a7446910d31b 100644 --- a/testing/xpcshell/dns-packet/index.js +++ b/testing/xpcshell/dns-packet/index.js @@ -1276,6 +1276,7 @@ const svcparam = exports.svcparam = {} svcparam.keyToNumber = function(keyName) { switch (keyName.toLowerCase()) { + case 'mandatory': return 0 case 'alpn' : return 1 case 'no-default-alpn' : return 2 case 'port' : return 3 @@ -1293,7 +1294,7 @@ svcparam.keyToNumber = function(keyName) { svcparam.numberToKeyName = function(number) { switch (number) { - case 0 : return '' + case 0 : return 'mandatory' case 1 : return 'alpn' case 2 : return 'no-default-alpn' case 3 : return 'port' @@ -1318,7 +1319,22 @@ svcparam.encode = function(param, buf, offset) { offset += 2; svcparam.encode.bytes = 2; - if (key == 1) { // alpn + if (key == 0) { // mandatory + let values = param.value; + if (!Array.isArray(values)) values = [values]; + buf.writeUInt16BE(values.length*2, offset); + offset += 2; + svcparam.encode.bytes += 2; + + for (let val of values) { + if (typeof val !== 'number') { + val = svcparam.keyToNumber(val); + } + buf.writeUInt16BE(val, offset); + offset += 2; + svcparam.encode.bytes += 2; + } + } else if (key == 1) { // alpn let len = param.value.length buf.writeUInt16BE(len || 0, offset); offset += 2; @@ -1371,7 +1387,7 @@ svcparam.encode = function(param, buf, offset) { } } else { // Unknown option - buf.writeUInt16BE(param.value || 0, offset); + buf.writeUInt16BE(0, offset); // 0 length since we don't know how to encode offset += 2; svcparam.encode.bytes += 2; } @@ -1404,10 +1420,11 @@ svcparam.encodingLength = function (param) { // 2 bytes for type, 2 bytes for length, what's left for the value switch (param.key) { + case 'mandatory' : return 4 + 2*(Array.isArray(param.value) ? param.value.length : 1) case 'alpn' : return 4 + param.value.length case 'no-default-alpn' : return 4 case 'port' : return 4 + 2 - case 'ipv4hint' : return 4+4 * (Array.isArray(param.value) ? param.value.length : 1) + case 'ipv4hint' : return 4 + 4 * (Array.isArray(param.value) ? param.value.length : 1) case 'echconfig' : return 4 + param.value.length case 'ipv6hint' : return 4 + 16 * (Array.isArray(param.value) ? param.value.length : 1) case 'key65535' : return 4